diff --git a/src/authenticator.rs b/src/authenticator.rs index 6eea182..9d0feda 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -6,7 +6,7 @@ use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; use crate::refresh::RefreshFlow; use crate::service_account::{ServiceAccountFlow, ServiceAccountFlowOpts, ServiceAccountKey}; use crate::storage::{self, Storage}; -use crate::types::{ApplicationSecret, Token}; +use crate::types::{AccessToken, ApplicationSecret, TokenInfo}; use private::AuthFlow; use std::borrow::Cow; @@ -27,35 +27,35 @@ where C: hyper::client::connect::Connect + 'static, { /// Return the current token for the provided scopes. - pub async fn token<'a, T>(&'a self, scopes: &'a [T]) -> Result + pub async fn token<'a, T>(&'a self, scopes: &'a [T]) -> Result where T: AsRef, { let hashed_scopes = storage::ScopeSet::from(scopes); match (self.storage.get(hashed_scopes), self.auth_flow.app_secret()) { - (Some(t), _) if !t.expired() => { + (Some(t), _) if !t.is_expired() => { // unexpired token found - Ok(t) + Ok(t.into()) } ( - Some(Token { + Some(TokenInfo { refresh_token: Some(refresh_token), .. }), Some(app_secret), ) => { // token is expired but has a refresh token. - let token = + let token_info = RefreshFlow::refresh_token(&self.hyper_client, app_secret, &refresh_token) .await?; - self.storage.set(hashed_scopes, token.clone()).await?; - Ok(token) + self.storage.set(hashed_scopes, token_info.clone()).await?; + Ok(token_info.into()) } _ => { // no token in the cache or the token returned can't be refreshed. - let t = self.auth_flow.token(&self.hyper_client, scopes).await?; - self.storage.set(hashed_scopes, t.clone()).await?; - Ok(t) + let token_info = self.auth_flow.token(&self.hyper_client, scopes).await?; + self.storage.set(hashed_scopes, token_info.clone()).await?; + Ok(token_info.into()) } } } @@ -354,7 +354,7 @@ mod private { use crate::error::Error; use crate::installed::InstalledFlow; use crate::service_account::ServiceAccountFlow; - use crate::types::{ApplicationSecret, Token}; + use crate::types::{ApplicationSecret, TokenInfo}; pub enum AuthFlow { DeviceFlow(DeviceFlow), @@ -375,7 +375,7 @@ mod private { &'a self, hyper_client: &'a hyper::Client, scopes: &'a [T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, diff --git a/src/device.rs b/src/device.rs index e80ad01..b21098b 100644 --- a/src/device.rs +++ b/src/device.rs @@ -2,7 +2,7 @@ use crate::authenticator_delegate::{ DefaultDeviceFlowDelegate, DeviceAuthResponse, DeviceFlowDelegate, }; use crate::error::{AuthError, Error}; -use crate::types::{ApplicationSecret, Token}; +use crate::types::{ApplicationSecret, TokenInfo}; use std::borrow::Cow; use std::time::Duration; @@ -43,7 +43,7 @@ impl DeviceFlow { &self, hyper_client: &hyper::Client, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -73,7 +73,7 @@ impl DeviceFlow { app_secret: &ApplicationSecret, device_auth_resp: &DeviceAuthResponse, grant_type: &str, - ) -> Result + ) -> Result where C: hyper::client::connect::Connect + 'static, { @@ -168,7 +168,7 @@ impl DeviceFlow { client: &hyper::Client, device_code: &str, grant_type: &str, - ) -> Result + ) -> Result where C: hyper::client::connect::Connect + 'static, { @@ -188,6 +188,6 @@ impl DeviceFlow { .unwrap(); // TODO: Error checking let res = client.request(request).await?; let body = res.into_body().try_concat().await?; - Token::from_json(&body) + TokenInfo::from_json(&body) } } diff --git a/src/installed.rs b/src/installed.rs index f2ae102..4a6bebd 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -4,7 +4,7 @@ // use crate::authenticator_delegate::{DefaultInstalledFlowDelegate, InstalledFlowDelegate}; use crate::error::Error; -use crate::types::{ApplicationSecret, Token}; +use crate::types::{ApplicationSecret, TokenInfo}; use std::convert::AsRef; use std::future::Future; @@ -93,7 +93,7 @@ impl InstalledFlow { &self, hyper_client: &hyper::Client, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -115,7 +115,7 @@ impl InstalledFlow { hyper_client: &hyper::Client, app_secret: &ApplicationSecret, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -140,7 +140,7 @@ impl InstalledFlow { hyper_client: &hyper::Client, app_secret: &ApplicationSecret, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -178,7 +178,7 @@ impl InstalledFlow { hyper_client: &hyper::Client, app_secret: &ApplicationSecret, server_addr: Option, - ) -> Result + ) -> Result where C: hyper::client::connect::Connect + 'static, { @@ -186,7 +186,7 @@ impl InstalledFlow { let request = Self::request_token(app_secret, authcode, redirect_uri, server_addr); let resp = hyper_client.request(request).await?; let body = resp.into_body().try_concat().await?; - Token::from_json(&body) + TokenInfo::from_json(&body) } /// Sends the authorization code to the provider in order to obtain access and refresh tokens. diff --git a/src/lib.rs b/src/lib.rs index aa7c0fb..7fa2e58 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,4 +93,4 @@ pub use crate::service_account::ServiceAccountKey; #[doc(inline)] pub use crate::error::Error; -pub use crate::types::{ApplicationSecret, ConsoleApplicationSecret, Token}; +pub use crate::types::{AccessToken, ApplicationSecret, ConsoleApplicationSecret}; diff --git a/src/refresh.rs b/src/refresh.rs index 1ed72e8..0ce14bf 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -1,5 +1,5 @@ use crate::error::Error; -use crate::types::{ApplicationSecret, Token}; +use crate::types::{ApplicationSecret, TokenInfo}; use futures_util::try_stream::TryStreamExt; use hyper::header; @@ -10,7 +10,7 @@ use url::form_urlencoded; /// Refresh an expired access token, as obtained by any other authentication flow. /// This flow is useful when your `Token` is expired and allows to obtain a new /// and valid access token. -pub struct RefreshFlow; +pub(crate) struct RefreshFlow; impl RefreshFlow { /// Attempt to refresh the given token, and obtain a new, valid one. @@ -27,11 +27,11 @@ impl RefreshFlow { /// /// # Examples /// Please see the crate landing page for an example. - pub async fn refresh_token( + pub(crate) async fn refresh_token( client: &hyper::Client, client_secret: &ApplicationSecret, refresh_token: &str, - ) -> Result { + ) -> Result { let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ ("client_id", client_secret.client_id.as_str()), @@ -48,7 +48,7 @@ impl RefreshFlow { let resp = client.request(request).await?; let body = resp.into_body().try_concat().await?; - let mut token = Token::from_json(&body)?; + let mut token = TokenInfo::from_json(&body)?; // If the refresh result contains a refresh_token use it, otherwise // continue using our previous refresh_token. token diff --git a/src/service_account.rs b/src/service_account.rs index 1eeeb47..55c545f 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -12,7 +12,7 @@ //! use crate::error::Error; -use crate::types::Token; +use crate::types::TokenInfo; use std::io; @@ -181,7 +181,7 @@ impl ServiceAccountFlow { &self, hyper_client: &hyper::Client, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -202,7 +202,7 @@ impl ServiceAccountFlow { .unwrap(); let response = hyper_client.request(request).await?; let body = response.into_body().try_concat().await?; - Token::from_json(&body) + TokenInfo::from_json(&body) } } diff --git a/src/storage.rs b/src/storage.rs index 7d38575..89ed9c2 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -2,7 +2,7 @@ // // See project root for licensing information. // -use crate::types::Token; +use crate::types::TokenInfo; use std::collections::HashMap; use std::io; @@ -114,7 +114,7 @@ impl Storage { pub(crate) async fn set( &self, scopes: ScopeSet<'_, T>, - token: Token, + token: TokenInfo, ) -> Result<(), io::Error> where T: AsRef, @@ -125,7 +125,7 @@ impl Storage { } } - pub(crate) fn get(&self, scopes: ScopeSet) -> Option + pub(crate) fn get(&self, scopes: ScopeSet) -> Option where T: AsRef, { @@ -141,7 +141,7 @@ impl Storage { #[derive(Debug, Clone)] struct JSONToken { scopes: Vec, - token: Token, + token: TokenInfo, hash: ScopeHash, filter: ScopeFilter, } @@ -154,7 +154,7 @@ impl<'de> Deserialize<'de> for JSONToken { #[derive(Deserialize)] struct RawJSONToken { scopes: Vec, - token: Token, + token: TokenInfo, } let RawJSONToken { scopes, token } = RawJSONToken::deserialize(deserializer)?; let ScopeSet { hash, filter, .. } = ScopeSet::from(&scopes); @@ -175,7 +175,7 @@ impl Serialize for JSONToken { #[derive(Serialize)] struct RawJSONToken<'a> { scopes: &'a [String], - token: &'a Token, + token: &'a TokenInfo, } RawJSONToken { scopes: &self.scopes, @@ -251,7 +251,7 @@ impl JSONTokens { filter, scopes, }: ScopeSet, - ) -> Option + ) -> Option where T: AsRef, { @@ -280,7 +280,7 @@ impl JSONTokens { filter, scopes, }: ScopeSet, - token: Token, + token: TokenInfo, ) -> Result<(), io::Error> where T: AsRef, @@ -326,7 +326,7 @@ impl DiskStorage { pub(crate) async fn set( &self, scopes: ScopeSet<'_, T>, - token: Token, + token: TokenInfo, ) -> Result<(), io::Error> where T: AsRef, @@ -341,7 +341,7 @@ impl DiskStorage { tokio::fs::write(self.filename.clone(), json).await } - pub(crate) fn get(&self, scopes: ScopeSet) -> Option + pub(crate) fn get(&self, scopes: ScopeSet) -> Option where T: AsRef, { @@ -375,10 +375,9 @@ mod tests { #[tokio::test] async fn test_disk_storage() { - let new_token = |access_token: &str| Token { + let new_token = |access_token: &str| TokenInfo { access_token: access_token.to_owned(), refresh_token: None, - token_type: "Bearer".to_owned(), expires_at: None, }; let scope_set = ScopeSet::from(&["myscope"]); diff --git a/src/types.rs b/src/types.rs index c9b3225..5633f0d 100644 --- a/src/types.rs +++ b/src/types.rs @@ -3,25 +3,70 @@ use crate::error::{AuthErrorOr, Error}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +/// Represents an access token returned by oauth2 servers. All access tokens are +/// Bearer tokens. Other types of tokens are not supported. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] +pub struct AccessToken { + value: String, + expires_at: Option>, +} + +impl AccessToken { + /// A string representation of the access token. + pub fn as_str(&self) -> &str { + &self.value + } + + /// The time the access token will expire, if any. + pub fn expiration_time(&self) -> Option> { + self.expires_at + } + + /// Determine if the access token is expired. + /// This will report that the token is expired 1 minute prior to the + /// expiration time to ensure that when the token is actually sent to the + /// server it's still valid. + pub fn is_expired(&self) -> bool { + // Consider the token expired if it's within 1 minute of it's expiration + // time. + self.expires_at + .map(|expiration_time| expiration_time - chrono::Duration::minutes(1) <= Utc::now()) + .unwrap_or(false) + } +} + +impl AsRef for AccessToken { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl From for AccessToken { + fn from(value: TokenInfo) -> Self { + AccessToken { + value: value.access_token, + expires_at: value.expires_at, + } + } +} + /// Represents a token as returned by OAuth2 servers. /// /// It is produced by all authentication flows. /// It authenticates certain operations, and must be refreshed once /// it reached it's expiry date. #[derive(Clone, PartialEq, Debug, Deserialize, Serialize)] -pub struct Token { +pub(crate) struct TokenInfo { /// used when authenticating calls to oauth2 enabled services. - pub access_token: String, + pub(crate) access_token: String, /// used to refresh an expired access_token. - pub refresh_token: Option, - /// The token type as string - usually 'Bearer'. - pub token_type: String, + pub(crate) refresh_token: Option, /// The time when the token expires. - pub expires_at: Option>, + pub(crate) expires_at: Option>, } -impl Token { - pub(crate) fn from_json(json_data: &[u8]) -> Result { +impl TokenInfo { + pub(crate) fn from_json(json_data: &[u8]) -> Result { #[derive(Deserialize)] struct RawToken { access_token: String, @@ -37,19 +82,30 @@ impl Token { expires_in, } = serde_json::from_slice::>(json_data)?.into_result()?; + if token_type.to_lowercase().as_str() != "bearer" { + use std::io; + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + r#"unknown token type returned; expected "bearer" found {}"#, + token_type + ), + ) + .into()); + } + let expires_at = expires_in .map(|seconds_from_now| Utc::now() + chrono::Duration::seconds(seconds_from_now)); - Ok(Token { + Ok(TokenInfo { access_token, refresh_token, - token_type, expires_at, }) } /// Returns true if we are expired. - pub fn expired(&self) -> bool { + pub fn is_expired(&self) -> bool { self.expires_at .map(|expiration_time| expiration_time - chrono::Duration::minutes(1) <= Utc::now()) .unwrap_or(false) diff --git a/tests/tests.rs b/tests/tests.rs index 40fe487..fe2d6b3 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -88,7 +88,7 @@ async fn test_device_success() { .token(&["https://www.googleapis.com/scope/1"]) .await .expect("token failed"); - assert_eq!("accesstoken", token.access_token); + assert_eq!("accesstoken", token.as_str()); _m.assert(); } @@ -253,9 +253,7 @@ async fn test_installed_interactive_success() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!("accesstoken", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); + assert_eq!("accesstoken", tok.as_str()); _m.assert(); } @@ -282,9 +280,7 @@ async fn test_installed_redirect_success() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!("accesstoken", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); + assert_eq!("accesstoken", tok.as_str()); _m.assert(); } @@ -347,8 +343,8 @@ async fn test_service_account_success() { .token(&["https://www.googleapis.com/auth/pubsub"]) .await .expect("token failed"); - assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); - assert!(Utc::now() + chrono::Duration::seconds(3600) >= tok.expires_at.unwrap()); + assert!(tok.as_str().contains("ya29.c.ElouBywiys0Ly")); + assert!(Utc::now() + chrono::Duration::seconds(3600) >= tok.expiration_time().unwrap()); _m.assert(); } @@ -396,9 +392,7 @@ async fn test_refresh() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!("accesstoken", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); + assert_eq!("accesstoken", tok.as_str()); _m.assert(); let _m = mockito::mock("POST", "/token") @@ -420,9 +414,7 @@ async fn test_refresh() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!("accesstoken2", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); + assert_eq!("accesstoken2", tok.as_str()); _m.assert(); let _m = mockito::mock("POST", "/token") @@ -481,7 +473,7 @@ async fn test_memory_storage() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!(token1.access_token.as_str(), "accesstoken"); + assert_eq!(token1.as_str(), "accesstoken"); assert_eq!(token1, token2); _m.assert(); @@ -507,7 +499,7 @@ async fn test_memory_storage() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!(token3.access_token.as_str(), "accesstoken2"); + assert_eq!(token3.as_str(), "accesstoken2"); _m.assert(); } @@ -547,7 +539,7 @@ async fn test_disk_storage() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!(token1.access_token.as_str(), "accesstoken"); + assert_eq!(token1.as_str(), "accesstoken"); assert_eq!(token1, token2); _m.assert(); } @@ -569,6 +561,6 @@ async fn test_disk_storage() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!(token1.access_token.as_str(), "accesstoken"); + assert_eq!(token1.as_str(), "accesstoken"); assert_eq!(token1, token2); }