From 2cf2e465d12e2c01354594613beabbae3e3505b5 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 8 Nov 2019 15:22:38 -0800 Subject: [PATCH] Add JsonErrorOr enum to make json error handling more concise/consistent. JsonErrorOr is an untagged enum that is generic over arbitrary data. This means that when deserializing JsonErrorOr it will first check the json field for an 'error' attribute. If one exists it will deserialize into the JsonErrorOr::Err variant that contains a JsonError. If the message doesn't contain an 'error' field it will attempt to deserialize T into he JsonErrorOr::Data variant. --- src/device.rs | 31 ++++++++++++--------------- src/installed.rs | 48 ++++++++++++------------------------------ src/refresh.rs | 41 +++++++++++++++--------------------- src/service_account.rs | 32 ++++++++++++---------------- src/types.rs | 16 +++++++++++++- 5 files changed, 72 insertions(+), 96 deletions(-) diff --git a/src/device.rs b/src/device.rs index 44700c0..a618243 100644 --- a/src/device.rs +++ b/src/device.rs @@ -11,7 +11,7 @@ use url::form_urlencoded; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate, PollInformation, Retry}; use crate::types::{ - ApplicationSecret, GetToken, JsonError, PollError, RequestError, Token, + ApplicationSecret, GetToken, JsonErrorOr, PollError, RequestError, Token, }; pub const GOOGLE_DEVICE_CODE_URL: &'static str = "https://accounts.google.com/o/oauth2/device/code"; @@ -236,25 +236,20 @@ where } let json_bytes = resp.into_body().try_concat().await?; + match json::from_slice::>(&json_bytes)? { + JsonErrorOr::Err(e) => Err(e.into()), + JsonErrorOr::Data(decoded) => { + let expires_in = decoded.expires_in.unwrap_or(60 * 60); - // check for error - match json::from_slice::(&json_bytes) { - Err(_) => {} // ignore, move on - Ok(res) => return Err(RequestError::from(res)), + let pi = PollInformation { + user_code: decoded.user_code, + verification_url: decoded.verification_uri, + expires_at: Utc::now() + chrono::Duration::seconds(expires_in), + interval: Duration::from_secs(i64::abs(decoded.interval) as u64), + }; + Ok((pi, decoded.device_code)) + } } - - let decoded: JsonData = - json::from_slice(&json_bytes).map_err(|e| RequestError::JSONError(e))?; - - let expires_in = decoded.expires_in.unwrap_or(60 * 60); - - let pi = PollInformation { - user_code: decoded.user_code, - verification_url: decoded.verification_uri, - expires_at: Utc::now() + chrono::Duration::seconds(expires_in), - interval: Duration::from_secs(i64::abs(decoded.interval) as u64), - }; - Ok((pi, decoded.device_code)) } /// If the first call is successful, this method may be called. diff --git a/src/installed.rs b/src/installed.rs index 3a2a5f0..88d39bb 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -17,7 +17,7 @@ use url::form_urlencoded; use url::percent_encoding::{percent_encode, QUERY_ENCODE_SET}; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate}; -use crate::types::{ApplicationSecret, GetToken, RequestError, Token}; +use crate::types::{ApplicationSecret, GetToken, RequestError, Token, JsonErrorOr}; const OOB_REDIRECT_URI: &'static str = "urn:ietf:wg:oauth:2.0:oob"; @@ -270,24 +270,21 @@ where .try_concat() .await .map_err(|e| RequestError::ClientError(e))?; - let tokens: JSONTokenResponse = - serde_json::from_slice(&body).map_err(|e| RequestError::JSONError(e))?; - match tokens { - JSONTokenResponse { - error: Some(err), - error_description, - .. - } => Err(RequestError::NegativeServerResponse(err, error_description)), - JSONTokenResponse { - access_token: Some(access_token), - refresh_token, - token_type: Some(token_type), - expires_in, - .. - } => { + + #[derive(Deserialize)] + struct JSONTokenResponse { + access_token: String, + refresh_token: String, + token_type: String, + expires_in: Option, + } + + match serde_json::from_slice::>(&body)? { + JsonErrorOr::Err(err) => Err(err.into()), + JsonErrorOr::Data(JSONTokenResponse{access_token, refresh_token, token_type, expires_in}) => { let mut token = Token { access_token, - refresh_token, + refresh_token: Some(refresh_token), token_type, expires_in, expires_in_timestamp: None, @@ -295,12 +292,6 @@ where token.set_expiry_absolute(); Ok(token) } - JSONTokenResponse { - error_description, .. - } => Err(RequestError::NegativeServerResponse( - "".to_owned(), - error_description, - )), } } @@ -336,17 +327,6 @@ where } } -#[derive(Deserialize)] -struct JSONTokenResponse { - access_token: Option, - refresh_token: Option, - token_type: Option, - expires_in: Option, - - error: Option, - error_description: Option, -} - fn spawn_with_handle(f: F) -> impl Future where F: Future + 'static + Send, diff --git a/src/refresh.rs b/src/refresh.rs index 8963f92..4a4bd94 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -1,11 +1,10 @@ -use crate::types::{ApplicationSecret, JsonError, RefreshResult, RequestError}; +use crate::types::{ApplicationSecret, JsonErrorOr, RefreshResult, RequestError}; use super::Token; use chrono::Utc; use futures_util::try_stream::TryStreamExt; use hyper; use hyper::header; -use serde_json as json; use url::form_urlencoded; /// Implements the [OAuth2 Refresh Token Flow](https://developers.google.com/youtube/v3/guides/authentication#devices). @@ -58,34 +57,28 @@ impl RefreshFlow { Ok(body) => body, Err(err) => return Ok(RefreshResult::Error(err)), }; - if let Ok(json_err) = json::from_slice::(&body) { - return Ok(RefreshResult::RefreshError( - json_err.error, - json_err.error_description, - )); - } + #[derive(Deserialize)] struct JsonToken { access_token: String, token_type: String, expires_in: i64, } - let t: JsonToken = match json::from_slice(&body) { - Err(_) => { - return Ok(RefreshResult::RefreshError( - "failed to deserialized json token from refresh response".to_owned(), - None, - )) - } - Ok(token) => token, - }; - Ok(RefreshResult::Success(Token { - access_token: t.access_token, - token_type: t.token_type, - refresh_token: Some(refresh_token.to_string()), - expires_in: None, - expires_in_timestamp: Some(Utc::now().timestamp() + t.expires_in), - })) + + match serde_json::from_slice::>(&body) { + Err(_) => Ok(RefreshResult::RefreshError("failed to deserialized json token from refresh response".to_owned(), None)), + Ok(JsonErrorOr::Err(json_err)) => Ok(RefreshResult::RefreshError(json_err.error, json_err.error_description)), + Ok(JsonErrorOr::Data(JsonToken{access_token, token_type, expires_in})) => { + Ok(RefreshResult::Success( + Token{ + access_token, + token_type, + refresh_token: Some(refresh_token.to_string()), + expires_in: None, + expires_in_timestamp: Some(Utc::now().timestamp() + expires_in), + })) + }, + } } } diff --git a/src/service_account.rs b/src/service_account.rs index eee3fe4..9b93c94 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -17,7 +17,7 @@ use std::sync::{Arc, Mutex}; use crate::authenticator::{DefaultHyperClient, HyperClientBuilder}; use crate::storage::{hash_scopes, MemoryStorage, TokenStorage}; -use crate::types::{ApplicationSecret, GetToken, JsonError, RequestError, Token}; +use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, RequestError, Token}; use futures::prelude::*; use hyper::header; @@ -302,38 +302,32 @@ where .try_concat() .await .map_err(RequestError::ClientError)?; - if let Ok(jse) = serde_json::from_slice::(&body) { - return Err(RequestError::NegativeServerResponse( - jse.error, - jse.error_description, - )); - } - let token: TokenResponse = - serde_json::from_slice(&body).map_err(RequestError::JSONError)?; - let token = match token { - TokenResponse { + match serde_json::from_slice::>(&body)? { + JsonErrorOr::Err(err) => { + Err(err.into()) + }, + JsonErrorOr::Data(TokenResponse { access_token: Some(access_token), token_type: Some(token_type), expires_in: Some(expires_in), .. - } => { + }) => { let expires_ts = chrono::Utc::now().timestamp() + expires_in; - Token { + Ok(Token { access_token, token_type, refresh_token: None, expires_in: Some(expires_in), expires_in_timestamp: Some(expires_ts), - } - } - _ => { - return Err(RequestError::BadServerResponse(format!( + }) + }, + JsonErrorOr::Data(token) => { + Err(RequestError::BadServerResponse(format!( "Token response lacks fields: {:?}", token ))) } - }; - Ok(token) + } } async fn get_token(&self, scopes: &[T]) -> Result diff --git a/src/types.rs b/src/types.rs index 697ddb0..2c575ec 100644 --- a/src/types.rs +++ b/src/types.rs @@ -15,6 +15,14 @@ pub struct JsonError { pub error_uri: Option, } +/// A helper type to deserialize either a JsonError or another piece of data. +#[derive(Deserialize, Debug)] +#[serde(untagged)] +pub enum JsonErrorOr { + Err(JsonError), + Data(T), +} + /// All possible outcomes of the refresh flow #[derive(Debug)] pub enum RefreshResult { @@ -57,7 +65,7 @@ pub enum RequestError { /// A malformed server response. BadServerResponse(String), /// Error while decoding a JSON response. - JSONError(serde_json::error::Error), + JSONError(serde_json::Error), /// Error within user input. UserError(String), /// A lower level IO error. @@ -90,6 +98,12 @@ impl From for RequestError { } } +impl From for RequestError { + fn from(value: serde_json::Error) -> RequestError { + RequestError::JSONError(value) + } +} + impl fmt::Display for RequestError { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self {