diff --git a/Cargo.toml b/Cargo.toml index 98ea883..7b5bb27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,9 @@ tokio = { version = "1.0", features = ["fs", "macros", "io-std", "io-util", "tim url = "2" percent-encoding = "2" futures = "0.3" +async-trait = "^0.1" +anyhow = "1.0.38" +itertools = "0.10.0" [dev-dependencies] httptest = "0.14" diff --git a/examples/custom_flow.rs b/examples/custom_flow.rs index 462e4fd..f34bb7c 100644 --- a/examples/custom_flow.rs +++ b/examples/custom_flow.rs @@ -44,7 +44,7 @@ async fn main() { let sec = yup_oauth2::read_application_secret("client_secret.json") .await .expect("client secret couldn't be read."); - let auth = yup_oauth2::InstalledFlowAuthenticator::builder( + let mut auth = yup_oauth2::InstalledFlowAuthenticator::builder( sec, yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect, ) diff --git a/examples/custom_storage.rs b/examples/custom_storage.rs new file mode 100644 index 0000000..e5599b0 --- /dev/null +++ b/examples/custom_storage.rs @@ -0,0 +1,66 @@ +//! Demonstrating how to create a custom token store +use async_trait::async_trait; +use yup_oauth2::storage::{ScopeSet, TokenInfo, TokenStorage}; + +struct ExampleTokenStore { + store: Vec, +} + +struct StoredToken { + scopes: Vec, + serialized_token: String, +} + +/// Here we implement our own token storage. You could write the serialized token and scope data +/// to disk, an OS keychain, a database or whatever suits your use-case +#[async_trait] +impl TokenStorage for ExampleTokenStore { + async fn set(&mut self, scopes: ScopeSet<'_, &str>, token: TokenInfo) -> anyhow::Result<()> { + let data = serde_json::to_string(&token).unwrap(); + + println!("Storing token for scopes {:?}", scopes); + + self.store.push(StoredToken { + scopes: scopes.scopes(), + serialized_token: data, + }); + + Ok(()) + } + + async fn get(&self, target_scopes: ScopeSet<'_, &str>) -> Option { + // Retrieve the token data + for stored_token in self.store.iter() { + if target_scopes.is_covered_by(&stored_token.scopes) { + return serde_json::from_str(&stored_token.serialized_token).ok(); + } + } + + None + } +} + +#[tokio::main] +async fn main() { + // Put your client secret in the working directory! + let sec = yup_oauth2::read_application_secret("client_secret.json") + .await + .expect("client secret couldn't be read."); + let mut auth = yup_oauth2::InstalledFlowAuthenticator::builder( + sec, + yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect, + ) + .with_storage(yup_oauth2::authenticator::StorageType::Custom(Box::new( + ExampleTokenStore { store: vec![] }, + ))) + .build() + .await + .expect("InstalledFlowAuthenticator failed to build"); + + let scopes = &["https://www.googleapis.com/auth/drive.file"]; + + match auth.token(scopes).await { + Err(e) => println!("error: {:?}", e), + Ok(t) => println!("The token is {:?}", t), + } +} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..32a9786 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +edition = "2018" diff --git a/src/authenticator.rs b/src/authenticator.rs index 2ea89d8..a095f27 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -5,7 +5,7 @@ use crate::error::Error; use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; use crate::refresh::RefreshFlow; use crate::service_account::{ServiceAccountFlow, ServiceAccountFlowOpts, ServiceAccountKey}; -use crate::storage::{self, Storage}; +use crate::storage::{self, Storage, TokenStorage}; use crate::types::{AccessToken, ApplicationSecret, TokenInfo}; use private::AuthFlow; @@ -47,7 +47,7 @@ where C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { /// Return the current token for the provided scopes. - pub async fn token<'a, T>(&'a self, scopes: &'a [T]) -> Result + pub async fn token<'a, T>(&'a mut self, scopes: &'a [T]) -> Result where T: AsRef, { @@ -57,7 +57,7 @@ where /// Return a token for the provided scopes, but don't reuse cached tokens. Instead, /// always fetch a new token from the OAuth server. pub async fn force_refreshed_token<'a, T>( - &'a self, + &'a mut self, scopes: &'a [T], ) -> Result where @@ -68,7 +68,7 @@ where /// Return a cached token or fetch a new one from the server. async fn find_token<'a, T>( - &'a self, + &'a mut self, scopes: &'a [T], force_refresh: bool, ) -> Result @@ -219,6 +219,7 @@ impl AuthenticatorBuilder { tokens: Mutex::new(storage::JSONTokens::new()), }, StorageType::Disk(path) => Storage::Disk(storage::DiskStorage::new(path).await?), + StorageType::Custom(custom_store) => Storage::Custom(custom_store), }; Ok(Authenticator { @@ -236,6 +237,14 @@ impl AuthenticatorBuilder { } } + /// Use the provided token storage mechanism + pub fn with_storage(self, storage_type: StorageType) -> Self { + AuthenticatorBuilder { + storage_type: storage_type, + ..self + } + } + /// Use the provided hyper client. pub fn hyper_client( self, @@ -494,9 +503,14 @@ where } } -enum StorageType { +/// How should the acquired tokens be stored? +pub enum StorageType { + /// Store tokens in memory (and always log in again to acquire a new token on startup) Memory, + /// Store tokens to disk in the given file. Warning, this may be insecure unless you configure your operating system to restrict read access to the file. Disk(PathBuf), + /// Implement your own storage provider + Custom(Box), } #[cfg(test)] diff --git a/src/error.rs b/src/error.rs index 63cf636..04e9210 100644 --- a/src/error.rs +++ b/src/error.rs @@ -153,6 +153,8 @@ pub enum Error { UserError(String), /// A lower level IO error. LowLevelError(io::Error), + /// Other errors produced by a storage provider + OtherError(anyhow::Error), } impl From for Error { @@ -179,6 +181,15 @@ impl From for Error { } } +impl From for Error { + fn from(value: anyhow::Error) -> Error { + match value.downcast::() { + Ok(io_error) => Error::LowLevelError(io_error), + Err(err) => Error::OtherError(err), + } + } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self { @@ -194,6 +205,7 @@ impl fmt::Display for Error { } Error::UserError(ref s) => s.fmt(f), Error::LowLevelError(ref e) => e.fmt(f), + Error::OtherError(ref e) => e.fmt(f), } } } diff --git a/src/lib.rs b/src/lib.rs index 9b4b0cd..284903e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,7 +77,11 @@ mod helper; mod installed; mod refresh; mod service_account; -mod storage; + +/// Interface for storing tokens so that they can be re-used. There are built-in memory and +/// file-based storage providers. You can implement your own by implementing the TokenStorage trait. +pub mod storage; + mod types; #[doc(inline)] diff --git a/src/storage.rs b/src/storage.rs index 308eeea..cfe7f82 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -2,13 +2,16 @@ // // See project root for licensing information. // -use crate::types::TokenInfo; +pub use crate::types::TokenInfo; use futures::lock::Mutex; +use itertools::Itertools; use std::collections::HashMap; use std::io; use std::path::{Path, PathBuf}; +use async_trait::async_trait; + use serde::{Deserialize, Serialize}; // The storage layer allows retrieving tokens for scopes that have been @@ -49,8 +52,9 @@ impl ScopeFilter { } } +/// A set of scopes #[derive(Debug)] -pub(crate) struct ScopeSet<'a, T> { +pub struct ScopeSet<'a, T> { hash: ScopeHash, filter: ScopeFilter, scopes: &'a [T], @@ -73,6 +77,8 @@ impl<'a, T> ScopeSet<'a, T> where T: AsRef, { + /// Convert from an array into a ScopeSet. Automatically invoked by the compiler when + /// an array reference is passed. // implement an inherent from method even though From is implemented. This // is because passing an array ref like &[&str; 1] (&["foo"]) will be auto // deref'd to a slice on function boundaries, but it will not implement the @@ -103,25 +109,72 @@ where scopes, } } + + /// Get the scopes for storage when implementing TokenStorage.set(). + /// Returned scope strings are unique and sorted. + pub fn scopes(&self) -> Vec { + self.scopes + .iter() + .map(|scope| scope.as_ref().to_string()) + .sorted() + .unique() + .collect() + } + + /// Is this set of scopes covered by the other? Returns true if the other + /// set is a superset of this one. Use this when implementing TokenStorage.get() + pub fn is_covered_by(&self, other_scopes: &[String]) -> bool { + self.scopes + .iter() + .all(|s| other_scopes.iter().any(|t| t.as_str() == s.as_ref())) + } +} + +/// Implement your own token storage solution by implementing this trait. You need a way to +/// store and retrieve tokens, each keyed by a set of scopes. +#[async_trait] +pub trait TokenStorage: Send + Sync { + /// Store a token for the given set of scopes so that it can be retrieved later by get() + /// ScopeSet implements Hash so that you can easily serialize and store it. + /// TokenInfo can be serialized with serde. + async fn set(&mut self, scopes: ScopeSet<'_, &str>, token: TokenInfo) -> anyhow::Result<()>; + + /// Retrieve a token stored by set for the given set of scopes + async fn get(&self, scopes: ScopeSet<'_, &str>) -> Option; } pub(crate) enum Storage { Memory { tokens: Mutex }, Disk(DiskStorage), + Custom(Box), } impl Storage { pub(crate) async fn set( - &self, + &mut self, scopes: ScopeSet<'_, T>, token: TokenInfo, - ) -> Result<(), io::Error> + ) -> anyhow::Result<()> where T: AsRef, { match self { - Storage::Memory { tokens } => tokens.lock().await.set(scopes, token), - Storage::Disk(disk_storage) => disk_storage.set(scopes, token).await, + Storage::Memory { tokens } => Ok(tokens.lock().await.set(scopes, token)?), + Storage::Disk(disk_storage) => Ok(disk_storage.set(scopes, token).await?), + Storage::Custom(custom_storage) => { + let str_scopes: Vec<_> = scopes.scopes.iter().map(|scope| scope.as_ref()).collect(); + + (*custom_storage) + .set( + ScopeSet { + hash: scopes.hash, + filter: scopes.filter, + scopes: &str_scopes[..], + }, + token, + ) + .await + } } } @@ -132,6 +185,17 @@ impl Storage { match self { Storage::Memory { tokens } => tokens.lock().await.get(scopes), Storage::Disk(disk_storage) => disk_storage.get(scopes).await, + Storage::Custom(custom_storage) => { + let str_scopes: Vec<_> = scopes.scopes.iter().map(|scope| scope.as_ref()).collect(); + + (*custom_storage) + .get(ScopeSet { + hash: scopes.hash, + filter: scopes.filter, + scopes: &str_scopes[..], + }) + .await + } } } } diff --git a/src/types.rs b/src/types.rs index 4b7dd0d..1060bfa 100644 --- a/src/types.rs +++ b/src/types.rs @@ -56,7 +56,7 @@ impl From for AccessToken { /// It authenticates certain operations, and must be refreshed once /// it reached it's expiry date. #[derive(Clone, PartialEq, Debug, Deserialize, Serialize)] -pub(crate) struct TokenInfo { +pub struct TokenInfo { /// used when authenticating calls to oauth2 enabled services. pub(crate) access_token: String, /// used to refresh an expired access_token. diff --git a/tests/tests.rs b/tests/tests.rs index 548d9c7..53f27ab 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -92,7 +92,7 @@ async fn test_device_success() { }))), ); - let auth = create_device_flow_auth(&server).await; + let mut auth = create_device_flow_auth(&server).await; let token = auth .token(&["https://www.googleapis.com/scope/1"]) .await @@ -117,7 +117,7 @@ async fn test_device_no_code() { "error_description": "description" }))), ); - let auth = create_device_flow_auth(&server).await; + let mut auth = create_device_flow_auth(&server).await; let res = auth.token(&["https://www.googleapis.com/scope/1"]).await; assert!(res.is_err()); assert!(format!("{}", res.unwrap_err()).contains("invalid_client_id")); @@ -155,7 +155,7 @@ async fn test_device_no_token() { "error": "access_denied" }))), ); - let auth = create_device_flow_auth(&server).await; + let mut auth = create_device_flow_auth(&server).await; let res = auth.token(&["https://www.googleapis.com/scope/1"]).await; assert!(res.is_err()); assert!(format!("{}", res.unwrap_err()).contains("access_denied")); @@ -239,7 +239,7 @@ async fn create_installed_flow_auth( async fn test_installed_interactive_success() { let _ = env_logger::try_init(); let server = Server::run(); - let auth = + let mut auth = create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; server.expect( Expectation::matching(all_of![ @@ -268,7 +268,7 @@ async fn test_installed_interactive_success() { async fn test_installed_redirect_success() { let _ = env_logger::try_init(); let server = Server::run(); - let auth = + let mut auth = create_installed_flow_auth(&server, InstalledFlowReturnMethod::HTTPRedirect, None).await; server.expect( Expectation::matching(all_of![ @@ -297,7 +297,7 @@ async fn test_installed_redirect_success() { async fn test_installed_error() { let _ = env_logger::try_init(); let server = Server::run(); - let auth = + let mut auth = create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; server.expect( Expectation::matching(all_of![ @@ -347,7 +347,7 @@ async fn test_service_account_success() { use chrono::Utc; let _ = env_logger::try_init(); let server = Server::run(); - let auth = create_service_account_auth(&server).await; + let mut auth = create_service_account_auth(&server).await; server.expect( Expectation::matching(request::method_path("POST", "/token")) @@ -369,7 +369,7 @@ async fn test_service_account_success() { async fn test_service_account_error() { let _ = env_logger::try_init(); let server = Server::run(); - let auth = create_service_account_auth(&server).await; + let mut auth = create_service_account_auth(&server).await; server.expect( Expectation::matching(request::method_path("POST", "/token")).respond_with(json_encoded( serde_json::json!({ @@ -388,7 +388,7 @@ async fn test_service_account_error() { async fn test_refresh() { let _ = env_logger::try_init(); let server = Server::run(); - let auth = + let mut auth = create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; // We refresh a token whenever it's within 1 minute of expiring. So // acquiring a token that expires in 59 seconds will force a refresh on @@ -486,7 +486,7 @@ async fn test_refresh() { async fn test_memory_storage() { let _ = env_logger::try_init(); let server = Server::run(); - let auth = + let mut auth = create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; server.expect( Expectation::matching(all_of![ @@ -519,7 +519,7 @@ async fn test_memory_storage() { // Create a new authenticator. This authenticator does not share a cache // with the previous one. Validate that it receives a different token. - let auth2 = + let mut auth2 = create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; server.expect( Expectation::matching(all_of![ @@ -565,7 +565,7 @@ async fn test_disk_storage() { }))), ); { - let auth = create_installed_flow_auth( + let mut auth = create_installed_flow_auth( &server, InstalledFlowReturnMethod::Interactive, Some(storage_path.clone()), @@ -589,7 +589,7 @@ async fn test_disk_storage() { // Create a new authenticator. This authenticator uses the same token // storage file as the previous one so should receive a token without // making any http requests. - let auth = create_installed_flow_auth( + let mut auth = create_installed_flow_auth( &server, InstalledFlowReturnMethod::Interactive, Some(storage_path.clone()),