diff --git a/src/device.rs b/src/device.rs index 44700c0..a618243 100644 --- a/src/device.rs +++ b/src/device.rs @@ -11,7 +11,7 @@ use url::form_urlencoded; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate, PollInformation, Retry}; use crate::types::{ - ApplicationSecret, GetToken, JsonError, PollError, RequestError, Token, + ApplicationSecret, GetToken, JsonErrorOr, PollError, RequestError, Token, }; pub const GOOGLE_DEVICE_CODE_URL: &'static str = "https://accounts.google.com/o/oauth2/device/code"; @@ -236,25 +236,20 @@ where } let json_bytes = resp.into_body().try_concat().await?; + match json::from_slice::>(&json_bytes)? { + JsonErrorOr::Err(e) => Err(e.into()), + JsonErrorOr::Data(decoded) => { + let expires_in = decoded.expires_in.unwrap_or(60 * 60); - // check for error - match json::from_slice::(&json_bytes) { - Err(_) => {} // ignore, move on - Ok(res) => return Err(RequestError::from(res)), + let pi = PollInformation { + user_code: decoded.user_code, + verification_url: decoded.verification_uri, + expires_at: Utc::now() + chrono::Duration::seconds(expires_in), + interval: Duration::from_secs(i64::abs(decoded.interval) as u64), + }; + Ok((pi, decoded.device_code)) + } } - - let decoded: JsonData = - json::from_slice(&json_bytes).map_err(|e| RequestError::JSONError(e))?; - - let expires_in = decoded.expires_in.unwrap_or(60 * 60); - - let pi = PollInformation { - user_code: decoded.user_code, - verification_url: decoded.verification_uri, - expires_at: Utc::now() + chrono::Duration::seconds(expires_in), - interval: Duration::from_secs(i64::abs(decoded.interval) as u64), - }; - Ok((pi, decoded.device_code)) } /// If the first call is successful, this method may be called. diff --git a/src/installed.rs b/src/installed.rs index 3a2a5f0..88d39bb 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -17,7 +17,7 @@ use url::form_urlencoded; use url::percent_encoding::{percent_encode, QUERY_ENCODE_SET}; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate}; -use crate::types::{ApplicationSecret, GetToken, RequestError, Token}; +use crate::types::{ApplicationSecret, GetToken, RequestError, Token, JsonErrorOr}; const OOB_REDIRECT_URI: &'static str = "urn:ietf:wg:oauth:2.0:oob"; @@ -270,24 +270,21 @@ where .try_concat() .await .map_err(|e| RequestError::ClientError(e))?; - let tokens: JSONTokenResponse = - serde_json::from_slice(&body).map_err(|e| RequestError::JSONError(e))?; - match tokens { - JSONTokenResponse { - error: Some(err), - error_description, - .. - } => Err(RequestError::NegativeServerResponse(err, error_description)), - JSONTokenResponse { - access_token: Some(access_token), - refresh_token, - token_type: Some(token_type), - expires_in, - .. - } => { + + #[derive(Deserialize)] + struct JSONTokenResponse { + access_token: String, + refresh_token: String, + token_type: String, + expires_in: Option, + } + + match serde_json::from_slice::>(&body)? { + JsonErrorOr::Err(err) => Err(err.into()), + JsonErrorOr::Data(JSONTokenResponse{access_token, refresh_token, token_type, expires_in}) => { let mut token = Token { access_token, - refresh_token, + refresh_token: Some(refresh_token), token_type, expires_in, expires_in_timestamp: None, @@ -295,12 +292,6 @@ where token.set_expiry_absolute(); Ok(token) } - JSONTokenResponse { - error_description, .. - } => Err(RequestError::NegativeServerResponse( - "".to_owned(), - error_description, - )), } } @@ -336,17 +327,6 @@ where } } -#[derive(Deserialize)] -struct JSONTokenResponse { - access_token: Option, - refresh_token: Option, - token_type: Option, - expires_in: Option, - - error: Option, - error_description: Option, -} - fn spawn_with_handle(f: F) -> impl Future where F: Future + 'static + Send, diff --git a/src/refresh.rs b/src/refresh.rs index 8963f92..4a4bd94 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -1,11 +1,10 @@ -use crate::types::{ApplicationSecret, JsonError, RefreshResult, RequestError}; +use crate::types::{ApplicationSecret, JsonErrorOr, RefreshResult, RequestError}; use super::Token; use chrono::Utc; use futures_util::try_stream::TryStreamExt; use hyper; use hyper::header; -use serde_json as json; use url::form_urlencoded; /// Implements the [OAuth2 Refresh Token Flow](https://developers.google.com/youtube/v3/guides/authentication#devices). @@ -58,34 +57,28 @@ impl RefreshFlow { Ok(body) => body, Err(err) => return Ok(RefreshResult::Error(err)), }; - if let Ok(json_err) = json::from_slice::(&body) { - return Ok(RefreshResult::RefreshError( - json_err.error, - json_err.error_description, - )); - } + #[derive(Deserialize)] struct JsonToken { access_token: String, token_type: String, expires_in: i64, } - let t: JsonToken = match json::from_slice(&body) { - Err(_) => { - return Ok(RefreshResult::RefreshError( - "failed to deserialized json token from refresh response".to_owned(), - None, - )) - } - Ok(token) => token, - }; - Ok(RefreshResult::Success(Token { - access_token: t.access_token, - token_type: t.token_type, - refresh_token: Some(refresh_token.to_string()), - expires_in: None, - expires_in_timestamp: Some(Utc::now().timestamp() + t.expires_in), - })) + + match serde_json::from_slice::>(&body) { + Err(_) => Ok(RefreshResult::RefreshError("failed to deserialized json token from refresh response".to_owned(), None)), + Ok(JsonErrorOr::Err(json_err)) => Ok(RefreshResult::RefreshError(json_err.error, json_err.error_description)), + Ok(JsonErrorOr::Data(JsonToken{access_token, token_type, expires_in})) => { + Ok(RefreshResult::Success( + Token{ + access_token, + token_type, + refresh_token: Some(refresh_token.to_string()), + expires_in: None, + expires_in_timestamp: Some(Utc::now().timestamp() + expires_in), + })) + }, + } } } diff --git a/src/service_account.rs b/src/service_account.rs index eee3fe4..9b93c94 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -17,7 +17,7 @@ use std::sync::{Arc, Mutex}; use crate::authenticator::{DefaultHyperClient, HyperClientBuilder}; use crate::storage::{hash_scopes, MemoryStorage, TokenStorage}; -use crate::types::{ApplicationSecret, GetToken, JsonError, RequestError, Token}; +use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, RequestError, Token}; use futures::prelude::*; use hyper::header; @@ -302,38 +302,32 @@ where .try_concat() .await .map_err(RequestError::ClientError)?; - if let Ok(jse) = serde_json::from_slice::(&body) { - return Err(RequestError::NegativeServerResponse( - jse.error, - jse.error_description, - )); - } - let token: TokenResponse = - serde_json::from_slice(&body).map_err(RequestError::JSONError)?; - let token = match token { - TokenResponse { + match serde_json::from_slice::>(&body)? { + JsonErrorOr::Err(err) => { + Err(err.into()) + }, + JsonErrorOr::Data(TokenResponse { access_token: Some(access_token), token_type: Some(token_type), expires_in: Some(expires_in), .. - } => { + }) => { let expires_ts = chrono::Utc::now().timestamp() + expires_in; - Token { + Ok(Token { access_token, token_type, refresh_token: None, expires_in: Some(expires_in), expires_in_timestamp: Some(expires_ts), - } - } - _ => { - return Err(RequestError::BadServerResponse(format!( + }) + }, + JsonErrorOr::Data(token) => { + Err(RequestError::BadServerResponse(format!( "Token response lacks fields: {:?}", token ))) } - }; - Ok(token) + } } async fn get_token(&self, scopes: &[T]) -> Result diff --git a/src/types.rs b/src/types.rs index 697ddb0..2c575ec 100644 --- a/src/types.rs +++ b/src/types.rs @@ -15,6 +15,14 @@ pub struct JsonError { pub error_uri: Option, } +/// A helper type to deserialize either a JsonError or another piece of data. +#[derive(Deserialize, Debug)] +#[serde(untagged)] +pub enum JsonErrorOr { + Err(JsonError), + Data(T), +} + /// All possible outcomes of the refresh flow #[derive(Debug)] pub enum RefreshResult { @@ -57,7 +65,7 @@ pub enum RequestError { /// A malformed server response. BadServerResponse(String), /// Error while decoding a JSON response. - JSONError(serde_json::error::Error), + JSONError(serde_json::Error), /// Error within user input. UserError(String), /// A lower level IO error. @@ -90,6 +98,12 @@ impl From for RequestError { } } +impl From for RequestError { + fn from(value: serde_json::Error) -> RequestError { + RequestError::JSONError(value) + } +} + impl fmt::Display for RequestError { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self {