diff --git a/Cargo.toml b/Cargo.toml index 49a3df1..a30608d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,28 +12,22 @@ edition = "2018" [dependencies] base64 = "0.10" -chrono = "0.4" -http = "0.1" -hyper = {version = "0.12", default-features = false} -hyper-rustls = "0.17" -itertools = "0.8" -log = "0.3" +chrono = { version = "0.4", features = ["serde"] } +http = "0.2" +hyper = "0.13.1" +hyper-rustls = "0.19" +log = "0.4" rustls = "0.16" -serde = "1.0" +seahash = "3.0.6" +serde = {version = "1.0", features = ["derive"]} serde_json = "1.0" -serde_derive = "1.0" +tokio = { version = "0.2", features = ["fs", "macros", "io-std", "time"] } url = "1" -futures = "0.1" -tokio-threadpool = "0.1" -tokio = "0.1" -tokio-timer = "0.2" [dev-dependencies] -getopts = "0.2" -open = "1.1" -yup-hyper-mock = "3.14" -mockito = "0.17" +httptest = "0.5" env_logger = "0.6" +tempfile = "3.1" [workspace] members = ["examples/test-installed/", "examples/test-svc-acct/", "examples/test-device/"] diff --git a/examples/test-device/Cargo.toml b/examples/test-device/Cargo.toml index 1b4ed95..21557bb 100644 --- a/examples/test-device/Cargo.toml +++ b/examples/test-device/Cargo.toml @@ -6,7 +6,4 @@ edition = "2018" [dependencies] yup-oauth2 = { path = "../../" } -hyper = "0.12" -hyper-rustls = "0.17" -futures = "0.1" -tokio = "0.1" +tokio = { version = "0.2", features = ["macros"] } diff --git a/examples/test-device/src/main.rs b/examples/test-device/src/main.rs index 278fd78..6586d40 100644 --- a/examples/test-device/src/main.rs +++ b/examples/test-device/src/main.rs @@ -1,20 +1,19 @@ -use futures::prelude::*; -use yup_oauth2::{self, Authenticator, DeviceFlow, GetToken}; +use yup_oauth2::DeviceFlowAuthenticator; -use std::path; -use tokio; - -fn main() { - let creds = yup_oauth2::read_application_secret(path::Path::new("clientsecret.json")) +#[tokio::main] +async fn main() { + let app_secret = yup_oauth2::read_application_secret("clientsecret.json") + .await .expect("clientsecret"); - let mut auth = Authenticator::new(DeviceFlow::new(creds)) + let auth = DeviceFlowAuthenticator::builder(app_secret) .persist_tokens_to_disk("tokenstorage.json") .build() + .await .expect("authenticator"); - let scopes = vec!["https://www.googleapis.com/auth/youtube.readonly"]; - let mut rt = tokio::runtime::Runtime::new().unwrap(); - let fut = auth.token(scopes).and_then(|tok| Ok(println!("{:?}", tok))); - - println!("{:?}", rt.block_on(fut)); + let scopes = &["https://www.googleapis.com/auth/youtube.readonly"]; + match auth.token(scopes).await { + Err(e) => println!("error: {:?}", e), + Ok(t) => println!("token: {:?}", t), + } } diff --git a/examples/test-installed/Cargo.toml b/examples/test-installed/Cargo.toml index 0d6e654..6b69312 100644 --- a/examples/test-installed/Cargo.toml +++ b/examples/test-installed/Cargo.toml @@ -6,7 +6,4 @@ edition = "2018" [dependencies] yup-oauth2 = { path = "../../" } -hyper = "0.12" -hyper-rustls = "0.17" -futures = "0.1" -tokio = "0.1" +tokio = { version = "0.2", features = ["macros"] } diff --git a/examples/test-installed/src/main.rs b/examples/test-installed/src/main.rs index 3aa29ca..43a797c 100644 --- a/examples/test-installed/src/main.rs +++ b/examples/test-installed/src/main.rs @@ -1,36 +1,21 @@ -use futures::prelude::*; -use yup_oauth2::GetToken; -use yup_oauth2::{Authenticator, InstalledFlow}; +use yup_oauth2::{InstalledFlowAuthenticator, InstalledFlowReturnMethod}; -use hyper::client::Client; -use hyper_rustls::HttpsConnector; - -use std::path::Path; - -fn main() { - let https = HttpsConnector::new(1); - let client = Client::builder() - .keep_alive(false) - .build::<_, hyper::Body>(https); - let ad = yup_oauth2::DefaultFlowDelegate; - let secret = yup_oauth2::read_application_secret(Path::new("clientsecret.json")) +#[tokio::main] +async fn main() { + let app_secret = yup_oauth2::read_application_secret("clientsecret.json") + .await .expect("clientsecret.json"); - let mut auth = Authenticator::new(InstalledFlow::new( - secret, - yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect(8081), - )) - .persist_tokens_to_disk("tokencache.json") - .build() - .unwrap(); - let s = "https://www.googleapis.com/auth/drive.file".to_string(); - let scopes = vec![s]; + let auth = + InstalledFlowAuthenticator::builder(app_secret, InstalledFlowReturnMethod::HTTPRedirect) + .persist_tokens_to_disk("tokencache.json") + .build() + .await + .unwrap(); + let scopes = &["https://www.googleapis.com/auth/drive.file"]; - let tok = auth.token(scopes); - let fut = tok.map_err(|e| println!("error: {:?}", e)).and_then(|t| { - println!("The token is {:?}", t); - Ok(()) - }); - - tokio::run(fut) + match auth.token(scopes).await { + Err(e) => println!("error: {:?}", e), + Ok(t) => println!("The token is {:?}", t), + } } diff --git a/examples/test-svc-acct/Cargo.toml b/examples/test-svc-acct/Cargo.toml index 4bccf09..abc2694 100644 --- a/examples/test-svc-acct/Cargo.toml +++ b/examples/test-svc-acct/Cargo.toml @@ -6,7 +6,4 @@ edition = "2018" [dependencies] yup-oauth2 = { path = "../../" } -hyper = "0.12" -hyper-rustls = "0.17" -futures = "0.1" -tokio = "0.1" +tokio = { version = "0.2", features = ["macros"] } diff --git a/examples/test-svc-acct/src/main.rs b/examples/test-svc-acct/src/main.rs index ebaaac1..ee67c03 100644 --- a/examples/test-svc-acct/src/main.rs +++ b/examples/test-svc-acct/src/main.rs @@ -1,29 +1,18 @@ -use yup_oauth2; +use yup_oauth2::ServiceAccountAuthenticator; -use futures::prelude::*; -use yup_oauth2::GetToken; +#[tokio::main] +async fn main() { + let creds = yup_oauth2::read_service_account_key("serviceaccount.json") + .await + .unwrap(); + let sa = ServiceAccountAuthenticator::builder(creds) + .build() + .await + .unwrap(); + let scopes = &["https://www.googleapis.com/auth/pubsub"]; -use tokio; - -use std::path; - -fn main() { - let creds = - yup_oauth2::service_account_key_from_file(path::Path::new("serviceaccount.json")).unwrap(); - let mut sa = yup_oauth2::ServiceAccountAccess::new(creds).build(); - - let fut = sa - .token(vec!["https://www.googleapis.com/auth/pubsub"]) - .and_then(|tok| { - println!("token is: {:?}", tok); - Ok(()) - }); - let fut2 = sa - .token(vec!["https://www.googleapis.com/auth/pubsub"]) - .and_then(|tok| { - println!("cached token is {:?} and should be identical", tok); - Ok(()) - }); - let all = fut.join(fut2).then(|_| Ok(())); - tokio::run(all) + let tok = sa.token(scopes).await.unwrap(); + println!("token is: {:?}", tok); + let tok = sa.token(scopes).await.unwrap(); + println!("cached token is {:?} and should be identical", tok); } diff --git a/src/authenticator.rs b/src/authenticator.rs index 3a7189e..0500e67 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -1,42 +1,429 @@ -use crate::authenticator_delegate::{AuthenticatorDelegate, DefaultAuthenticatorDelegate, Retry}; +//! Module contianing the core functionality for OAuth2 Authentication. +use crate::authenticator_delegate::{DeviceFlowDelegate, InstalledFlowDelegate}; +use crate::device::DeviceFlow; +use crate::error::Error; +use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; use crate::refresh::RefreshFlow; -use crate::storage::{hash_scopes, DiskTokenStorage, MemoryStorage, TokenStorage}; -use crate::types::{ApplicationSecret, GetToken, RefreshResult, RequestError, Token}; +use crate::service_account::{ServiceAccountFlow, ServiceAccountFlowOpts, ServiceAccountKey}; +use crate::storage::{self, Storage}; +use crate::types::{AccessToken, ApplicationSecret, TokenInfo}; +use private::AuthFlow; -use futures::{future, prelude::*}; -use tokio_timer; - -use std::error::Error; +use std::borrow::Cow; +use std::fmt; use std::io; -use std::path::Path; -use std::sync::{Arc, Mutex}; +use std::path::PathBuf; +use std::sync::Mutex; -/// Authenticator abstracts different `GetToken` implementations behind one type and handles -/// caching received tokens. It's important to use it (instead of the flows directly) because -/// otherwise the user needs to be asked for new authorization every time a token is generated. -/// -/// `ServiceAccountAccess` does not need (and does not work) with `Authenticator`, given that it -/// does not require interaction and implements its own caching. Use it directly. -/// -/// NOTE: It is recommended to use a client constructed like this in order to prevent functions -/// like `hyper::run()` from hanging: `let client = hyper::Client::builder().keep_alive(false);`. -/// Due to token requests being rare, this should not result in a too bad performance problem. -struct AuthenticatorImpl< - T: GetToken, - S: TokenStorage, - AD: AuthenticatorDelegate, - C: hyper::client::connect::Connect, -> { - client: hyper::Client, - inner: Arc>, - store: Arc>, - delegate: AD, +/// Authenticator is responsible for fetching tokens, handling refreshing tokens, +/// and optionally persisting tokens to disk. +pub struct Authenticator { + hyper_client: hyper::Client, + storage: Storage, + auth_flow: AuthFlow, } -/// A trait implemented for any hyper::Client as well as teh DefaultHyperClient. -pub trait HyperClientBuilder { - type Connector: hyper::client::connect::Connect; +struct DisplayScopes<'a, T>(&'a [T]); +impl<'a, T> fmt::Display for DisplayScopes<'a, T> +where + T: AsRef, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("[")?; + let mut iter = self.0.iter(); + if let Some(first) = iter.next() { + f.write_str(first.as_ref())?; + for scope in iter { + f.write_str(", ")?; + f.write_str(scope.as_ref())?; + } + } + f.write_str("]") + } +} +impl Authenticator +where + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, +{ + /// Return the current token for the provided scopes. + pub async fn token<'a, T>(&'a self, scopes: &'a [T]) -> Result + where + T: AsRef, + { + log::debug!( + "access token requested for scopes: {}", + DisplayScopes(scopes) + ); + let hashed_scopes = storage::ScopeSet::from(scopes); + match (self.storage.get(hashed_scopes), self.auth_flow.app_secret()) { + (Some(t), _) if !t.is_expired() => { + // unexpired token found + log::debug!("found valid token in cache: {:?}", t); + Ok(t.into()) + } + ( + Some(TokenInfo { + refresh_token: Some(refresh_token), + .. + }), + Some(app_secret), + ) => { + // token is expired but has a refresh token. + let token_info = + RefreshFlow::refresh_token(&self.hyper_client, app_secret, &refresh_token) + .await?; + self.storage.set(hashed_scopes, token_info.clone()).await?; + Ok(token_info.into()) + } + _ => { + // no token in the cache or the token returned can't be refreshed. + let token_info = self.auth_flow.token(&self.hyper_client, scopes).await?; + self.storage.set(hashed_scopes, token_info.clone()).await?; + Ok(token_info.into()) + } + } + } +} + +/// Configure an Authenticator using the builder pattern. +pub struct AuthenticatorBuilder { + hyper_client_builder: C, + storage_type: StorageType, + auth_flow: F, +} + +/// Create an authenticator that uses the installed flow. +/// ``` +/// # async fn foo() { +/// # use yup_oauth2::InstalledFlowReturnMethod; +/// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultInstalledFlowDelegate; +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").await.unwrap(); +/// let authenticator = yup_oauth2::InstalledFlowAuthenticator::builder( +/// app_secret, +/// InstalledFlowReturnMethod::HTTPRedirect, +/// ) +/// .build() +/// .await +/// .expect("failed to create authenticator"); +/// # } +/// ``` +pub struct InstalledFlowAuthenticator; +impl InstalledFlowAuthenticator { + /// Use the builder pattern to create an Authenticator that uses the installed flow. + pub fn builder( + app_secret: ApplicationSecret, + method: InstalledFlowReturnMethod, + ) -> AuthenticatorBuilder { + AuthenticatorBuilder::::with_auth_flow(InstalledFlow::new( + app_secret, method, + )) + } +} + +/// Create an authenticator that uses the device flow. +/// ``` +/// # async fn foo() { +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").await.unwrap(); +/// let authenticator = yup_oauth2::DeviceFlowAuthenticator::builder(app_secret) +/// .build() +/// .await +/// .expect("failed to create authenticator"); +/// # } +/// ``` +pub struct DeviceFlowAuthenticator; +impl DeviceFlowAuthenticator { + /// Use the builder pattern to create an Authenticator that uses the device flow. + pub fn builder( + app_secret: ApplicationSecret, + ) -> AuthenticatorBuilder { + AuthenticatorBuilder::::with_auth_flow(DeviceFlow::new(app_secret)) + } +} + +/// Create an authenticator that uses a service account. +/// ``` +/// # async fn foo() { +/// # let service_account_key = yup_oauth2::read_service_account_key("/tmp/foo").await.unwrap(); +/// let authenticator = yup_oauth2::ServiceAccountAuthenticator::builder(service_account_key) +/// .build() +/// .await +/// .expect("failed to create authenticator"); +/// # } +/// ``` +pub struct ServiceAccountAuthenticator; +impl ServiceAccountAuthenticator { + /// Use the builder pattern to create an Authenticator that uses a service account. + pub fn builder( + service_account_key: ServiceAccountKey, + ) -> AuthenticatorBuilder { + AuthenticatorBuilder::::with_auth_flow(ServiceAccountFlowOpts { + key: service_account_key, + subject: None, + }) + } +} + +/// ## Methods available when building any Authenticator. +/// ``` +/// # async fn foo() { +/// # let custom_hyper_client = hyper::Client::new(); +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").await.unwrap(); +/// let authenticator = yup_oauth2::DeviceFlowAuthenticator::builder(app_secret) +/// .hyper_client(custom_hyper_client) +/// .persist_tokens_to_disk("/tmp/tokenfile.json") +/// .build() +/// .await +/// .expect("failed to create authenticator"); +/// # } +/// ``` +impl AuthenticatorBuilder { + async fn common_build( + hyper_client_builder: C, + storage_type: StorageType, + auth_flow: AuthFlow, + ) -> io::Result> + where + C: HyperClientBuilder, + { + let hyper_client = hyper_client_builder.build_hyper_client(); + let storage = match storage_type { + StorageType::Memory => Storage::Memory { + tokens: Mutex::new(storage::JSONTokens::new()), + }, + StorageType::Disk(path) => Storage::Disk(storage::DiskStorage::new(path).await?), + }; + + Ok(Authenticator { + hyper_client, + storage, + auth_flow, + }) + } + + fn with_auth_flow(auth_flow: F) -> AuthenticatorBuilder { + AuthenticatorBuilder { + hyper_client_builder: DefaultHyperClient, + storage_type: StorageType::Memory, + auth_flow, + } + } + + /// Use the provided hyper client. + pub fn hyper_client( + self, + hyper_client: hyper::Client, + ) -> AuthenticatorBuilder, F> { + AuthenticatorBuilder { + hyper_client_builder: hyper_client, + storage_type: self.storage_type, + auth_flow: self.auth_flow, + } + } + + /// Persist tokens to disk in the provided filename. + pub fn persist_tokens_to_disk>(self, path: P) -> AuthenticatorBuilder { + AuthenticatorBuilder { + storage_type: StorageType::Disk(path.into()), + ..self + } + } +} + +/// ## Methods available when building a device flow Authenticator. +/// ``` +/// # async fn foo() { +/// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultDeviceFlowDelegate; +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").await.unwrap(); +/// let authenticator = yup_oauth2::DeviceFlowAuthenticator::builder(app_secret) +/// .device_code_url("foo") +/// .flow_delegate(Box::new(custom_flow_delegate)) +/// .grant_type("foo") +/// .build() +/// .await +/// .expect("failed to create authenticator"); +/// # } +/// ``` +impl AuthenticatorBuilder { + /// Use the provided device code url. + pub fn device_code_url(self, url: impl Into>) -> Self { + AuthenticatorBuilder { + auth_flow: DeviceFlow { + device_code_url: url.into(), + ..self.auth_flow + }, + ..self + } + } + + /// Use the provided DeviceFlowDelegate. + pub fn flow_delegate(self, flow_delegate: Box) -> Self { + AuthenticatorBuilder { + auth_flow: DeviceFlow { + flow_delegate, + ..self.auth_flow + }, + ..self + } + } + + /// Use the provided grant type. + pub fn grant_type(self, grant_type: impl Into>) -> Self { + AuthenticatorBuilder { + auth_flow: DeviceFlow { + grant_type: grant_type.into(), + ..self.auth_flow + }, + ..self + } + } + + /// Create the authenticator. + pub async fn build(self) -> io::Result> + where + C: HyperClientBuilder, + { + Self::common_build( + self.hyper_client_builder, + self.storage_type, + AuthFlow::DeviceFlow(self.auth_flow), + ) + .await + } +} + +/// ## Methods available when building an installed flow Authenticator. +/// ``` +/// # async fn foo() { +/// # use yup_oauth2::InstalledFlowReturnMethod; +/// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultInstalledFlowDelegate; +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").await.unwrap(); +/// let authenticator = yup_oauth2::InstalledFlowAuthenticator::builder( +/// app_secret, +/// InstalledFlowReturnMethod::HTTPRedirect, +/// ) +/// .flow_delegate(Box::new(custom_flow_delegate)) +/// .build() +/// .await +/// .expect("failed to create authenticator"); +/// # } +/// ``` +impl AuthenticatorBuilder { + /// Use the provided InstalledFlowDelegate. + pub fn flow_delegate(self, flow_delegate: Box) -> Self { + AuthenticatorBuilder { + auth_flow: InstalledFlow { + flow_delegate, + ..self.auth_flow + }, + ..self + } + } + + /// Create the authenticator. + pub async fn build(self) -> io::Result> + where + C: HyperClientBuilder, + { + Self::common_build( + self.hyper_client_builder, + self.storage_type, + AuthFlow::InstalledFlow(self.auth_flow), + ) + .await + } +} + +/// ## Methods available when building a service account authenticator. +/// ``` +/// # async fn foo() { +/// # let service_account_key = yup_oauth2::read_service_account_key("/tmp/foo").await.unwrap(); +/// let authenticator = yup_oauth2::ServiceAccountAuthenticator::builder( +/// service_account_key, +/// ) +/// .subject("mysubject") +/// .build() +/// .await +/// .expect("failed to create authenticator"); +/// # } +/// ``` +impl AuthenticatorBuilder { + /// Use the provided subject. + pub fn subject(self, subject: impl Into) -> Self { + AuthenticatorBuilder { + auth_flow: ServiceAccountFlowOpts { + subject: Some(subject.into()), + ..self.auth_flow + }, + ..self + } + } + + /// Create the authenticator. + pub async fn build(self) -> io::Result> + where + C: HyperClientBuilder, + { + let service_account_auth_flow = ServiceAccountFlow::new(self.auth_flow)?; + Self::common_build( + self.hyper_client_builder, + self.storage_type, + AuthFlow::ServiceAccountFlow(service_account_auth_flow), + ) + .await + } +} + +mod private { + use crate::device::DeviceFlow; + use crate::error::Error; + use crate::installed::InstalledFlow; + use crate::service_account::ServiceAccountFlow; + use crate::types::{ApplicationSecret, TokenInfo}; + + pub enum AuthFlow { + DeviceFlow(DeviceFlow), + InstalledFlow(InstalledFlow), + ServiceAccountFlow(ServiceAccountFlow), + } + + impl AuthFlow { + pub(crate) fn app_secret(&self) -> Option<&ApplicationSecret> { + match self { + AuthFlow::DeviceFlow(device_flow) => Some(&device_flow.app_secret), + AuthFlow::InstalledFlow(installed_flow) => Some(&installed_flow.app_secret), + AuthFlow::ServiceAccountFlow(_) => None, + } + } + + pub(crate) async fn token<'a, C, T>( + &'a self, + hyper_client: &'a hyper::Client, + scopes: &'a [T], + ) -> Result + where + T: AsRef, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, + { + match self { + AuthFlow::DeviceFlow(device_flow) => device_flow.token(hyper_client, scopes).await, + AuthFlow::InstalledFlow(installed_flow) => { + installed_flow.token(hyper_client, scopes).await + } + AuthFlow::ServiceAccountFlow(service_account_flow) => { + service_account_flow.token(hyper_client, scopes).await + } + } + } + } +} + +/// A trait implemented for any hyper::Client as well as the DefaultHyperClient. +pub trait HyperClientBuilder { + /// The hyper connector that the resulting hyper client will use. + type Connector: hyper::client::connect::Connect + Clone + Send + Sync + 'static; + + /// Create a hyper::Client fn build_hyper_client(self) -> hyper::Client; } @@ -48,13 +435,13 @@ impl HyperClientBuilder for DefaultHyperClient { fn build_hyper_client(self) -> hyper::Client { hyper::Client::builder() .keep_alive(false) - .build::<_, hyper::Body>(hyper_rustls::HttpsConnector::new(1)) + .build::<_, hyper::Body>(hyper_rustls::HttpsConnector::new()) } } impl HyperClientBuilder for hyper::Client where - C: hyper::client::connect::Connect, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { type Connector = C; @@ -63,271 +450,18 @@ where } } -/// An internal trait implemented by flows to be used by an authenticator. -pub trait AuthFlow { - type TokenGetter: GetToken; - - fn build_token_getter(self, client: hyper::Client) -> Self::TokenGetter; +enum StorageType { + Memory, + Disk(PathBuf), } -/// An authenticator can be used with `InstalledFlow`'s or `DeviceFlow`'s and -/// will refresh tokens as they expire as well as optionally persist tokens to -/// disk. -pub struct Authenticator< - T: AuthFlow, - S: TokenStorage, - AD: AuthenticatorDelegate, - C: HyperClientBuilder, -> { - client: C, - token_getter: T, - store: io::Result, - delegate: AD, -} +#[cfg(test)] +mod tests { + use super::*; -impl Authenticator -where - T: AuthFlow<::Connector>, -{ - /// Create a new authenticator with the provided flow. By default a new - /// hyper::Client will be created the default authenticator delegate will be - /// used, and tokens will not be persisted to disk. - /// Accepted flow types are DeviceFlow and InstalledFlow. - /// - /// Examples - /// ``` - /// use std::path::Path; - /// use yup_oauth2::{ApplicationSecret, Authenticator, DeviceFlow}; - /// let creds = ApplicationSecret::default(); - /// let auth = Authenticator::new(DeviceFlow::new(creds)).build().unwrap(); - /// ``` - pub fn new( - flow: T, - ) -> Authenticator { - Authenticator { - client: DefaultHyperClient, - token_getter: flow, - store: Ok(MemoryStorage::new()), - delegate: DefaultAuthenticatorDelegate, - } - } -} - -impl Authenticator -where - T: AuthFlow, - S: TokenStorage, - AD: AuthenticatorDelegate, - C: HyperClientBuilder, -{ - /// Use the provided hyper client. - pub fn hyper_client( - self, - hyper_client: hyper::Client, - ) -> Authenticator> - where - NewC: hyper::client::connect::Connect, - T: AuthFlow, - { - Authenticator { - client: hyper_client, - token_getter: self.token_getter, - store: self.store, - delegate: self.delegate, - } - } - - /// Persist tokens to disk in the provided filename. - pub fn persist_tokens_to_disk>( - self, - path: P, - ) -> Authenticator { - let disk_storage = DiskTokenStorage::new(path.as_ref().to_str().unwrap()); - Authenticator { - client: self.client, - token_getter: self.token_getter, - store: disk_storage, - delegate: self.delegate, - } - } - - /// Use the provided authenticator delegate. - pub fn delegate( - self, - delegate: NewAD, - ) -> Authenticator { - Authenticator { - client: self.client, - token_getter: self.token_getter, - store: self.store, - delegate: delegate, - } - } - - /// Create the authenticator. - pub fn build(self) -> io::Result - where - T::TokenGetter: 'static + GetToken + Send, - S: 'static + Send, - AD: 'static + Send, - C::Connector: 'static + Clone + Send, - { - let client = self.client.build_hyper_client(); - let store = Arc::new(Mutex::new(self.store?)); - let inner = Arc::new(Mutex::new( - self.token_getter.build_token_getter(client.clone()), - )); - - Ok(AuthenticatorImpl { - client, - inner, - store, - delegate: self.delegate, - }) - } -} - -impl< - GT: 'static + GetToken + Send, - S: 'static + TokenStorage + Send, - AD: 'static + AuthenticatorDelegate + Send, - C: 'static + hyper::client::connect::Connect + Clone + Send, - > GetToken for AuthenticatorImpl -{ - /// Returns the API Key of the inner flow. - fn api_key(&mut self) -> Option { - self.inner.lock().unwrap().api_key() - } - /// Returns the application secret of the inner flow. - fn application_secret(&self) -> ApplicationSecret { - self.inner.lock().unwrap().application_secret() - } - - fn token( - &mut self, - scopes: I, - ) -> Box + Send> - where - T: Into, - I: IntoIterator, - { - let (scope_key, scopes) = hash_scopes(scopes); - let store = self.store.clone(); - let mut delegate = self.delegate.clone(); - let client = self.client.clone(); - let appsecret = self.inner.lock().unwrap().application_secret(); - let gettoken = self.inner.clone(); - let loopfn = move |()| -> Box< - dyn Future, Error = RequestError> + Send, - > { - // How well does this work with tokio? - match store.lock().unwrap().get( - scope_key.clone(), - &scopes.iter().map(|s| s.as_str()).collect(), - ) { - Ok(Some(t)) => { - if !t.expired() { - return Box::new(Ok(future::Loop::Break(t)).into_future()); - } - // Implement refresh flow. - let refresh_token = t.refresh_token.clone(); - let mut delegate = delegate.clone(); - let store = store.clone(); - let scopes = scopes.clone(); - let refresh_fut = RefreshFlow::refresh_token( - client.clone(), - appsecret.clone(), - refresh_token.unwrap(), - ) - .and_then(move |rr| -> Box, Error=RequestError> + Send> { - match rr { - RefreshResult::Error(ref e) => { - delegate.token_refresh_failed( - format!("{}", e.description().to_string()), - &Some("the request has likely timed out".to_string()), - ); - Box::new(Err(RequestError::Refresh(rr)).into_future()) - } - RefreshResult::RefreshError(ref s, ref ss) => { - delegate.token_refresh_failed( - format!("{} {}", s, ss.clone().map(|s| format!("({})", s)).unwrap_or("".to_string())), - &Some("the refresh token is likely invalid and your authorization has been revoked".to_string()), - ); - Box::new(Err(RequestError::Refresh(rr)).into_future()) - } - RefreshResult::Success(t) => { - if let Err(e) = store.lock().unwrap().set(scope_key, &scopes.iter().map(|s| s.as_str()).collect(), Some(t.clone())) { - match delegate.token_storage_failure(true, &e) { - Retry::Skip => Box::new(Ok(future::Loop::Break(t)).into_future()), - Retry::Abort => Box::new(Err(RequestError::Cache(Box::new(e))).into_future()), - Retry::After(d) => Box::new( - tokio_timer::sleep(d) - .then(|_| Ok(future::Loop::Continue(()))), - ) - as Box< - dyn Future< - Item = future::Loop, - Error = RequestError> + Send>, - } - } else { - Box::new(Ok(future::Loop::Break(t)).into_future()) - } - }, - } - }); - Box::new(refresh_fut) - } - Ok(None) => { - let store = store.clone(); - let scopes = scopes.clone(); - let mut delegate = delegate.clone(); - Box::new( - gettoken - .lock() - .unwrap() - .token(scopes.clone()) - .and_then(move |t| { - if let Err(e) = store.lock().unwrap().set( - scope_key, - &scopes.iter().map(|s| s.as_str()).collect(), - Some(t.clone()), - ) { - match delegate.token_storage_failure(true, &e) { - Retry::Skip => { - Box::new(Ok(future::Loop::Break(t)).into_future()) - } - Retry::Abort => Box::new( - Err(RequestError::Cache(Box::new(e))).into_future(), - ), - Retry::After(d) => Box::new( - tokio_timer::sleep(d) - .then(|_| Ok(future::Loop::Continue(()))), - ) - as Box< - dyn Future< - Item = future::Loop, - Error = RequestError, - > + Send, - >, - } - } else { - Box::new(Ok(future::Loop::Break(t)).into_future()) - } - }), - ) - } - Err(err) => match delegate.token_storage_failure(false, &err) { - Retry::Abort | Retry::Skip => { - return Box::new(Err(RequestError::Cache(Box::new(err))).into_future()) - } - Retry::After(d) => { - return Box::new( - tokio_timer::sleep(d).then(|_| Ok(future::Loop::Continue(()))), - ) - } - }, - } - }; - Box::new(future::loop_fn((), loopfn)) + #[test] + fn ensure_send_sync() { + fn is_send_sync() {} + is_send_sync::::Connector>>() } } diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 9077f6f..6308bd5 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -1,36 +1,21 @@ -use hyper; +//! Module containing types related to delegates. +use crate::error::{AuthErrorOr, Error}; -use std::error::Error; -use std::fmt; -use std::io; - -use crate::types::{PollError, RequestError}; - -use chrono::{DateTime, Local, Utc}; +use std::pin::Pin; use std::time::Duration; -use futures::{future, prelude::*}; -use tokio::io as tio; - -/// A utility type to indicate how operations DeviceFlowHelper operations should be retried -pub enum Retry { - /// Signal you don't want to retry - Abort, - /// Signals you want to retry after the given duration - After(Duration), - /// Instruct the caller to attempt to keep going, or choose an alternate path. - /// If this is not supported, it will have the same effect as `Abort` - Skip, -} +use chrono::{DateTime, Local, Utc}; +use std::future::Future; /// Contains state of pending authentication requests #[derive(Clone, Debug, PartialEq)] -pub struct PollInformation { +pub struct DeviceAuthResponse { + /// The device verification code. + pub device_code: String, /// Code the user must enter ... pub user_code: String, - /// ... at the verification URL - pub verification_url: String, - + /// ... at the verification URI + pub verification_uri: String, /// The `user_code` expires at the given time /// It's the time the user has left to authenticate your application pub expires_at: DateTime, @@ -39,162 +24,134 @@ pub struct PollInformation { pub interval: Duration, } -impl fmt::Display for PollInformation { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - writeln!(f, "Proceed with polling until {}", self.expires_at) +impl DeviceAuthResponse { + pub(crate) fn from_json(json_data: &[u8]) -> Result { + Ok(serde_json::from_slice::>(json_data)?.into_result()?) } } -impl fmt::Display for PollError { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - match *self { - PollError::HttpError(ref err) => err.fmt(f), - PollError::Expired(ref date) => writeln!(f, "Authentication expired at {}", date), - PollError::AccessDenied => "Access denied by user".fmt(f), - PollError::TimedOut => "Timed out waiting for token".fmt(f), - PollError::Other(ref s) => format!("Unknown server error: {}", s).fmt(f), +impl<'de> serde::Deserialize<'de> for DeviceAuthResponse { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(serde::Deserialize)] + struct RawDeviceAuthResponse { + device_code: String, + user_code: String, + // The standard dictates that verification_uri is required, but + // sadly google uses verification_url currently. One of these two + // fields need to be set and verification_uri takes precedence if + // they both are set. + verification_uri: Option, + verification_url: Option, + expires_in: i64, + interval: Option, } + + let RawDeviceAuthResponse { + device_code, + user_code, + verification_uri, + verification_url, + expires_in, + interval, + } = RawDeviceAuthResponse::deserialize(deserializer)?; + + let verification_uri = verification_uri.or(verification_url).ok_or_else(|| { + serde::de::Error::custom("neither verification_uri nor verification_url specified") + })?; + let expires_at = Utc::now() + chrono::Duration::seconds(expires_in); + let interval = Duration::from_secs(interval.unwrap_or(5)); + Ok(DeviceAuthResponse { + device_code, + user_code, + verification_uri, + expires_at, + interval, + }) } } -impl Error for PollError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - match *self { - PollError::HttpError(ref e) => Some(e), - _ => None, - } - } -} - -/// A partially implemented trait to interact with the `Authenticator` -/// -/// The only method that needs to be implemented manually is `present_user_code(...)`, -/// as no assumptions are made on how this presentation should happen. -pub trait AuthenticatorDelegate: Clone { - /// Called whenever there is an client, usually if there are network problems. - /// - /// Return retry information. - fn client_error(&mut self, _: &hyper::Error) -> Retry { - Retry::Abort - } - - /// Called whenever we failed to retrieve a token or set a token due to a storage error. - /// You may use it to either ignore the incident or retry. - /// This can be useful if the underlying `TokenStorage` may fail occasionally. - /// if `is_set` is true, the failure resulted from `TokenStorage.set(...)`. Otherwise, - /// it was `TokenStorage.get(...)` - fn token_storage_failure(&mut self, is_set: bool, _: &dyn Error) -> Retry { - let _ = is_set; - Retry::Abort - } - - /// The server denied the attempt to obtain a request code - fn request_failure(&mut self, _: RequestError) {} - - /// Called if we could not acquire a refresh token for a reason possibly specified - /// by the server. - /// This call is made for the delegate's information only. - fn token_refresh_failed>( - &mut self, - error: S, - error_description: &Option, - ) { - { - let _ = error; - } - { - let _ = error_description; - } - } -} - -/// FlowDelegate methods are called when an OAuth flow needs to ask the application what to do in -/// certain cases. -pub trait FlowDelegate: Clone { - /// Called if the request code is expired. You will have to start over in this case. - /// This will be the last call the delegate receives. - /// Given `DateTime` is the expiration date - fn expired(&mut self, _: &DateTime) {} - - /// Called if the user denied access. You would have to start over. - /// This will be the last call the delegate receives. - fn denied(&mut self) {} - - /// Called as long as we are waiting for the user to authorize us. - /// Can be used to print progress information, or decide to time-out. - /// - /// If the returned `Retry` variant is a duration. - /// # Notes - /// * Only used in `DeviceFlow`. Return value will only be used if it - /// is larger than the interval desired by the server. - fn pending(&mut self, _: &PollInformation) -> Retry { - Retry::After(Duration::from_secs(5)) - } - - /// Configure a custom redirect uri if needed. - fn redirect_uri(&self) -> Option { - None - } +/// DeviceFlowDelegate methods are called when a device flow needs to ask the +/// application what to do in certain cases. +pub trait DeviceFlowDelegate: Send + Sync { /// The server has returned a `user_code` which must be shown to the user, - /// along with the `verification_url`. + /// along with the `verification_uri`. /// # Notes /// * Will be called exactly once, provided we didn't abort during `request_code` phase. - /// * Will only be called if the Authenticator's flow_type is `FlowType::Device`. - fn present_user_code(&mut self, pi: &PollInformation) { - println!( - "Please enter {} at {} and grant access to this application", - pi.user_code, pi.verification_url - ); - println!("Do not close this application until you either denied or granted access."); - println!( - "You have time until {}.", - pi.expires_at.with_timezone(&Local) - ); + fn present_user_code<'a>( + &'a self, + device_auth_resp: &'a DeviceAuthResponse, + ) -> Pin + Send + 'a>> { + Box::pin(present_user_code(device_auth_resp)) + } +} + +async fn present_user_code(device_auth_resp: &DeviceAuthResponse) { + println!( + "Please enter {} at {} and grant access to this application", + device_auth_resp.user_code, device_auth_resp.verification_uri + ); + println!("Do not close this application until you either denied or granted access."); + println!( + "You have time until {}.", + device_auth_resp.expires_at.with_timezone(&Local) + ); +} + +/// InstalledFlowDelegate methods are called when an installed flow needs to ask +/// the application what to do in certain cases. +pub trait InstalledFlowDelegate: Send + Sync { + /// Configure a custom redirect uri if needed. + fn redirect_uri(&self) -> Option<&str> { + None } - /// This method is used by the InstalledFlow. /// We need the user to navigate to a URL using their browser and potentially paste back a code /// (or maybe not). Whether they have to enter a code depends on the InstalledFlowReturnMethod /// used. - fn present_user_url + fmt::Display>( - &mut self, - url: S, + fn present_user_url<'a>( + &'a self, + url: &'a str, need_code: bool, - ) -> Box, Error = Box> + Send> { - if need_code { - println!( - "Please direct your browser to {}, follow the instructions and enter the \ - code displayed here: ", - url - ); - - Box::new( - tio::lines(io::BufReader::new(tio::stdin())) - .into_future() - .map_err(|(e, _)| { - println!("{:?}", e); - Box::new(e) as Box - }) - .and_then(|(l, _)| Ok(l)), - ) - } else { - println!( - "Please direct your browser to {} and follow the instructions displayed \ - there.", - url - ); - Box::new(future::ok(None)) - } + ) -> Pin> + Send + 'a>> { + Box::pin(present_user_url(url, need_code)) } } -/// Uses all default implementations by AuthenticatorDelegate, and makes the trait's -/// implementation usable in the first place. -#[derive(Clone)] -pub struct DefaultAuthenticatorDelegate; -impl AuthenticatorDelegate for DefaultAuthenticatorDelegate {} +async fn present_user_url(url: &str, need_code: bool) -> Result { + use tokio::io::AsyncBufReadExt; + if need_code { + println!( + "Please direct your browser to {}, follow the instructions and enter the \ + code displayed here: ", + url + ); + let mut user_input = String::new(); + tokio::io::BufReader::new(tokio::io::stdin()) + .read_line(&mut user_input) + .await + .map_err(|e| format!("couldn't read code: {}", e))?; + // remove trailing whitespace. + user_input.truncate(user_input.trim_end().len()); + Ok(user_input) + } else { + println!( + "Please direct your browser to {} and follow the instructions displayed \ + there.", + url + ); + Ok(String::new()) + } +} -/// Uses all default implementations in the FlowDelegate trait. -#[derive(Clone)] -pub struct DefaultFlowDelegate; -impl FlowDelegate for DefaultFlowDelegate {} +/// Uses all default implementations in the DeviceFlowDelegate trait. +#[derive(Copy, Clone)] +pub struct DefaultDeviceFlowDelegate; +impl DeviceFlowDelegate for DefaultDeviceFlowDelegate {} + +/// Uses all default implementations in the DeviceFlowDelegate trait. +#[derive(Copy, Clone)] +pub struct DefaultInstalledFlowDelegate; +impl InstalledFlowDelegate for DefaultInstalledFlowDelegate {} diff --git a/src/device.rs b/src/device.rs index 227f1a9..5378494 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,213 +1,112 @@ -use std::iter::{FromIterator, IntoIterator}; +use crate::authenticator_delegate::{ + DefaultDeviceFlowDelegate, DeviceAuthResponse, DeviceFlowDelegate, +}; +use crate::error::{AuthError, Error}; +use crate::types::{ApplicationSecret, TokenInfo}; + +use std::borrow::Cow; use std::time::Duration; -use ::log::{error, log}; -use chrono::{self, Utc}; -use futures::stream::Stream; -use futures::{future, prelude::*}; -use http; -use hyper; use hyper::header; -use itertools::Itertools; -use serde_json as json; -use tokio_timer; use url::form_urlencoded; -use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate, PollInformation, Retry}; -use crate::types::{ - ApplicationSecret, Flow, FlowType, GetToken, JsonError, PollError, RequestError, Token, -}; +pub const GOOGLE_DEVICE_CODE_URL: &str = "https://accounts.google.com/o/oauth2/device/code"; -pub const GOOGLE_DEVICE_CODE_URL: &'static str = "https://accounts.google.com/o/oauth2/device/code"; +// https://developers.google.com/identity/protocols/OAuth2ForDevices#step-4:-poll-googles-authorization-server +pub const GOOGLE_GRANT_TYPE: &str = "http://oauth.net/grant_type/device/1.0"; /// Implements the [Oauth2 Device Flow](https://developers.google.com/youtube/v3/guides/authentication#devices) /// It operates in two steps: /// * obtain a code to show to the user /// * (repeatedly) poll for the user to authenticate your application -#[derive(Clone)] -pub struct DeviceFlow { - application_secret: ApplicationSecret, - device_code_url: String, - flow_delegate: FD, - wait: Duration, +pub struct DeviceFlow { + pub(crate) app_secret: ApplicationSecret, + pub(crate) device_code_url: Cow<'static, str>, + pub(crate) flow_delegate: Box, + pub(crate) grant_type: Cow<'static, str>, } -impl DeviceFlow { +impl DeviceFlow { /// Create a new DeviceFlow. The default FlowDelegate will be used and the /// default wait time is 120 seconds. - pub fn new(secret: ApplicationSecret) -> DeviceFlow { + pub(crate) fn new(app_secret: ApplicationSecret) -> Self { DeviceFlow { - application_secret: secret, - device_code_url: GOOGLE_DEVICE_CODE_URL.to_string(), - flow_delegate: DefaultFlowDelegate, - wait: Duration::from_secs(120), - } - } -} - -impl DeviceFlow { - /// Use the provided device code url. - pub fn device_code_url(self, url: String) -> Self { - DeviceFlow { - device_code_url: url, - ..self + app_secret, + device_code_url: GOOGLE_DEVICE_CODE_URL.into(), + flow_delegate: Box::new(DefaultDeviceFlowDelegate), + grant_type: GOOGLE_GRANT_TYPE.into(), } } - /// Use the provided FlowDelegate. - pub fn delegate(self, delegate: NewFD) -> DeviceFlow { - DeviceFlow { - application_secret: self.application_secret, - device_code_url: self.device_code_url, - flow_delegate: delegate, - wait: self.wait, - } - } - - /// Use the provided wait duration. - pub fn wait_duration(self, duration: Duration) -> Self { - DeviceFlow { - wait: duration, - ..self - } - } -} - -impl crate::authenticator::AuthFlow for DeviceFlow -where - FD: FlowDelegate + Send + 'static, - C: hyper::client::connect::Connect + 'static, -{ - type TokenGetter = DeviceFlowImpl; - - fn build_token_getter(self, client: hyper::Client) -> Self::TokenGetter { - DeviceFlowImpl { - client, - application_secret: self.application_secret, - device_code_url: self.device_code_url, - fd: self.flow_delegate, - wait: Duration::from_secs(1200), - } - } -} - -/// The DeviceFlow implementation. -pub struct DeviceFlowImpl { - client: hyper::Client, - application_secret: ApplicationSecret, - /// Usually GOOGLE_DEVICE_CODE_URL - device_code_url: String, - fd: FD, - wait: Duration, -} - -impl Flow for DeviceFlowImpl { - fn type_id() -> FlowType { - FlowType::Device(String::new()) - } -} - -impl< - FD: FlowDelegate + Clone + Send + 'static, - C: hyper::client::connect::Connect + Sync + 'static, - > GetToken for DeviceFlowImpl -{ - fn token( - &mut self, - scopes: I, - ) -> Box + Send> + pub(crate) async fn token( + &self, + hyper_client: &hyper::Client, + scopes: &[T], + ) -> Result where - T: Into, - I: IntoIterator, + T: AsRef, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { - self.retrieve_device_token(Vec::from_iter(scopes.into_iter().map(Into::into))) - } - fn api_key(&mut self) -> Option { - None - } - fn application_secret(&self) -> ApplicationSecret { - self.application_secret.clone() - } -} - -impl DeviceFlowImpl -where - C: hyper::client::connect::Connect + Sync + 'static, - C::Transport: 'static, - C::Future: 'static, - FD: FlowDelegate + Clone + Send + 'static, -{ - /// Essentially what `GetToken::token` does: Retrieve a token for the given scopes without - /// caching. - fn retrieve_device_token<'a>( - &mut self, - scopes: Vec, - ) -> Box + Send> { - let application_secret = self.application_secret.clone(); - let client = self.client.clone(); - let wait = self.wait; - let mut fd = self.fd.clone(); - let request_code = Self::request_code( - application_secret.clone(), - client.clone(), - self.device_code_url.clone(), + let device_auth_resp = Self::request_code( + &self.app_secret, + hyper_client, + &self.device_code_url, scopes, ) - .and_then(move |(pollinf, device_code)| { - fd.present_user_code(&pollinf); - Ok((pollinf, device_code)) - }); - let fd = self.fd.clone(); - Box::new(request_code.and_then(move |(pollinf, device_code)| { - future::loop_fn(0, move |i| { - // Make a copy of everything every time, because the loop function needs to be - // repeatable, i.e. we can't move anything out. - let pt = Self::poll_token( - application_secret.clone(), - client.clone(), - device_code.clone(), - pollinf.clone(), - fd.clone(), - ); - let maxn = wait.as_secs() / pollinf.interval.as_secs(); - let mut fd = fd.clone(); - let pollinf = pollinf.clone(); - tokio_timer::sleep(pollinf.interval) - .then(|_| pt) - .then(move |r| match r { - Ok(None) if i < maxn => match fd.pending(&pollinf) { - Retry::Abort | Retry::Skip => { - Box::new(Err(RequestError::Poll(PollError::TimedOut)).into_future()) - } - Retry::After(d) => Box::new( - tokio_timer::sleep(d) - .then(move |_| Ok(future::Loop::Continue(i + 1))), - ) - as Box< - dyn Future< - Item = future::Loop, - Error = RequestError, - > + Send, - >, - }, - Ok(Some(tok)) => Box::new(Ok(future::Loop::Break(tok)).into_future()), - Err(e @ PollError::AccessDenied) - | Err(e @ PollError::TimedOut) - | Err(e @ PollError::Expired(_)) => { - Box::new(Err(RequestError::Poll(e)).into_future()) - } - Err(ref e) if i < maxn => { - error!("Unknown error from poll token api: {}", e); - Box::new(Ok(future::Loop::Continue(i + 1)).into_future()) - } - // Too many attempts. - Ok(None) | Err(_) => { - error!("Too many poll attempts"); - Box::new(Err(RequestError::Poll(PollError::TimedOut)).into_future()) - } - }) - }) - })) + .await?; + log::debug!("Presenting code to user"); + self.flow_delegate + .present_user_code(&device_auth_resp) + .await; + self.wait_for_device_token( + hyper_client, + &self.app_secret, + &device_auth_resp, + &self.grant_type, + ) + .await + } + + async fn wait_for_device_token( + &self, + hyper_client: &hyper::Client, + app_secret: &ApplicationSecret, + device_auth_resp: &DeviceAuthResponse, + grant_type: &str, + ) -> Result + where + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, + { + let mut interval = device_auth_resp.interval; + log::debug!("Polling every {:?} for device token", interval); + loop { + tokio::time::delay_for(interval).await; + interval = match Self::poll_token( + &app_secret, + hyper_client, + &device_auth_resp.device_code, + grant_type, + ) + .await + { + Ok(token) => return Ok(token), + Err(Error::AuthError(AuthError { error, .. })) + if error.as_str() == "authorization_pending" => + { + log::debug!("still waiting on authorization from the server"); + interval + } + Err(Error::AuthError(AuthError { error, .. })) if error.as_str() == "slow_down" => { + let interval = interval + Duration::from_secs(5); + log::debug!( + "server requested slow_down. Increasing polling interval to {:?}", + interval + ); + interval + } + Err(err) => return Err(err), + } + } } /// The first step involves asking the server for a code that the user @@ -225,95 +124,40 @@ where /// * If called after a successful result was returned at least once. /// # Examples /// See test-cases in source code for a more complete example. - fn request_code( - application_secret: ApplicationSecret, - client: hyper::Client, - device_code_url: String, - scopes: Vec, - ) -> impl Future { - // note: cloned() shouldn't be needed, see issue - // https://github.com/servo/rust-url/issues/81 + async fn request_code( + application_secret: &ApplicationSecret, + client: &hyper::Client, + device_code_url: &str, + scopes: &[T], + ) -> Result + where + T: AsRef, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, + { let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ - ("client_id", application_secret.client_id.clone()), - ( - "scope", - scopes - .into_iter() - .intersperse(" ".to_string()) - .collect::(), - ), + ("client_id", application_secret.client_id.as_str()), + ("scope", crate::helper::join(scopes, " ").as_str()), ]) .finish(); // note: works around bug in rustlang // https://github.com/rust-lang/rust/issues/22252 - let request = hyper::Request::post(device_code_url) + let req = hyper::Request::post(device_code_url) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(req)) - .into_future(); - request - .then( - move |request: Result, http::Error>| { - let request = request.unwrap(); - client.request(request) - }, - ) - .then( - |r: Result, hyper::error::Error>| { - match r { - Err(err) => { - return Err(RequestError::ClientError(err)); - } - Ok(res) => { - // This return type is defined in https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.2 - // The alias is present as Google use a non-standard name for verification_uri. - // According to the standard interval is optional, however, all tested implementations provide it. - // verification_uri_complete is optional in the standard but not provided in tested implementations. - #[derive(Deserialize)] - struct JsonData { - device_code: String, - user_code: String, - #[serde(alias = "verification_url")] - verification_uri: String, - expires_in: Option, - interval: i64, - } - - let json_str: String = res - .into_body() - .concat2() - .wait() - .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) - .unwrap(); // TODO: error handling - - // check for error - match json::from_str::(&json_str) { - Err(_) => {} // ignore, move on - Ok(res) => return Err(RequestError::from(res)), - } - - let decoded: JsonData = json::from_str(&json_str).unwrap(); - - let expires_in = decoded.expires_in.unwrap_or(60 * 60); - - let pi = PollInformation { - user_code: decoded.user_code, - verification_url: decoded.verification_uri, - expires_at: Utc::now() + chrono::Duration::seconds(expires_in), - interval: Duration::from_secs(i64::abs(decoded.interval) as u64), - }; - Ok((pi, decoded.device_code)) - } - } - }, - ) + .unwrap(); + log::debug!("requesting code from server: {:?}", req); + let (head, body) = client.request(req).await?.into_parts(); + let body = hyper::body::to_bytes(body).await?; + log::debug!("received response; head: {:?}, body: {:?}", head, body); + DeviceAuthResponse::from_json(&body) } /// If the first call is successful, this method may be called. /// As long as we are waiting for authentication, it will return `Ok(None)`. /// You should call it within the interval given the previously returned - /// `PollInformation.interval` field. + /// `DeviceAuthResponse.interval` field. /// /// The operation was successful once you receive an Ok(Some(Token)) for the first time. /// Subsequent calls will return the previous result, which may also be an error state. @@ -328,27 +172,22 @@ where /// /// # Examples /// See test-cases in source code for a more complete example. - fn poll_token<'a>( - application_secret: ApplicationSecret, - client: hyper::Client, - device_code: String, - pi: PollInformation, - mut fd: FD, - ) -> impl Future, Error = PollError> { - let expired = if pi.expires_at <= Utc::now() { - fd.expired(&pi.expires_at); - Err(PollError::Expired(pi.expires_at)).into_future() - } else { - Ok(()).into_future() - }; - + async fn poll_token<'a, C>( + application_secret: &ApplicationSecret, + client: &hyper::Client, + device_code: &str, + grant_type: &str, + ) -> Result + where + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, + { // We should be ready for a new request let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ - ("client_id", &application_secret.client_id[..]), - ("client_secret", &application_secret.client_secret), - ("code", &device_code), - ("grant_type", "http://oauth.net/grant_type/device/1.0"), + ("client_id", application_secret.client_id.as_str()), + ("client_secret", application_secret.client_secret.as_str()), + ("code", device_code), + ("grant_type", grant_type), ]) .finish(); @@ -356,184 +195,10 @@ where .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(req)) .unwrap(); // TODO: Error checking - expired - .and_then(move |_| client.request(request).map_err(|e| PollError::HttpError(e))) - .map(|res| { - res.into_body() - .concat2() - .wait() - .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) - .unwrap() // TODO: error handling - }) - .and_then(move |json_str: String| { - #[derive(Deserialize)] - struct JsonError { - error: String, - } - - match json::from_str::(&json_str) { - Err(_) => {} // ignore, move on, it's not an error - Ok(res) => { - match res.error.as_ref() { - "access_denied" => { - fd.denied(); - return Err(PollError::AccessDenied); - } - "authorization_pending" => return Ok(None), - s => { - return Err(PollError::Other(format!( - "server message '{}' not understood", - s - ))) - } - }; - } - } - - // yes, we expect that ! - let mut t: Token = json::from_str(&json_str).unwrap(); - t.set_expiry_absolute(); - - Ok(Some(t.clone())) - }) - } -} - -#[cfg(test)] -mod tests { - use hyper; - use hyper_rustls::HttpsConnector; - use mockito; - use tokio; - - use super::*; - use crate::authenticator::AuthFlow; - use crate::helper::parse_application_secret; - - #[test] - fn test_device_end2end() { - #[derive(Clone)] - struct FD; - impl FlowDelegate for FD { - fn present_user_code(&mut self, pi: &PollInformation) { - assert_eq!("https://example.com/verify", pi.verification_url); - } - } - - let server_url = mockito::server_url(); - let app_secret = r#"{"installed":{"client_id":"902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com","project_id":"yup-test-243420","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_secret":"iuMPN6Ne1PD7cos29Tk9rlqH","redirect_uris":["urn:ietf:wg:oauth:2.0:oob","http://localhost"]}}"#; - let mut app_secret = parse_application_secret(app_secret).unwrap(); - app_secret.token_uri = format!("{}/token", server_url); - let device_code_url = format!("{}/code", server_url); - - let https = HttpsConnector::new(1); - let client = hyper::Client::builder() - .keep_alive(false) - .build::<_, hyper::Body>(https); - - let mut flow = DeviceFlow::new(app_secret) - .delegate(FD) - .device_code_url(device_code_url) - .build_token_getter(client); - - let mut rt = tokio::runtime::Builder::new() - .core_threads(1) - .panic_handler(|e| std::panic::resume_unwind(e)) - .build() - .unwrap(); - - // Successful path - { - let code_response = r#"{"device_code": "devicecode", "user_code": "usercode", "verification_url": "https://example.com/verify", "expires_in": 1234567, "interval": 1}"#; - let _m = mockito::mock("POST", "/code") - .match_body(mockito::Matcher::Regex( - ".*client_id=902216714886.*".to_string(), - )) - .with_status(200) - .with_body(code_response) - .create(); - let token_response = r#"{"access_token": "accesstoken", "refresh_token": "refreshtoken", "token_type": "Bearer", "expires_in": 1234567}"#; - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*client_secret=iuMPN6Ne1PD7cos29Tk9rlqH&code=devicecode.*".to_string(), - )) - .with_status(200) - .with_body(token_response) - .create(); - - let fut = flow - .token(vec!["https://www.googleapis.com/scope/1"]) - .then(|token| { - let token = token.unwrap(); - assert_eq!("accesstoken", token.access_token); - Ok(()) as Result<(), ()> - }); - rt.block_on(fut).expect("block_on"); - - _m.assert(); - } - // Code is not delivered. - { - let code_response = - r#"{"error": "invalid_client_id", "error_description": "description"}"#; - let _m = mockito::mock("POST", "/code") - .match_body(mockito::Matcher::Regex( - ".*client_id=902216714886.*".to_string(), - )) - .with_status(400) - .with_body(code_response) - .create(); - let token_response = r#"{"access_token": "accesstoken", "refresh_token": "refreshtoken", "token_type": "Bearer", "expires_in": 1234567}"#; - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*client_secret=iuMPN6Ne1PD7cos29Tk9rlqH&code=devicecode.*".to_string(), - )) - .with_status(200) - .with_body(token_response) - .expect(0) // Never called! - .create(); - - let fut = flow - .token(vec!["https://www.googleapis.com/scope/1"]) - .then(|token| { - assert!(token.is_err()); - assert!(format!("{}", token.unwrap_err()).contains("invalid_client_id")); - Ok(()) as Result<(), ()> - }); - rt.block_on(fut).expect("block_on"); - - _m.assert(); - } - // Token is not delivered. - { - let code_response = r#"{"device_code": "devicecode", "user_code": "usercode", "verification_url": "https://example.com/verify", "expires_in": 1234567, "interval": 1}"#; - let _m = mockito::mock("POST", "/code") - .match_body(mockito::Matcher::Regex( - ".*client_id=902216714886.*".to_string(), - )) - .with_status(200) - .with_body(code_response) - .create(); - let token_response = r#"{"error": "access_denied"}"#; - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*client_secret=iuMPN6Ne1PD7cos29Tk9rlqH&code=devicecode.*".to_string(), - )) - .with_status(400) - .with_body(token_response) - .expect(1) - .create(); - - let fut = flow - .token(vec!["https://www.googleapis.com/scope/1"]) - .then(|token| { - assert!(token.is_err()); - assert!(format!("{}", token.unwrap_err()).contains("Access denied by user")); - Ok(()) as Result<(), ()> - }); - rt.block_on(fut).expect("block_on"); - - _m.assert(); - } + log::debug!("polling for token: {:?}", request); + let (head, body) = client.request(request).await?.into_parts(); + let body = hyper::body::to_bytes(body).await?; + log::debug!("received response; head: {:?} body: {:?}", head, body); + TokenInfo::from_json(&body) } } diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..63cf636 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,256 @@ +//! Module containing various error types. + +use std::borrow::Cow; +use std::error::Error as StdError; +use std::fmt; +use std::io; + +use serde::Deserialize; + +/// Error returned by the authorization server. +/// https://tools.ietf.org/html/rfc6749#section-5.2 +/// https://tools.ietf.org/html/rfc8628#section-3.5 +#[derive(Deserialize, Debug, PartialEq, Eq)] +pub struct AuthError { + /// Error code from the server. + pub error: AuthErrorCode, + /// Human-readable text providing additional information. + pub error_description: Option, + /// A URI identifying a human-readable web page with information about the error. + pub error_uri: Option, +} + +impl fmt::Display for AuthError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", &self.error.as_str())?; + if let Some(desc) = &self.error_description { + write!(f, ": {}", desc)?; + } + if let Some(uri) = &self.error_uri { + write!(f, "; See {} for more info", uri)?; + } + Ok(()) + } +} +impl StdError for AuthError {} + +/// The error code returned by the authorization server. +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum AuthErrorCode { + /// invalid_request + InvalidRequest, + /// invalid_client + InvalidClient, + /// invalid_grant + InvalidGrant, + /// unauthorized_client + UnauthorizedClient, + /// unsupported_grant_type + UnsupportedGrantType, + /// invalid_scope + InvalidScope, + /// access_denied + AccessDenied, + /// expired_token + ExpiredToken, + /// other error + Other(String), +} + +impl AuthErrorCode { + /// The error code as a &str + pub fn as_str(&self) -> &str { + match self { + AuthErrorCode::InvalidRequest => "invalid_request", + AuthErrorCode::InvalidClient => "invalid_client", + AuthErrorCode::InvalidGrant => "invalid_grant", + AuthErrorCode::UnauthorizedClient => "unauthorized_client", + AuthErrorCode::UnsupportedGrantType => "unsupported_grant_type", + AuthErrorCode::InvalidScope => "invalid_scope", + AuthErrorCode::AccessDenied => "access_denied", + AuthErrorCode::ExpiredToken => "expired_token", + AuthErrorCode::Other(s) => s.as_str(), + } + } + + fn from_string<'a>(s: impl Into>) -> AuthErrorCode { + let s = s.into(); + match s.as_ref() { + "invalid_request" => AuthErrorCode::InvalidRequest, + "invalid_client" => AuthErrorCode::InvalidClient, + "invalid_grant" => AuthErrorCode::InvalidGrant, + "unauthorized_client" => AuthErrorCode::UnauthorizedClient, + "unsupported_grant_type" => AuthErrorCode::UnsupportedGrantType, + "invalid_scope" => AuthErrorCode::InvalidScope, + "access_denied" => AuthErrorCode::AccessDenied, + "expired_token" => AuthErrorCode::ExpiredToken, + _ => AuthErrorCode::Other(s.into_owned()), + } + } +} + +impl From for AuthErrorCode { + fn from(s: String) -> Self { + AuthErrorCode::from_string(s) + } +} + +impl<'a> From<&'a str> for AuthErrorCode { + fn from(s: &str) -> Self { + AuthErrorCode::from_string(s) + } +} + +impl<'de> Deserialize<'de> for AuthErrorCode { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct V; + impl<'de> serde::de::Visitor<'de> for V { + type Value = AuthErrorCode; + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("any string") + } + fn visit_string(self, value: String) -> Result { + Ok(value.into()) + } + fn visit_str(self, value: &str) -> Result { + Ok(value.into()) + } + } + deserializer.deserialize_string(V) + } +} + +/// A helper type to deserialize either an AuthError or another piece of data. +#[derive(Deserialize, Debug)] +#[serde(untagged)] +pub(crate) enum AuthErrorOr { + AuthError(AuthError), + Data(T), +} + +impl AuthErrorOr { + pub(crate) fn into_result(self) -> Result { + match self { + AuthErrorOr::AuthError(err) => Result::Err(err), + AuthErrorOr::Data(value) => Result::Ok(value), + } + } +} + +/// Encapsulates all possible results of the `token(...)` operation +#[derive(Debug)] +pub enum Error { + /// Indicates connection failure + HttpError(hyper::Error), + /// The server returned an error. + AuthError(AuthError), + /// Error while decoding a JSON response. + JSONError(serde_json::Error), + /// Error within user input. + UserError(String), + /// A lower level IO error. + LowLevelError(io::Error), +} + +impl From for Error { + fn from(error: hyper::Error) -> Error { + Error::HttpError(error) + } +} + +impl From for Error { + fn from(value: AuthError) -> Error { + Error::AuthError(value) + } +} + +impl From for Error { + fn from(value: serde_json::Error) -> Error { + Error::JSONError(value) + } +} + +impl From for Error { + fn from(value: io::Error) -> Error { + Error::LowLevelError(value) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + match *self { + Error::HttpError(ref err) => err.fmt(f), + Error::AuthError(ref err) => err.fmt(f), + Error::JSONError(ref e) => { + write!( + f, + "JSON Error; this might be a bug with unexpected server responses! {}", + e + )?; + Ok(()) + } + Error::UserError(ref s) => s.fmt(f), + Error::LowLevelError(ref e) => e.fmt(f), + } + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match *self { + Error::HttpError(ref err) => Some(err), + Error::AuthError(ref err) => Some(err), + Error::JSONError(ref err) => Some(err), + Error::LowLevelError(ref err) => Some(err), + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_auth_error_code_deserialize() { + assert_eq!( + AuthErrorCode::InvalidRequest, + serde_json::from_str(r#""invalid_request""#).unwrap() + ); + assert_eq!( + AuthErrorCode::InvalidClient, + serde_json::from_str(r#""invalid_client""#).unwrap() + ); + assert_eq!( + AuthErrorCode::InvalidGrant, + serde_json::from_str(r#""invalid_grant""#).unwrap() + ); + assert_eq!( + AuthErrorCode::UnauthorizedClient, + serde_json::from_str(r#""unauthorized_client""#).unwrap() + ); + assert_eq!( + AuthErrorCode::UnsupportedGrantType, + serde_json::from_str(r#""unsupported_grant_type""#).unwrap() + ); + assert_eq!( + AuthErrorCode::InvalidScope, + serde_json::from_str(r#""invalid_scope""#).unwrap() + ); + assert_eq!( + AuthErrorCode::AccessDenied, + serde_json::from_str(r#""access_denied""#).unwrap() + ); + assert_eq!( + AuthErrorCode::ExpiredToken, + serde_json::from_str(r#""expired_token""#).unwrap() + ); + assert_eq!( + AuthErrorCode::Other("undefined".to_owned()), + serde_json::from_str(r#""undefined""#).unwrap() + ); + } +} diff --git a/src/helper.rs b/src/helper.rs index c9471fc..b143db7 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -1,63 +1,72 @@ -#![allow(dead_code)] - //! Helper functions allowing you to avoid writing boilerplate code for common operations, such as //! parsing JSON or reading files. // Copyright (c) 2016 Google Inc (lewinb@google.com). // // Refer to the project root for licensing information. - -use serde_json; - -use std::fs; -use std::io::{self, Read}; -use std::path::Path; - use crate::service_account::ServiceAccountKey; use crate::types::{ApplicationSecret, ConsoleApplicationSecret}; -/// Read an application secret from a file. -pub fn read_application_secret(path: &Path) -> io::Result { - let mut secret = String::new(); - let mut file = fs::OpenOptions::new().read(true).open(path)?; - file.read_to_string(&mut secret)?; +use std::io; +use std::path::Path; - parse_application_secret(&secret) +/// Read an application secret from a file. +pub async fn read_application_secret>(path: P) -> io::Result { + parse_application_secret(tokio::fs::read(path).await?) } /// Read an application secret from a JSON string. -pub fn parse_application_secret>(secret: S) -> io::Result { - let result: serde_json::Result = - serde_json::from_str(secret.as_ref()); - match result { - Err(e) => Err(io::Error::new( +pub fn parse_application_secret>(secret: S) -> io::Result { + let decoded: ConsoleApplicationSecret = + serde_json::from_slice(secret.as_ref()).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("Bad application secret: {}", e), + ) + })?; + + if let Some(web) = decoded.web { + Ok(web) + } else if let Some(installed) = decoded.installed { + Ok(installed) + } else { + Err(io::Error::new( io::ErrorKind::InvalidData, - format!("Bad application secret: {}", e), - )), - Ok(decoded) => { - if decoded.web.is_some() { - Ok(decoded.web.unwrap()) - } else if decoded.installed.is_some() { - Ok(decoded.installed.unwrap()) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "Unknown application secret format", - )) - } - } + "Unknown application secret format", + )) } } /// Read a service account key from a JSON file. You can download the JSON keys from the Google /// Cloud Console or the respective console of your service provider. -pub fn service_account_key_from_file>(path: S) -> io::Result { - let mut key = String::new(); - let mut file = fs::OpenOptions::new().read(true).open(path)?; - file.read_to_string(&mut key)?; - - match serde_json::from_str(&key) { - Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, format!("{}", e))), - Ok(decoded) => Ok(decoded), - } +pub async fn read_service_account_key>(path: P) -> io::Result { + let key = tokio::fs::read(path).await?; + serde_json::from_slice(&key).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("Bad service account key: {}", e), + ) + }) +} + +pub(crate) fn join(pieces: &[T], separator: &str) -> String +where + T: AsRef, +{ + let mut iter = pieces.iter(); + let first = match iter.next() { + Some(p) => p, + None => return String::new(), + }; + let num_separators = pieces.len() - 1; + let pieces_size: usize = pieces.iter().map(|p| p.as_ref().len()).sum(); + let size = pieces_size + separator.len() * num_separators; + let mut result = String::with_capacity(size); + result.push_str(first.as_ref()); + for p in iter { + result.push_str(separator); + result.push_str(p.as_ref()); + } + debug_assert_eq!(size, result.len()); + result } diff --git a/src/installed.rs b/src/installed.rs index 31c967a..3611162 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -2,53 +2,42 @@ // // Refer to the project root for licensing information. // +use crate::authenticator_delegate::{DefaultInstalledFlowDelegate, InstalledFlowDelegate}; +use crate::error::Error; +use crate::types::{ApplicationSecret, TokenInfo}; + use std::convert::AsRef; +use std::net::SocketAddr; use std::sync::{Arc, Mutex}; -use futures::prelude::*; -use futures::stream::Stream; -use futures::sync::oneshot; -use hyper; -use hyper::{header, StatusCode, Uri}; +use hyper::header; +use tokio::sync::oneshot; use url::form_urlencoded; use url::percent_encoding::{percent_encode, QUERY_ENCODE_SET}; -use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate}; -use crate::types::{ApplicationSecret, GetToken, RequestError, Token}; - -const OOB_REDIRECT_URI: &'static str = "urn:ietf:wg:oauth:2.0:oob"; +const OOB_REDIRECT_URI: &str = "urn:ietf:wg:oauth:2.0:oob"; /// Assembles a URL to request an authorization token (with user interaction). /// Note that the redirect_uri here has to be either None or some variation of /// http://localhost:{port}, or the authorization won't work (error "redirect_uri_mismatch") -fn build_authentication_request_url<'a, T, I>( +fn build_authentication_request_url( auth_uri: &str, client_id: &str, - scopes: I, - redirect_uri: Option, + scopes: &[T], + redirect_uri: Option<&str>, ) -> String where - T: AsRef + 'a, - I: IntoIterator, + T: AsRef, { let mut url = String::new(); - let mut scopes_string = scopes.into_iter().fold(String::new(), |mut acc, sc| { - acc.push_str(sc.as_ref()); - acc.push_str(" "); - acc - }); - // Remove last space - scopes_string.pop(); + let scopes_string = crate::helper::join(scopes, " "); url.push_str(auth_uri); vec![ format!("?scope={}", scopes_string), - format!("&access_type=offline"), - format!( - "&redirect_uri={}", - redirect_uri.unwrap_or(OOB_REDIRECT_URI.to_string()) - ), - format!("&response_type=code"), + "&access_type=offline".to_string(), + format!("&redirect_uri={}", redirect_uri.unwrap_or(OOB_REDIRECT_URI)), + "&response_type=code".to_string(), format!("&client_id={}", client_id), ] .into_iter() @@ -58,35 +47,6 @@ where }) } -impl - GetToken for InstalledFlowImpl -{ - fn token( - &mut self, - scopes: I, - ) -> Box + Send> - where - T: Into, - I: IntoIterator, - { - Box::new(self.obtain_token(scopes.into_iter().map(Into::into).collect())) - } - fn api_key(&mut self) -> Option { - None - } - fn application_secret(&self) -> ApplicationSecret { - self.appsecret.clone() - } -} - -/// The InstalledFlow implementation. -pub struct InstalledFlowImpl { - method: InstalledFlowReturnMethod, - client: hyper::client::Client, - fd: FD, - appsecret: ApplicationSecret, -} - /// cf. https://developers.google.com/identity/protocols/OAuth2InstalledApp#choosingredirecturi pub enum InstalledFlowReturnMethod { /// Involves showing a URL to the user and asking to copy a code from their browser @@ -94,392 +54,245 @@ pub enum InstalledFlowReturnMethod { Interactive, /// Involves spinning up a local HTTP server and Google redirecting the browser to /// the server with a URL containing the code (preferred, but not as reliable). - HTTPRedirectEphemeral, - /// Involves spinning up a local HTTP server and Google redirecting the browser to - /// the server with a URL containing the code (preferred, but not as reliable). The - /// parameter is the port to listen on. - HTTPRedirect(u16), + HTTPRedirect, } /// InstalledFlowImpl provides tokens for services that follow the "Installed" OAuth flow. (See /// https://www.oauth.com/oauth2-servers/authorization/, /// https://developers.google.com/identity/protocols/OAuth2InstalledApp). -pub struct InstalledFlow { - method: InstalledFlowReturnMethod, - flow_delegate: FD, - appsecret: ApplicationSecret, +pub struct InstalledFlow { + pub(crate) app_secret: ApplicationSecret, + pub(crate) method: InstalledFlowReturnMethod, + pub(crate) flow_delegate: Box, } -impl InstalledFlow { +impl InstalledFlow { /// Create a new InstalledFlow with the provided secret and method. - pub fn new( - secret: ApplicationSecret, + pub(crate) fn new( + app_secret: ApplicationSecret, method: InstalledFlowReturnMethod, - ) -> InstalledFlow { + ) -> InstalledFlow { InstalledFlow { + app_secret, method, - flow_delegate: DefaultFlowDelegate, - appsecret: secret, + flow_delegate: Box::new(DefaultInstalledFlowDelegate), } } -} -impl InstalledFlow -where - FD: FlowDelegate, -{ - /// Use the provided FlowDelegate. - pub fn delegate(self, delegate: NewFD) -> InstalledFlow { - InstalledFlow { - method: self.method, - flow_delegate: delegate, - appsecret: self.appsecret, - } - } -} - -impl crate::authenticator::AuthFlow for InstalledFlow -where - FD: FlowDelegate + Send + 'static, - C: hyper::client::connect::Connect + 'static, -{ - type TokenGetter = InstalledFlowImpl; - - fn build_token_getter(self, client: hyper::Client) -> Self::TokenGetter { - InstalledFlowImpl { - method: self.method, - fd: self.flow_delegate, - appsecret: self.appsecret, - client, - } - } -} - -impl<'c, FD: 'static + FlowDelegate + Clone + Send, C: 'c + hyper::client::connect::Connect> - InstalledFlowImpl -{ /// Handles the token request flow; it consists of the following steps: /// . Obtain a authorization code with user cooperation or internal redirect. /// . Obtain a token and refresh token using that code. /// . Return that token /// - /// It's recommended not to use the DefaultFlowDelegate, but a specialized one. - fn obtain_token<'a>( - &mut self, - scopes: Vec, // Note: I haven't found a better way to give a list of strings here, due to ownership issues with futures. - ) -> impl 'a + Future + Send { - let rduri = self.fd.redirect_uri(); - // Start server on localhost to accept auth code. - let server_bind_port = match self.method { - InstalledFlowReturnMethod::HTTPRedirect(port) => Some(port), - InstalledFlowReturnMethod::HTTPRedirectEphemeral => Some(0), - _ => None, - }; - let server = if let Some(port) = server_bind_port { - match InstalledFlowServer::new(port) { - Result::Err(e) => Err(RequestError::ClientError(e)), - Result::Ok(server) => Ok(Some(server)), + /// It's recommended not to use the DefaultInstalledFlowDelegate, but a specialized one. + pub(crate) async fn token( + &self, + hyper_client: &hyper::Client, + scopes: &[T], + ) -> Result + where + T: AsRef, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, + { + match self.method { + InstalledFlowReturnMethod::HTTPRedirect => { + self.ask_auth_code_via_http(hyper_client, &self.app_secret, scopes) + .await } - } else { - Ok(None) - }; - let port = if let Ok(Some(ref srv)) = server { - Some(srv.port) - } else { - None - }; - let client = self.client.clone(); - let (appsecclone, appsecclone2) = (self.appsecret.clone(), self.appsecret.clone()); - let auth_delegate = self.fd.clone(); - server - .into_future() - // First: Obtain authorization code from user. - .and_then(move |server| { - Self::ask_authorization_code(server, auth_delegate, &appsecclone, scopes.iter()) - }) - // Exchange the authorization code provided by Google/the provider for a refresh and an - // access token. - .and_then(move |authcode| { - let request = Self::request_token(appsecclone2, authcode, rduri, port); - let result = client.request(request); - // Handle result here, it makes ownership tracking easier. - result - .and_then(move |r| { - r.into_body() - .concat2() - .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) - // TODO: error handling - }) - .then(|body_or| { - let resp = match body_or { - Err(e) => return Err(RequestError::ClientError(e)), - Ok(s) => s, - }; - - let token_resp: Result = - serde_json::from_str(&resp); - - match token_resp { - Err(e) => { - return Err(RequestError::JSONError(e)); - } - Ok(tok) => { - if tok.error.is_some() { - Err(RequestError::NegativeServerResponse( - tok.error.unwrap(), - tok.error_description, - )) - } else { - Ok(tok) - } - } - } - }) - }) - // Return the combined token. - .and_then(|tokens| { - // Successful response - if tokens.access_token.is_some() { - let mut token = Token { - access_token: tokens.access_token.unwrap(), - refresh_token: Some(tokens.refresh_token.unwrap()), - token_type: tokens.token_type.unwrap(), - expires_in: tokens.expires_in, - expires_in_timestamp: None, - }; - - token.set_expiry_absolute(); - Ok(token) - } else { - Err(RequestError::NegativeServerResponse( - tokens.error.unwrap(), - tokens.error_description, - )) - } - }) + InstalledFlowReturnMethod::Interactive => { + self.ask_auth_code_interactively(hyper_client, &self.app_secret, scopes) + .await + } + } } - fn ask_authorization_code<'a, S, T>( - server: Option, - mut auth_delegate: FD, - appsecret: &ApplicationSecret, - scopes: S, - ) -> Box + Send> + async fn ask_auth_code_interactively( + &self, + hyper_client: &hyper::Client, + app_secret: &ApplicationSecret, + scopes: &[T], + ) -> Result where - T: AsRef + 'a, - S: Iterator, + T: AsRef, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { - if server.is_none() { - let url = build_authentication_request_url( - &appsecret.auth_uri, - &appsecret.client_id, - scopes, - auth_delegate.redirect_uri(), - ); - Box::new( - auth_delegate - .present_user_url(&url, true /* need_code */) - .then(|r| { - match r { - Ok(Some(mut code)) => { - // Partial backwards compatibility in case an implementation adds a new line - // due to previous behaviour. - let ends_with_newline = - code.chars().last().map(|c| c == '\n').unwrap_or(false); - if ends_with_newline { - code.pop(); - } - Ok(code) - } - _ => Err(RequestError::UserError("couldn't read code".to_string())), - } - }), - ) - } else { - let mut server = server.unwrap(); - // The redirect URI must be this very localhost URL, otherwise authorization is refused - // by certain providers. - let url = build_authentication_request_url( - &appsecret.auth_uri, - &appsecret.client_id, - scopes, - auth_delegate - .redirect_uri() - .or_else(|| Some(format!("http://localhost:{}", server.port))), - ); - Box::new( - auth_delegate - .present_user_url(&url, false /* need_code */) - .then(move |_| server.block_till_auth()) - .map_err(|e| { - RequestError::UserError(format!( - "could not obtain token via redirect: {}", - e - )) - }), - ) - } + let url = build_authentication_request_url( + &app_secret.auth_uri, + &app_secret.client_id, + scopes, + self.flow_delegate.redirect_uri(), + ); + log::debug!("Presenting auth url to user: {}", url); + let auth_code = self + .flow_delegate + .present_user_url(&url, true /* need code */) + .await + .map_err(Error::UserError)?; + log::debug!("Received auth code: {}", auth_code); + self.exchange_auth_code(&auth_code, hyper_client, app_secret, None) + .await + } + + async fn ask_auth_code_via_http( + &self, + hyper_client: &hyper::Client, + app_secret: &ApplicationSecret, + scopes: &[T], + ) -> Result + where + T: AsRef, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, + { + use std::borrow::Cow; + let server = InstalledFlowServer::run()?; + let server_addr = server.local_addr(); + + // Present url to user. + // The redirect URI must be this very localhost URL, otherwise authorization is refused + // by certain providers. + let redirect_uri: Cow = match self.flow_delegate.redirect_uri() { + Some(uri) => uri.into(), + None => format!("http://{}", server_addr).into(), + }; + let url = build_authentication_request_url( + &app_secret.auth_uri, + &app_secret.client_id, + scopes, + Some(redirect_uri.as_ref()), + ); + log::debug!("Presenting auth url to user: {}", url); + let _ = self + .flow_delegate + .present_user_url(&url, false /* need code */) + .await; + let auth_code = server.wait_for_auth_code().await; + self.exchange_auth_code(&auth_code, hyper_client, app_secret, Some(server_addr)) + .await + } + + async fn exchange_auth_code( + &self, + authcode: &str, + hyper_client: &hyper::Client, + app_secret: &ApplicationSecret, + server_addr: Option, + ) -> Result + where + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, + { + let redirect_uri = self.flow_delegate.redirect_uri(); + let request = Self::request_token(app_secret, authcode, redirect_uri, server_addr); + log::debug!("Sending request: {:?}", request); + let (head, body) = hyper_client.request(request).await?.into_parts(); + let body = hyper::body::to_bytes(body).await?; + log::debug!("Received response; head: {:?} body: {:?}", head, body); + TokenInfo::from_json(&body) } /// Sends the authorization code to the provider in order to obtain access and refresh tokens. - fn request_token<'a>( - appsecret: ApplicationSecret, - authcode: String, - custom_redirect_uri: Option, - port: Option, + fn request_token( + app_secret: &ApplicationSecret, + authcode: &str, + custom_redirect_uri: Option<&str>, + server_addr: Option, ) -> hyper::Request { - let redirect_uri = custom_redirect_uri.unwrap_or_else(|| match port { - None => OOB_REDIRECT_URI.to_string(), - Some(port) => format!("http://localhost:{}", port), - }); + use std::borrow::Cow; + let redirect_uri: Cow = match (custom_redirect_uri, server_addr) { + (Some(uri), _) => uri.into(), + (None, Some(addr)) => format!("http://{}", addr).into(), + (None, None) => OOB_REDIRECT_URI.into(), + }; let body = form_urlencoded::Serializer::new(String::new()) .extend_pairs(vec![ - ("code".to_string(), authcode.to_string()), - ("client_id".to_string(), appsecret.client_id.clone()), - ("client_secret".to_string(), appsecret.client_secret.clone()), - ("redirect_uri".to_string(), redirect_uri), - ("grant_type".to_string(), "authorization_code".to_string()), + ("code", authcode), + ("client_id", app_secret.client_id.as_str()), + ("client_secret", app_secret.client_secret.as_str()), + ("redirect_uri", redirect_uri.as_ref()), + ("grant_type", "authorization_code"), ]) .finish(); - let request = hyper::Request::post(appsecret.token_uri) + hyper::Request::post(&app_secret.token_uri) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(body)) - .unwrap(); // TODO: error check - request + .unwrap() // TODO: error check } } -#[derive(Deserialize)] -struct JSONTokenResponse { - access_token: Option, - refresh_token: Option, - token_type: Option, - expires_in: Option, - - error: Option, - error_description: Option, -} - struct InstalledFlowServer { - port: u16, - shutdown_tx: Option>, - auth_code_rx: Option>, - threadpool: Option, + addr: SocketAddr, + auth_code_rx: oneshot::Receiver, + trigger_shutdown_tx: oneshot::Sender<()>, + shutdown_complete: tokio::task::JoinHandle<()>, } impl InstalledFlowServer { - fn new(port: u16) -> Result { + fn run() -> Result { + use hyper::service::{make_service_fn, service_fn}; let (auth_code_tx, auth_code_rx) = oneshot::channel::(); - let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let (trigger_shutdown_tx, trigger_shutdown_rx) = oneshot::channel::<()>(); + let auth_code_tx = Arc::new(Mutex::new(Some(auth_code_tx))); - let threadpool = tokio_threadpool::Builder::new() - .pool_size(1) - .name_prefix("InstalledFlowServer-") - .build(); - let service_maker = InstalledFlowServiceMaker::new(auth_code_tx); - - let addr: std::net::SocketAddr = ([127, 0, 0, 1], port).into(); - let builder = hyper::server::Server::try_bind(&addr)?; - let server = builder.http1_only(true).serve(service_maker); - let port = server.local_addr().port(); - let server_future = server - .with_graceful_shutdown(shutdown_rx) - .map_err(|err| panic!("Failed badly: {}", err)); - - threadpool.spawn(server_future); - - Result::Ok(InstalledFlowServer { - port: port, - shutdown_tx: Some(shutdown_tx), - auth_code_rx: Some(auth_code_rx), - threadpool: Some(threadpool), + let service = make_service_fn(move |_| { + let auth_code_tx = auth_code_tx.clone(); + async move { + use std::convert::Infallible; + Ok::<_, Infallible>(service_fn(move |req| { + installed_flow_server::handle_req(req, auth_code_tx.clone()) + })) + } + }); + let addr: std::net::SocketAddr = ([127, 0, 0, 1], 0).into(); + let server = hyper::server::Server::try_bind(&addr)?; + let server = server.http1_only(true).serve(service); + let addr = server.local_addr(); + let shutdown_complete = tokio::spawn(async { + let _ = server + .with_graceful_shutdown(async move { + let _ = trigger_shutdown_rx.await; + }) + .await; + }); + log::debug!("HTTP server listening on {}", addr); + Ok(InstalledFlowServer { + addr, + auth_code_rx, + trigger_shutdown_tx, + shutdown_complete, }) } - fn block_till_auth(&mut self) -> Result { - match self.auth_code_rx.take() { - Some(auth_code_rx) => auth_code_rx.wait(), - None => Result::Err(oneshot::Canceled), - } + fn local_addr(&self) -> SocketAddr { + self.addr + } + + async fn wait_for_auth_code(self) -> String { + log::debug!("Waiting for HTTP server to receive auth code"); + // Wait for the auth code from the server. + let auth_code = self + .auth_code_rx + .await + .expect("server shutdown while waiting for auth_code"); + log::debug!("HTTP server received auth code: {}", auth_code); + log::debug!("Shutting down HTTP server"); + // auth code received. shutdown the server + let _ = self.trigger_shutdown_tx.send(()); + let _ = self.shutdown_complete.await; + auth_code } } -impl std::ops::Drop for InstalledFlowServer { - fn drop(&mut self) { - self.shutdown_tx.take().map(|tx| tx.send(())); - self.auth_code_rx.take().map(|mut rx| rx.close()); - self.threadpool.take(); - } -} +mod installed_flow_server { + use hyper::{Body, Request, Response, StatusCode, Uri}; + use std::sync::{Arc, Mutex}; + use tokio::sync::oneshot; + use url::form_urlencoded; -pub struct InstalledFlowHandlerResponseFuture { - inner: Box< - dyn futures::Future, Error = hyper::http::Error> + Send, - >, -} - -impl InstalledFlowHandlerResponseFuture { - fn new( - fut: Box< - dyn futures::Future, Error = hyper::http::Error> - + Send, - >, - ) -> Self { - Self { inner: fut } - } -} - -impl futures::Future for InstalledFlowHandlerResponseFuture { - type Item = hyper::Response; - type Error = hyper::http::Error; - - fn poll(&mut self) -> futures::Poll { - self.inner.poll() - } -} - -/// Creates InstalledFlowService on demand -struct InstalledFlowServiceMaker { - auth_code_tx: Arc>>>, -} - -impl InstalledFlowServiceMaker { - fn new(auth_code_tx: oneshot::Sender) -> InstalledFlowServiceMaker { - let auth_code_tx = Arc::new(Mutex::new(Option::Some(auth_code_tx))); - InstalledFlowServiceMaker { auth_code_tx } - } -} - -impl hyper::service::MakeService for InstalledFlowServiceMaker { - type ReqBody = hyper::Body; - type ResBody = hyper::Body; - type Error = hyper::http::Error; - type Service = InstalledFlowService; - type Future = futures::future::FutureResult; - type MakeError = hyper::http::Error; - - fn make_service(&mut self, _ctx: Ctx) -> Self::Future { - let service = InstalledFlowService { - auth_code_tx: self.auth_code_tx.clone(), - }; - futures::future::ok(service) - } -} - -/// HTTP service handling the redirect from the provider. -struct InstalledFlowService { - auth_code_tx: Arc>>>, -} - -impl hyper::service::Service for InstalledFlowService { - type ReqBody = hyper::Body; - type ResBody = hyper::Body; - type Error = hyper::http::Error; - type Future = InstalledFlowHandlerResponseFuture; - - fn call(&mut self, req: hyper::Request) -> Self::Future { + pub(super) async fn handle_req( + req: Request, + auth_code_tx: Arc>>>, + ) -> Result, http::Error> { match req.uri().path_and_query() { Some(path_and_query) => { // We use a fake URL because the redirect goes to a URL, meaning we @@ -491,240 +304,51 @@ impl hyper::service::Service for InstalledFlowService { .path_and_query(path_and_query.clone()) .build(); - if url.is_err() { - let response = hyper::Response::builder() + match url { + Err(_) => hyper::Response::builder() .status(StatusCode::BAD_REQUEST) - .body(hyper::Body::from("Unparseable URL")); - - match response { - Ok(response) => InstalledFlowHandlerResponseFuture::new(Box::new( - futures::future::ok(response), - )), - Err(err) => InstalledFlowHandlerResponseFuture::new(Box::new( - futures::future::err(err), - )), - } - } else { - self.handle_url(url.unwrap()); - let response = - hyper::Response::builder() - .status(StatusCode::OK) - .body(hyper::Body::from( - "SuccessYou may now \ - close this window.", - )); - - match response { - Ok(response) => InstalledFlowHandlerResponseFuture::new(Box::new( - futures::future::ok(response), - )), - Err(err) => InstalledFlowHandlerResponseFuture::new(Box::new( - futures::future::err(err), - )), - } - } - } - None => { - let response = hyper::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(hyper::Body::from("Invalid Request!")); - - match response { - Ok(response) => InstalledFlowHandlerResponseFuture::new(Box::new( - futures::future::ok(response), - )), - Err(err) => { - InstalledFlowHandlerResponseFuture::new(Box::new(futures::future::err(err))) - } + .body(hyper::Body::from("Unparseable URL")), + Ok(url) => match auth_code_from_url(url) { + Some(auth_code) => { + if let Some(sender) = auth_code_tx.lock().unwrap().take() { + let _ = sender.send(auth_code); + } + hyper::Response::builder().status(StatusCode::OK).body( + hyper::Body::from( + "SuccessYou may now \ + close this window.", + ), + ) + } + None => hyper::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(hyper::Body::from("No `code` in URL")), + }, } } + None => hyper::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(hyper::Body::from("Invalid Request!")), } } -} -impl InstalledFlowService { - fn handle_url(&mut self, url: hyper::Uri) { + fn auth_code_from_url(url: hyper::Uri) -> Option { // The provider redirects to the specified localhost URL, appending the authorization // code, like this: http://localhost:8080/xyz/?code=4/731fJ3BheyCouCniPufAd280GHNV5Ju35yYcGs - // We take that code and send it to the ask_authorization_code() function that - // waits for it. - for (param, val) in form_urlencoded::parse(url.query().unwrap_or("").as_bytes()) { - if param == "code".to_string() { - let mut auth_code_tx = self.auth_code_tx.lock().unwrap(); - match auth_code_tx.take() { - Some(auth_code_tx) => { - let _ = auth_code_tx.send(val.to_owned().to_string()); - } - None => { - // call to the server after a previous call. Each server is only designed - // to receive a single request. - } - }; + form_urlencoded::parse(url.query().unwrap_or("").as_bytes()).find_map(|(param, val)| { + if param == "code" { + Some(val.into_owned()) + } else { + None } - } + }) } } #[cfg(test)] mod tests { - use std::error::Error; - use std::fmt; - use std::str::FromStr; - - use hyper; - use hyper::client::connect::HttpConnector; - use hyper_rustls::HttpsConnector; - use mockito::{self, mock}; - use tokio; - use super::*; - use crate::authenticator::AuthFlow; - use crate::authenticator_delegate::FlowDelegate; - use crate::helper::*; - use crate::types::StringError; - - #[test] - fn test_end2end() { - #[derive(Clone)] - struct FD( - String, - hyper::Client, hyper::Body>, - ); - impl FlowDelegate for FD { - /// Depending on need_code, return the pre-set code or send the code to the server at - /// the redirect_uri given in the url. - fn present_user_url + fmt::Display>( - &mut self, - url: S, - need_code: bool, - ) -> Box, Error = Box> + Send> - { - if need_code { - Box::new(Ok(Some(self.0.clone())).into_future()) - } else { - // Parse presented url to obtain redirect_uri with location of local - // code-accepting server. - let uri = Uri::from_str(url.as_ref()).unwrap(); - let query = uri.query().unwrap(); - let parsed = form_urlencoded::parse(query.as_bytes()).into_owned(); - let mut rduri = None; - for (k, v) in parsed { - if k == "redirect_uri" { - rduri = Some(v); - break; - } - } - if rduri.is_none() { - return Box::new( - Err(Box::new(StringError::new("no redirect uri!", None)) - as Box) - .into_future(), - ); - } - let mut rduri = rduri.unwrap(); - rduri.push_str(&format!("?code={}", self.0)); - let rduri = Uri::from_str(rduri.as_ref()).unwrap(); - // Hit server. - return Box::new( - self.1 - .get(rduri) - .map_err(|e| Box::new(e) as Box) - .map(|_| None), - ); - } - } - } - - let server_url = mockito::server_url(); - let app_secret = r#"{"installed":{"client_id":"902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com","project_id":"yup-test-243420","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_secret":"iuMPN6Ne1PD7cos29Tk9rlqH","redirect_uris":["urn:ietf:wg:oauth:2.0:oob","http://localhost"]}}"#; - let mut app_secret = parse_application_secret(app_secret).unwrap(); - app_secret.token_uri = format!("{}/token", server_url); - - let https = HttpsConnector::new(1); - let client = hyper::Client::builder() - .keep_alive(false) - .build::<_, hyper::Body>(https); - - let fd = FD("authorizationcode".to_string(), client.clone()); - let mut inf = - InstalledFlow::new(app_secret.clone(), InstalledFlowReturnMethod::Interactive) - .delegate(fd) - .build_token_getter(client.clone()); - - let mut rt = tokio::runtime::Builder::new() - .core_threads(1) - .panic_handler(|e| std::panic::resume_unwind(e)) - .build() - .unwrap(); - - // Successful path. - { - let _m = mock("POST", "/token") - .match_body(mockito::Matcher::Regex(".*code=authorizationcode.*client_id=9022167.*".to_string())) - .with_body(r#"{"access_token": "accesstoken", "refresh_token": "refreshtoken", "token_type": "Bearer", "expires_in": 12345678}"#) - .expect(1) - .create(); - - let fut = inf - .token(vec!["https://googleapis.com/some/scope"]) - .and_then(|tok| { - assert_eq!("accesstoken", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); - Ok(()) - }); - rt.block_on(fut).expect("block on"); - _m.assert(); - } - // Successful path with HTTP redirect. - { - let mut inf = - InstalledFlow::new(app_secret, InstalledFlowReturnMethod::HTTPRedirect(8081)) - .delegate(FD( - "authorizationcodefromlocalserver".to_string(), - client.clone(), - )) - .build_token_getter(client.clone()); - let _m = mock("POST", "/token") - .match_body(mockito::Matcher::Regex(".*code=authorizationcodefromlocalserver.*client_id=9022167.*".to_string())) - .with_body(r#"{"access_token": "accesstoken", "refresh_token": "refreshtoken", "token_type": "Bearer", "expires_in": 12345678}"#) - .expect(1) - .create(); - - let fut = inf - .token(vec!["https://googleapis.com/some/scope"]) - .and_then(|tok| { - assert_eq!("accesstoken", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); - Ok(()) - }); - rt.block_on(fut).expect("block on"); - _m.assert(); - } - // Error from server. - { - let _m = mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*code=authorizationcode.*client_id=9022167.*".to_string(), - )) - .with_status(400) - .with_body(r#"{"error": "invalid_code"}"#) - .expect(1) - .create(); - - let fut = inf - .token(vec!["https://googleapis.com/some/scope"]) - .then(|tokr| { - assert!(tokr.is_err()); - assert!(format!("{}", tokr.unwrap_err()).contains("invalid_code")); - Ok(()) as Result<(), ()> - }); - rt.block_on(fut).expect("block on"); - _m.assert(); - } - rt.shutdown_on_idle().wait().expect("shutdown"); - } + use hyper::Uri; #[test] fn test_request_url_builder() { @@ -737,47 +361,59 @@ mod tests { "https://accounts.google.com/o/oauth2/auth", "812741506391-h38jh0j4fv0ce1krdkiq0hfvt6n5am\ rf.apps.googleusercontent.com", - vec![&"email".to_string(), &"profile".to_string()], + &["email", "profile"], None ) ); } - #[test] - fn test_server_random_local_port() { - let addr1 = InstalledFlowServer::new(0).unwrap(); - let addr2 = InstalledFlowServer::new(0).unwrap(); - assert_ne!(addr1.port, addr2.port); + #[tokio::test] + async fn test_server_random_local_port() { + let addr1 = InstalledFlowServer::run().unwrap().local_addr(); + let addr2 = InstalledFlowServer::run().unwrap().local_addr(); + assert_ne!(addr1.port(), addr2.port()); } - #[test] - fn test_http_handle_url() { + #[tokio::test] + async fn test_http_handle_url() { let (tx, rx) = oneshot::channel(); - let mut handler = InstalledFlowService { - auth_code_tx: Arc::new(Mutex::new(Option::Some(tx))), - }; // URLs are usually a bit botched let url: Uri = "http://example.com:1234/?code=ab/c%2Fd#".parse().unwrap(); - handler.handle_url(url); - assert_eq!(rx.wait().unwrap(), "ab/c/d".to_string()); + let req = hyper::Request::get(url) + .body(hyper::body::Body::empty()) + .unwrap(); + installed_flow_server::handle_req(req, Arc::new(Mutex::new(Some(tx)))) + .await + .unwrap(); + assert_eq!(rx.await.unwrap().as_str(), "ab/c/d"); } - #[test] - fn test_server() { - let runtime = tokio::runtime::Runtime::new().unwrap(); + #[tokio::test] + async fn test_server() { let client: hyper::Client = - hyper::Client::builder() - .executor(runtime.executor()) - .build_http(); - let mut server = InstalledFlowServer::new(0).unwrap(); + hyper::Client::builder().build_http(); + let server = InstalledFlowServer::run().unwrap(); + + let response = client + .get(format!("http://{}/", server.local_addr()).parse().unwrap()) + .await; + match response { + Result::Ok(_response) => { + // TODO: Do we really want this to assert success? + //assert!(response.status().is_success()); + } + Result::Err(err) => { + assert!(false, "Failed to request from local server: {:?}", err); + } + } let response = client .get( - format!("http://127.0.0.1:{}/", server.port) + format!("http://{}/?code=ab/c%2Fd#", server.local_addr()) .parse() .unwrap(), ) - .wait(); + .await; match response { Result::Ok(response) => { assert!(response.status().is_success()); @@ -787,29 +423,6 @@ mod tests { } } - let response = client - .get( - format!("http://127.0.0.1:{}/?code=ab/c%2Fd#", server.port) - .parse() - .unwrap(), - ) - .wait(); - match response { - Result::Ok(response) => { - assert!(response.status().is_success()); - } - Result::Err(err) => { - assert!(false, "Failed to request from local server: {:?}", err); - } - } - - match server.block_till_auth() { - Result::Ok(response) => { - assert_eq!(response, "ab/c/d".to_string()); - } - Result::Err(err) => { - assert!(false, "Server failed to pass on the message: {:?}", err); - } - } + assert_eq!(server.wait_for_auth_code().await.as_str(), "ab/c/d"); } } diff --git a/src/lib.rs b/src/lib.rs index 48e3b28..7fa2e58 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,72 +20,60 @@ //! based on the Google APIs; it may or may not work with other providers. //! //! # Installed Flow Usage -//! The `InstalledFlow` involves showing a URL to the user (or opening it in a browser) +//! The installed flow involves showing a URL to the user (or opening it in a browser) //! and then either prompting the user to enter a displayed code, or make the authorizing //! website redirect to a web server spun up by this library and running on localhost. //! -//! In order to use the interactive method, use the `InstalledInteractive` `FlowType`; -//! for the redirect method, use `InstalledRedirect`, with the port number to let the -//! server listen on. +//! In order to use the interactive method, use the `Interactive` `InstalledFlowReturnMethod`; +//! for the redirect method, use `HTTPRedirect`. //! //! You can implement your own `AuthenticatorDelegate` in order to customize the flow; -//! the `InstalledFlow` uses the `present_user_url` method. +//! the installed flow uses the `present_user_url` method. //! -//! The returned `Token` is stored permanently in the given token storage in order to -//! authorize future API requests to the same scopes. +//! The returned `Token` will be stored in memory in order to authorize future +//! API requests to the same scopes. The tokens can optionally be persisted to +//! disk by using `persist_tokens_to_disk` when creating the authenticator. //! //! The following example, which is derived from the (actual and runnable) example in //! `examples/test-installed/`, shows the basics of using this crate: //! //! ```test_harness,no_run -//! use futures::prelude::*; -//! use yup_oauth2::GetToken; -//! use yup_oauth2::{Authenticator, InstalledFlow}; +//! use yup_oauth2::{InstalledFlowAuthenticator, InstalledFlowReturnMethod}; //! -//! use hyper::client::Client; -//! use hyper_rustls::HttpsConnector; -//! -//! use std::path::Path; -//! -//! fn main() { +//! #[tokio::main] +//! async fn main() { //! // Read application secret from a file. Sometimes it's easier to compile it directly into //! // the binary. The clientsecret file contains JSON like `{"installed":{"client_id": ... }}` -//! let secret = yup_oauth2::read_application_secret(Path::new("clientsecret.json")) +//! let secret = yup_oauth2::read_application_secret("clientsecret.json") +//! .await //! .expect("clientsecret.json"); //! //! // Create an authenticator that uses an InstalledFlow to authenticate. The -//! // authentication tokens are persisted to a file named tokencache.json. The -//! // authenticator takes care of caching tokens to disk and refreshing tokens once -//! // they've expired. -//! let mut auth = Authenticator::new( -//! InstalledFlow::new(secret, yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect(0)) -//! ) +//! // authentication tokens are persisted to a file named tokencache.json. The +//! // authenticator takes care of caching tokens to disk and refreshing tokens once +//! // they've expired. +//! let mut auth = InstalledFlowAuthenticator::builder(secret, InstalledFlowReturnMethod::HTTPRedirect) //! .persist_tokens_to_disk("tokencache.json") //! .build() +//! .await //! .unwrap(); //! -//! let s = "https://www.googleapis.com/auth/drive.file".to_string(); -//! let scopes = vec![s]; +//! let scopes = &["https://www.googleapis.com/auth/drive.file"]; //! //! // token() is the one important function of this crate; it does everything to //! // obtain a token that can be sent e.g. as Bearer token. -//! let tok = auth.token(scopes); -//! // Finally we print the token. -//! let fut = tok.map_err(|e| println!("error: {:?}", e)).and_then(|t| { -//! println!("The token is {:?}", t); -//! Ok(()) -//! }); -//! -//! tokio::run(fut) +//! match auth.token(scopes).await { +//! Ok(token) => println!("The token is {:?}", token), +//! Err(e) => println!("error: {:?}", e), +//! } //! } //! ``` //! -#[macro_use] -extern crate serde_derive; - -mod authenticator; -mod authenticator_delegate; +#![deny(missing_docs)] +pub mod authenticator; +pub mod authenticator_delegate; mod device; +pub mod error; mod helper; mod installed; mod refresh; @@ -93,17 +81,16 @@ mod service_account; mod storage; mod types; -pub use crate::authenticator::{AuthFlow, Authenticator}; -pub use crate::authenticator_delegate::{ - AuthenticatorDelegate, DefaultAuthenticatorDelegate, DefaultFlowDelegate, FlowDelegate, - PollInformation, +#[doc(inline)] +pub use crate::authenticator::{ + DeviceFlowAuthenticator, InstalledFlowAuthenticator, ServiceAccountAuthenticator, }; -pub use crate::device::{DeviceFlow, GOOGLE_DEVICE_CODE_URL}; + pub use crate::helper::*; -pub use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; -pub use crate::service_account::*; -pub use crate::storage::{DiskTokenStorage, MemoryStorage, NullStorage, TokenStorage}; -pub use crate::types::{ - ApplicationSecret, ConsoleApplicationSecret, FlowType, GetToken, PollError, RefreshResult, - RequestError, Scheme, Token, TokenType, -}; +pub use crate::installed::InstalledFlowReturnMethod; + +pub use crate::service_account::ServiceAccountKey; + +#[doc(inline)] +pub use crate::error::Error; +pub use crate::types::{AccessToken, ApplicationSecret, ConsoleApplicationSecret}; diff --git a/src/refresh.rs b/src/refresh.rs index 1175ecd..53d12fa 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -1,12 +1,7 @@ -use crate::types::{ApplicationSecret, JsonError, RefreshResult, RequestError}; +use crate::error::Error; +use crate::types::{ApplicationSecret, TokenInfo}; -use super::Token; -use chrono::Utc; -use futures::stream::Stream; -use futures::Future; -use hyper; use hyper::header; -use serde_json as json; use url::form_urlencoded; /// Implements the [OAuth2 Refresh Token Flow](https://developers.google.com/youtube/v3/guides/authentication#devices). @@ -14,7 +9,7 @@ use url::form_urlencoded; /// Refresh an expired access token, as obtained by any other authentication flow. /// This flow is useful when your `Token` is expired and allows to obtain a new /// and valid access token. -pub struct RefreshFlow; +pub(crate) struct RefreshFlow; impl RefreshFlow { /// Attempt to refresh the given token, and obtain a new, valid one. @@ -31,155 +26,41 @@ impl RefreshFlow { /// /// # Examples /// Please see the crate landing page for an example. - pub fn refresh_token<'a, C: 'static + hyper::client::connect::Connect>( - client: hyper::Client, - client_secret: ApplicationSecret, - refresh_token: String, - ) -> impl 'a + Future { + pub(crate) async fn refresh_token( + client: &hyper::Client, + client_secret: &ApplicationSecret, + refresh_token: &str, + ) -> Result + where + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, + { + log::debug!( + "refreshing access token with refresh token: {}", + refresh_token + ); let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ - ("client_id", client_secret.client_id.clone()), - ("client_secret", client_secret.client_secret.clone()), - ("refresh_token", refresh_token.to_string()), - ("grant_type", "refresh_token".to_string()), + ("client_id", client_secret.client_id.as_str()), + ("client_secret", client_secret.client_secret.as_str()), + ("refresh_token", refresh_token), + ("grant_type", "refresh_token"), ]) .finish(); - let request = hyper::Request::post(client_secret.token_uri.clone()) + let request = hyper::Request::post(&client_secret.token_uri) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(req)) - .unwrap(); // TODO: error handling - - client - .request(request) - .then(|r| { - match r { - Err(err) => return Err(RefreshResult::Error(err)), - Ok(res) => { - Ok(res - .into_body() - .concat2() - .wait() - .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) - .unwrap()) // TODO: error handling - } - } - }) - .then(move |maybe_json_str: Result| { - if let Err(e) = maybe_json_str { - return Ok(e); - } - let json_str = maybe_json_str.unwrap(); - #[derive(Deserialize)] - struct JsonToken { - access_token: String, - token_type: String, - expires_in: i64, - } - - match json::from_str::(&json_str) { - Err(_) => {} - Ok(res) => { - return Ok(RefreshResult::RefreshError( - res.error, - res.error_description, - )) - } - } - - let t: JsonToken = json::from_str(&json_str).unwrap(); - Ok(RefreshResult::Success(Token { - access_token: t.access_token, - token_type: t.token_type, - refresh_token: Some(refresh_token.to_string()), - expires_in: None, - expires_in_timestamp: Some(Utc::now().timestamp() + t.expires_in), - })) - }) - .map_err(RequestError::Refresh) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::helper; - - use hyper; - use hyper_rustls::HttpsConnector; - use mockito; - use tokio; - - #[test] - fn test_refresh_end2end() { - let server_url = mockito::server_url(); - - let app_secret = r#"{"installed":{"client_id":"902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com","project_id":"yup-test-243420","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_secret":"iuMPN6Ne1PD7cos29Tk9rlqH","redirect_uris":["urn:ietf:wg:oauth:2.0:oob","http://localhost"]}}"#; - let mut app_secret = helper::parse_application_secret(app_secret).unwrap(); - app_secret.token_uri = format!("{}/token", server_url); - let refresh_token = "my-refresh-token".to_string(); - - let https = HttpsConnector::new(1); - let client = hyper::Client::builder() - .keep_alive(false) - .build::<_, hyper::Body>(https); - - let mut rt = tokio::runtime::Builder::new() - .core_threads(1) - .panic_handler(|e| std::panic::resume_unwind(e)) - .build() .unwrap(); - - // Success - { - let _m = mockito::mock("POST", "/token") - .match_body( - mockito::Matcher::Regex(".*client_id=902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com.*refresh_token=my-refresh-token.*".to_string())) - .with_status(200) - .with_body(r#"{"access_token": "new-access-token", "token_type": "Bearer", "expires_in": 1234567}"#) - .create(); - let fut = RefreshFlow::refresh_token( - client.clone(), - app_secret.clone(), - refresh_token.clone(), - ) - .then(|rr| { - let rr = rr.unwrap(); - match rr { - RefreshResult::Success(tok) => { - assert_eq!("new-access-token", tok.access_token); - assert_eq!("Bearer", tok.token_type); - } - _ => panic!(format!("unexpected RefreshResult {:?}", rr)), - } - Ok(()) as Result<(), ()> - }); - - rt.block_on(fut).expect("block_on"); - _m.assert(); - } - // Refresh error. - { - let _m = mockito::mock("POST", "/token") - .match_body( - mockito::Matcher::Regex(".*client_id=902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com.*refresh_token=my-refresh-token.*".to_string())) - .with_status(400) - .with_body(r#"{"error": "invalid_token"}"#) - .create(); - - let fut = RefreshFlow::refresh_token(client, app_secret, refresh_token).then(|rr| { - let rr = rr.unwrap(); - match rr { - RefreshResult::RefreshError(e, None) => { - assert_eq!(e, "invalid_token"); - } - _ => panic!(format!("unexpected RefreshResult {:?}", rr)), - } - Ok(()) - }); - - tokio::run(fut); - _m.assert(); - } + log::debug!("Sending request: {:?}", request); + let (head, body) = client.request(request).await?.into_parts(); + let body = hyper::body::to_bytes(body).await?; + log::debug!("Received response; head: {:?}, body: {:?}", head, body); + let mut token = TokenInfo::from_json(&body)?; + // If the refresh result contains a refresh_token use it, otherwise + // continue using our previous refresh_token. + token + .refresh_token + .get_or_insert_with(|| refresh_token.to_owned()); + Ok(token) } } diff --git a/src/service_account.rs b/src/service_account.rs index b61448f..81d4c6e 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -11,551 +11,237 @@ //! Copyright (c) 2016 Google Inc (lewinb@google.com). //! -use std::default::Default; -use std::sync::{Arc, Mutex}; +use crate::error::Error; +use crate::types::TokenInfo; -use crate::authenticator::{DefaultHyperClient, HyperClientBuilder}; -use crate::storage::{hash_scopes, MemoryStorage, TokenStorage}; -use crate::types::{ApplicationSecret, GetToken, JsonError, RequestError, StringError, Token}; +use std::io; -use futures::stream::Stream; -use futures::{future, prelude::*}; use hyper::header; -use url::form_urlencoded; - use rustls::{ self, internal::pemfile, sign::{self, SigningKey}, PrivateKey, }; -use std::io; +use serde::{Deserialize, Serialize}; +use url::form_urlencoded; -use base64; -use chrono; -use hyper; -use serde_json; - -const GRANT_TYPE: &'static str = "urn:ietf:params:oauth:grant-type:jwt-bearer"; -const GOOGLE_RS256_HEAD: &'static str = "{\"alg\":\"RS256\",\"typ\":\"JWT\"}"; +const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer"; +const GOOGLE_RS256_HEAD: &str = r#"{"alg":"RS256","typ":"JWT"}"#; /// Encodes s as Base64 -fn encode_base64>(s: T) -> String { - base64::encode_config(s.as_ref(), base64::URL_SAFE) +fn append_base64 + ?Sized>(s: &T, out: &mut String) { + base64::encode_config_buf(s, base64::URL_SAFE, out) } /// Decode a PKCS8 formatted RSA key. fn decode_rsa_key(pem_pkcs8: &str) -> Result { - let private = pem_pkcs8.to_string().replace("\\n", "\n").into_bytes(); - let mut private_reader: &[u8] = private.as_ref(); - let private_keys = pemfile::pkcs8_private_keys(&mut private_reader); + let private_keys = pemfile::pkcs8_private_keys(&mut pem_pkcs8.as_bytes()); - if let Ok(pk) = private_keys { - if pk.len() > 0 { - Ok(pk[0].clone()) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Not enough private keys in PEM", - )) + match private_keys { + Ok(mut keys) if !keys.is_empty() => { + keys.truncate(1); + Ok(keys.remove(0)) } - } else { - Err(io::Error::new( + Ok(_) => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Not enough private keys in PEM", + )), + Err(_) => Err(io::Error::new( io::ErrorKind::InvalidInput, "Error reading key from PEM", - )) + )), } } /// JSON schema of secret service account key. You can obtain the key from /// the Cloud Console at https://console.cloud.google.com/. /// -/// You can use `helpers::service_account_key_from_file()` as a quick way to read a JSON client +/// You can use `helpers::read_service_account_key()` as a quick way to read a JSON client /// secret into a ServiceAccountKey. #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ServiceAccountKey { #[serde(rename = "type")] + /// key_type pub key_type: Option, + /// project_id pub project_id: Option, + /// private_key_id pub private_key_id: Option, - pub private_key: Option, - pub client_email: Option, + /// private_key + pub private_key: String, + /// client_email + pub client_email: String, + /// client_id pub client_id: Option, + /// auth_uri pub auth_uri: Option, - pub token_uri: Option, + /// token_uri + pub token_uri: String, + /// auth_provider_x509_cert_url pub auth_provider_x509_cert_url: Option, + /// client_x509_cert_url pub client_x509_cert_url: Option, } /// Permissions requested for a JWT. /// See https://developers.google.com/identity/protocols/OAuth2ServiceAccount#authorizingrequests. #[derive(Serialize, Debug)] -struct Claims { - iss: String, - aud: String, +struct Claims<'a> { + iss: &'a str, + aud: &'a str, exp: i64, iat: i64, - sub: Option, + subject: Option<&'a str>, scope: String, } -/// A JSON Web Token ready for signing. -struct JWT { - /// The value of GOOGLE_RS256_HEAD. - header: String, - /// A Claims struct, expressing the set of desired permissions etc. - claims: Claims, -} +impl<'a> Claims<'a> { + fn new(key: &'a ServiceAccountKey, scopes: &[T], subject: Option<&'a str>) -> Self + where + T: AsRef, + { + let iat = chrono::Utc::now().timestamp(); + let expiry = iat + 3600 - 5; // Max validity is 1h. -impl JWT { - /// Create a new JWT from claims. - fn new(claims: Claims) -> JWT { - JWT { - header: GOOGLE_RS256_HEAD.to_string(), - claims: claims, + let scope = crate::helper::join(scopes, " "); + Claims { + iss: &key.client_email, + aud: &key.token_uri, + exp: expiry, + iat, + subject, + scope, } } +} - /// Set JWT header. Default is `{"alg":"RS256","typ":"JWT"}`. - #[allow(dead_code)] - pub fn set_header(&mut self, head: String) { - self.header = head; - } +/// A JSON Web Token ready for signing. +pub(crate) struct JWTSigner { + signer: Box, +} - /// Encodes the first two parts (header and claims) to base64 and assembles them into a form - /// ready to be signed. - fn encode_claims(&self) -> String { - let mut head = encode_base64(&self.header); - let claims = encode_base64(serde_json::to_string(&self.claims).unwrap()); - - head.push_str("."); - head.push_str(&claims); - head - } - - /// Sign a JWT base string with `private_key`, which is a PKCS8 string. - fn sign(&self, private_key: &str) -> Result { - let mut jwt_head = self.encode_claims(); +impl JWTSigner { + fn new(private_key: &str) -> Result { let key = decode_rsa_key(private_key)?; let signing_key = sign::RSASigningKey::new(&key) .map_err(|_| io::Error::new(io::ErrorKind::Other, "Couldn't initialize signer"))?; let signer = signing_key .choose_scheme(&[rustls::SignatureScheme::RSA_PKCS1_SHA256]) - .ok_or(io::Error::new( - io::ErrorKind::Other, - "Couldn't choose signing scheme", - ))?; - let signature = signer - .sign(jwt_head.as_bytes()) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("{}", e)))?; - let signature_b64 = encode_base64(signature); + .ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "Couldn't choose signing scheme") + })?; + Ok(JWTSigner { signer }) + } + fn sign_claims(&self, claims: &Claims) -> Result { + let mut jwt_head = Self::encode_claims(claims); + let signature = self.signer.sign(jwt_head.as_bytes())?; jwt_head.push_str("."); - jwt_head.push_str(&signature_b64); - + append_base64(&signature, &mut jwt_head); Ok(jwt_head) } -} -/// Set `iss`, `aud`, `exp`, `iat`, `scope` field in the returned `Claims`. `scopes` is an iterator -/// yielding strings with OAuth scopes. -fn init_claims_from_key<'a, I, T>(key: &ServiceAccountKey, scopes: I) -> Claims -where - T: AsRef + 'a, - I: IntoIterator, -{ - let iat = chrono::Utc::now().timestamp(); - let expiry = iat + 3600 - 5; // Max validity is 1h. - - let mut scopes_string = scopes.into_iter().fold(String::new(), |mut acc, sc| { - acc.push_str(sc.as_ref()); - acc.push_str(" "); - acc - }); - scopes_string.pop(); - - Claims { - iss: key.client_email.clone().unwrap(), - aud: key.token_uri.clone().unwrap(), - exp: expiry, - iat: iat, - sub: None, - scope: scopes_string, + /// Encodes the first two parts (header and claims) to base64 and assembles them into a form + /// ready to be signed. + fn encode_claims(claims: &Claims) -> String { + let mut head = String::new(); + append_base64(GOOGLE_RS256_HEAD, &mut head); + head.push_str("."); + append_base64(&serde_json::to_string(&claims).unwrap(), &mut head); + head } } -/// A token source (`GetToken`) yielding OAuth tokens for services that use ServiceAccount authorization. -/// This token source caches token and automatically renews expired ones, meaning you do not need -/// (and you also should not) use this with `Authenticator`. Just use it directly. -#[derive(Clone)] -pub struct ServiceAccountAccess { - client: C, +pub struct ServiceAccountFlowOpts { + pub(crate) key: ServiceAccountKey, + pub(crate) subject: Option, +} + +/// ServiceAccountFlow can fetch oauth tokens using a service account. +pub struct ServiceAccountFlow { key: ServiceAccountKey, - sub: Option, + subject: Option, + signer: JWTSigner, } -impl ServiceAccountAccess { - /// Create a new ServiceAccountAccess with the provided key. - pub fn new(key: ServiceAccountKey) -> Self { - ServiceAccountAccess { - client: DefaultHyperClient, - key, - sub: None, - } - } -} - -impl ServiceAccountAccess -where - C: HyperClientBuilder, - C::Connector: 'static, -{ - /// Use the provided hyper client. - pub fn hyper_client( - self, - hyper_client: NewC, - ) -> ServiceAccountAccess { - ServiceAccountAccess { - client: hyper_client, - key: self.key, - sub: self.sub, - } +impl ServiceAccountFlow { + pub(crate) fn new(opts: ServiceAccountFlowOpts) -> Result { + let signer = JWTSigner::new(&opts.key.private_key)?; + Ok(ServiceAccountFlow { + key: opts.key, + subject: opts.subject, + signer, + }) } - /// Use the provided sub. - pub fn sub(self, sub: String) -> Self { - ServiceAccountAccess { - sub: Some(sub), - ..self - } - } - - /// Build the configured ServiceAccountAccess. - pub fn build(self) -> impl GetToken { - ServiceAccountAccessImpl::new(self.client.build_hyper_client(), self.key, self.sub) - } -} - -#[derive(Clone)] -struct ServiceAccountAccessImpl { - client: hyper::Client, - key: ServiceAccountKey, - cache: Arc>, - sub: Option, -} - -impl ServiceAccountAccessImpl -where - C: hyper::client::connect::Connect, -{ - fn new(client: hyper::Client, key: ServiceAccountKey, sub: Option) -> Self { - ServiceAccountAccessImpl { - client, - key, - cache: Arc::new(Mutex::new(MemoryStorage::default())), - sub, - } - } -} - -/// This is the schema of the server's response. -#[derive(Deserialize, Debug)] -struct TokenResponse { - access_token: Option, - token_type: Option, - expires_in: Option, -} - -impl TokenResponse { - fn to_oauth_token(self) -> Token { - let expires_ts = chrono::Utc::now().timestamp() + self.expires_in.unwrap_or(0); - - Token { - access_token: self.access_token.unwrap(), - token_type: self.token_type.unwrap(), - refresh_token: Some(String::new()), - expires_in: self.expires_in, - expires_in_timestamp: Some(expires_ts), - } - } -} - -impl<'a, C: 'static + hyper::client::connect::Connect> ServiceAccountAccessImpl { /// Send a request for a new Bearer token to the OAuth provider. - fn request_token( - client: hyper::client::Client, - sub: Option, - key: ServiceAccountKey, - scopes: Vec, - ) -> impl Future { - let mut claims = init_claims_from_key(&key, &scopes); - claims.sub = sub.clone(); - let signed = JWT::new(claims) - .sign(key.private_key.as_ref().unwrap()) - .into_future(); - signed - .map_err(RequestError::LowLevelError) - .map(|signed| { - form_urlencoded::Serializer::new(String::new()) - .extend_pairs(vec![ - ("grant_type".to_string(), GRANT_TYPE.to_string()), - ("assertion".to_string(), signed), - ]) - .finish() - }) - .map(|rqbody| { - hyper::Request::post(key.token_uri.unwrap()) - .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") - .body(hyper::Body::from(rqbody)) - .unwrap() - }) - .and_then(move |request| client.request(request).map_err(RequestError::ClientError)) - .and_then(|response| { - response - .into_body() - .concat2() - .map_err(RequestError::ClientError) - }) - .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) - .and_then(|s| { - if let Ok(jse) = serde_json::from_str::(&s) { - Err(RequestError::NegativeServerResponse( - jse.error, - jse.error_description, - )) - } else { - serde_json::from_str(&s).map_err(RequestError::JSONError) - } - }) - .then(|token: Result| match token { - Err(e) => return Err(e), - Ok(token) => { - if token.access_token.is_none() - || token.token_type.is_none() - || token.expires_in.is_none() - { - Err(RequestError::BadServerResponse(format!( - "Token response lacks fields: {:?}", - token - ))) - } else { - Ok(token.to_oauth_token()) - } - } - }) - } -} - -impl GetToken for ServiceAccountAccessImpl -where - C: hyper::client::connect::Connect, -{ - fn token( - &mut self, - scopes: I, - ) -> Box + Send> + pub(crate) async fn token( + &self, + hyper_client: &hyper::Client, + scopes: &[T], + ) -> Result where - T: Into, - I: IntoIterator, + T: AsRef, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { - let (hash, scps0) = hash_scopes(scopes); - let cache = self.cache.clone(); - let scps = scps0.clone(); - - let cache_lookup = futures::lazy(move || { - match cache - .lock() - .unwrap() - .get(hash, &scps.iter().map(|s| s.as_str()).collect()) - { - Ok(Some(token)) => { - if !token.expired() { - return Ok(token); - } - return Err(StringError::new("expired token in cache", None)); - } - Err(e) => return Err(StringError::new(format!("cache lookup error: {}", e), None)), - Ok(None) => return Err(StringError::new("no token in cache", None)), - } - }); - - let cache = self.cache.clone(); - let req_token = Self::request_token( - self.client.clone(), - self.sub.clone(), - self.key.clone(), - scps0.iter().map(|s| s.to_string()).collect(), - ) - .then(move |r| match r { - Ok(token) => { - let _ = cache.lock().unwrap().set( - hash, - &scps0.iter().map(|s| s.as_str()).collect(), - Some(token.clone()), - ); - Box::new(future::ok(token)) - } - Err(e) => Box::new(future::err(e)), - }); - - Box::new(cache_lookup.then(|r| match r { - Ok(t) => Box::new(Ok(t).into_future()) - as Box + Send>, - Err(_) => { - Box::new(req_token) as Box + Send> - } - })) - } - - /// Returns an empty ApplicationSecret as tokens for service accounts don't need to be - /// refreshed (they are simply reissued). - fn application_secret(&self) -> ApplicationSecret { - Default::default() - } - - fn api_key(&mut self) -> Option { - None + let claims = Claims::new(&self.key, scopes, self.subject.as_ref().map(|x| x.as_str())); + let signed = self.signer.sign_claims(&claims).map_err(|_| { + Error::LowLevelError(io::Error::new( + io::ErrorKind::Other, + "unable to sign claims", + )) + })?; + let rqbody = form_urlencoded::Serializer::new(String::new()) + .extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", signed.as_str())]) + .finish(); + let request = hyper::Request::post(&self.key.token_uri) + .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") + .body(hyper::Body::from(rqbody)) + .unwrap(); + log::debug!("requesting token from service account: {:?}", request); + let (head, body) = hyper_client.request(request).await?.into_parts(); + let body = hyper::body::to_bytes(body).await?; + log::debug!("received response; head: {:?}, body: {:?}", head, body); + TokenInfo::from_json(&body) } } #[cfg(test)] mod tests { use super::*; - use crate::helper::service_account_key_from_file; - use crate::types::GetToken; - - use hyper; + use crate::helper::read_service_account_key; use hyper_rustls::HttpsConnector; - use mockito::{self, mock}; - use tokio; - - #[test] - fn test_mocked_http() { - env_logger::try_init().unwrap(); - let server_url = &mockito::server_url(); - let client_secret = r#"{ - "type": "service_account", - "project_id": "yup-test-243420", - "private_key_id": "26de294916614a5ebdf7a065307ed3ea9941902b", - "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDemmylrvp1KcOn\n9yTAVVKPpnpYznvBvcAU8Qjwr2fSKylpn7FQI54wCk5VJVom0jHpAmhxDmNiP8yv\nHaqsef+87Oc0n1yZ71/IbeRcHZc2OBB33/LCFqf272kThyJo3qspEqhuAw0e8neg\nLQb4jpm9PsqR8IjOoAtXQSu3j0zkXemMYFy93PWHjVpPEUX16NGfsWH7oxspBHOk\n9JPGJL8VJdbiAoDSDgF0y9RjJY5I52UeHNhMsAkTYs6mIG4kKXt2+T9tAyHw8aho\nwmuytQAfydTflTfTG8abRtliF3nil2taAc5VB07dP1b4dVYy/9r6M8Z0z4XM7aP+\nNdn2TKm3AgMBAAECggEAWi54nqTlXcr2M5l535uRb5Xz0f+Q/pv3ceR2iT+ekXQf\n+mUSShOr9e1u76rKu5iDVNE/a7H3DGopa7ZamzZvp2PYhSacttZV2RbAIZtxU6th\n7JajPAM+t9klGh6wj4jKEcE30B3XVnbHhPJI9TCcUyFZoscuPXt0LLy/z8Uz0v4B\nd5JARwyxDMb53VXwukQ8nNY2jP7WtUig6zwE5lWBPFMbi8GwGkeGZOruAK5sPPwY\nGBAlfofKANI7xKx9UXhRwisB4+/XI1L0Q6xJySv9P+IAhDUI6z6kxR+WkyT/YpG3\nX9gSZJc7qEaxTIuDjtep9GTaoEqiGntjaFBRKoe+VQKBgQDzM1+Ii+REQqrGlUJo\nx7KiVNAIY/zggu866VyziU6h5wjpsoW+2Npv6Dv7nWvsvFodrwe50Y3IzKtquIal\nVd8aa50E72JNImtK/o5Nx6xK0VySjHX6cyKENxHRDnBmNfbALRM+vbD9zMD0lz2q\nmns/RwRGq3/98EqxP+nHgHSr9QKBgQDqUYsFAAfvfT4I75Glc9svRv8IsaemOm07\nW1LCwPnj1MWOhsTxpNF23YmCBupZGZPSBFQobgmHVjQ3AIo6I2ioV6A+G2Xq/JCF\nmzfbvZfqtbbd+nVgF9Jr1Ic5T4thQhAvDHGUN77BpjEqZCQLAnUWJx9x7e2xvuBl\n1A6XDwH/ewKBgQDv4hVyNyIR3nxaYjFd7tQZYHTOQenVffEAd9wzTtVbxuo4sRlR\nNM7JIRXBSvaATQzKSLHjLHqgvJi8LITLIlds1QbNLl4U3UVddJbiy3f7WGTqPFfG\nkLhUF4mgXpCpkMLxrcRU14Bz5vnQiDmQRM4ajS7/kfwue00BZpxuZxst3QKBgQCI\nRI3FhaQXyc0m4zPfdYYVc4NjqfVmfXoC1/REYHey4I1XetbT9Nb/+ow6ew0UbgSC\nUZQjwwJ1m1NYXU8FyovVwsfk9ogJ5YGiwYb1msfbbnv/keVq0c/Ed9+AG9th30qM\nIf93hAfClITpMz2mzXIMRQpLdmQSR4A2l+E4RjkSOwKBgQCB78AyIdIHSkDAnCxz\nupJjhxEhtQ88uoADxRoEga7H/2OFmmPsqfytU4+TWIdal4K+nBCBWRvAX1cU47vH\nJOlSOZI0gRKe0O4bRBQc8GXJn/ubhYSxI02IgkdGrIKpOb5GG10m85ZvqsXw3bKn\nRVHMD0ObF5iORjZUqD0yRitAdg==\n-----END PRIVATE KEY-----\n", - "client_email": "yup-test-sa-1@yup-test-243420.iam.gserviceaccount.com", - "client_id": "102851967901799660408", - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": "", - "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/yup-test-sa-1%40yup-test-243420.iam.gserviceaccount.com" -}"#; - let mut key: ServiceAccountKey = serde_json::from_str(client_secret).unwrap(); - key.token_uri = Some(format!("{}/token", server_url)); - - let json_response = r#"{ - "access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk", - "expires_in": 3600, - "token_type": "Bearer" -}"#; - let bad_json_response = r#"{ - "access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk", - "token_type": "Bearer" -}"#; - - let https = HttpsConnector::new(1); - let client = hyper::Client::builder() - .keep_alive(false) - .build::<_, hyper::Body>(https); - let mut rt = tokio::runtime::Builder::new() - .core_threads(1) - .panic_handler(|e| std::panic::resume_unwind(e)) - .build() - .unwrap(); - - // Successful path. - { - let _m = mock("POST", "/token") - .with_status(200) - .with_header("content-type", "text/json") - .with_body(json_response) - .expect(1) - .create(); - let mut acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None); - let fut = acc - .token(vec!["https://www.googleapis.com/auth/pubsub"]) - .and_then(|tok| { - assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); - assert_eq!(Some(3600), tok.expires_in); - Ok(()) - }); - rt.block_on(fut).expect("block_on"); - - assert!(acc - .cache - .lock() - .unwrap() - .get( - 3502164897243251857, - &vec!["https://www.googleapis.com/auth/pubsub"] - ) - .unwrap() - .is_some()); - // Test that token is in cache (otherwise mock will tell us) - let fut = acc - .token(vec!["https://www.googleapis.com/auth/pubsub"]) - .and_then(|tok| { - assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); - assert_eq!(Some(3600), tok.expires_in); - Ok(()) - }); - rt.block_on(fut).expect("block_on 2"); - - _m.assert(); - } - // Malformed response. - { - let _m = mock("POST", "/token") - .with_status(200) - .with_header("content-type", "text/json") - .with_body(bad_json_response) - .create(); - let mut acc = ServiceAccountAccess::new(key.clone()) - .hyper_client(client.clone()) - .build(); - let fut = acc - .token(vec!["https://www.googleapis.com/auth/pubsub"]) - .then(|result| { - assert!(result.is_err()); - Ok(()) as Result<(), ()> - }); - rt.block_on(fut).expect("block_on"); - _m.assert(); - } - rt.shutdown_on_idle().wait().expect("shutdown"); - } // Valid but deactivated key. const TEST_PRIVATE_KEY_PATH: &'static str = "examples/Sanguine-69411a0c0eea.json"; // Uncomment this test to verify that we can successfully obtain tokens. - //#[test] + //#[tokio::test] #[allow(dead_code)] - fn test_service_account_e2e() { - let key = service_account_key_from_file(&TEST_PRIVATE_KEY_PATH.to_string()).unwrap(); - let https = HttpsConnector::new(4); - let runtime = tokio::runtime::Runtime::new().unwrap(); + async fn test_service_account_e2e() { + let key = read_service_account_key(TEST_PRIVATE_KEY_PATH) + .await + .unwrap(); + let acc = ServiceAccountFlow::new(ServiceAccountFlowOpts { key, subject: None }).unwrap(); + let https = HttpsConnector::new(); let client = hyper::Client::builder() - .executor(runtime.executor()) - .build(https); - let mut acc = ServiceAccountAccess::new(key).hyper_client(client).build(); + .keep_alive(false) + .build::<_, hyper::Body>(https); println!( "{:?}", - acc.token(vec!["https://www.googleapis.com/auth/pubsub"]) - .wait() + acc.token(&client, &["https://www.googleapis.com/auth/pubsub"]) + .await ); } - #[test] - fn test_jwt_initialize_claims() { - let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap(); + #[tokio::test] + async fn test_jwt_initialize_claims() { + let key = read_service_account_key(TEST_PRIVATE_KEY_PATH) + .await + .unwrap(); let scopes = vec!["scope1", "scope2", "scope3"]; - let claims = super::init_claims_from_key(&key, &scopes); + let claims = Claims::new(&key, &scopes, None); assert_eq!( claims.iss, @@ -571,13 +257,15 @@ mod tests { assert_eq!(claims.exp - claims.iat, 3595); } - #[test] - fn test_jwt_sign() { - let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap(); + #[tokio::test] + async fn test_jwt_sign() { + let key = read_service_account_key(TEST_PRIVATE_KEY_PATH) + .await + .unwrap(); let scopes = vec!["scope1", "scope2", "scope3"]; - let claims = super::init_claims_from_key(&key, &scopes); - let jwt = super::JWT::new(claims); - let signature = jwt.sign(key.private_key.as_ref().unwrap()); + let signer = JWTSigner::new(&key.private_key).unwrap(); + let claims = Claims::new(&key, &scopes, None); + let signature = signer.sign_claims(&claims); assert!(signature.is_ok()); diff --git a/src/storage.rs b/src/storage.rs index 950d015..4acc30a 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -2,291 +2,430 @@ // // See project root for licensing information. // +use crate::types::TokenInfo; -use std::cmp::Ordering; -use std::collections::hash_map::DefaultHasher; -use std::error::Error; -use std::fmt; -use std::fs; -use std::hash::{Hash, Hasher}; +use std::collections::HashMap; use std::io; -use std::io::{Read, Write}; +use std::path::{Path, PathBuf}; +use std::sync::Mutex; -use crate::types::Token; -use itertools::Itertools; +use serde::{Deserialize, Serialize}; -/// Implements a specialized storage to set and retrieve `Token` instances. -/// The `scope_hash` represents the signature of the scopes for which the given token -/// should be stored or retrieved. -/// For completeness, the underlying, sorted scopes are provided as well. They might be -/// useful for presentation to the user. -pub trait TokenStorage { - type Error: 'static + Error + Send + Sync; +// The storage layer allows retrieving tokens for scopes that have been +// previously granted tokens. One wrinkle is that a token granted for a set +// of scopes X is also valid for any subset of X's scopes. So when retrieving a +// token for a set of scopes provided by the caller it's beneficial to compare +// that set to all previously stored tokens to see if it is a subset of any +// existing set. To do this efficiently we store a bloom filter along with each +// token that represents the set of scopes the token is associated with. The +// bloom filter allows for efficiently skipping any entries that are +// definitively not a superset. +// The current implementation uses a 64bit bloom filter with 4 hash functions. - /// If `token` is None, it is invalid or revoked and should be removed from storage. - /// Otherwise, it should be saved. - fn set( - &mut self, - scope_hash: u64, - scopes: &Vec<&str>, - token: Option, - ) -> Result<(), Self::Error>; - /// A `None` result indicates that there is no token for the given scope_hash. - fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, Self::Error>; +/// ScopeHash is a hash value derived from a list of scopes. The hash value +/// represents a fingerprint of the set of scopes *independent* of the ordering. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +struct ScopeHash(u64); + +/// ScopeFilter represents a filter for a set of scopes. It can definitively +/// prove that a given list of scopes is not a subset of another. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +struct ScopeFilter(u64); + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +enum FilterResponse { + Maybe, + No, } -/// Calculate a hash value describing the scopes, and return a sorted Vec of the scopes. -pub fn hash_scopes(scopes: I) -> (u64, Vec) -where - T: Into, - I: IntoIterator, -{ - let mut sv: Vec = scopes.into_iter().map(Into::into).collect(); - sv.sort(); - let mut sh = DefaultHasher::new(); - sv.hash(&mut sh); - (sh.finish(), sv) +impl ScopeFilter { + /// Determine if this ScopeFilter could be a subset of the provided filter. + fn is_subset_of(self, filter: ScopeFilter) -> FilterResponse { + if self.0 & filter.0 == self.0 { + FilterResponse::Maybe + } else { + FilterResponse::No + } + } } -/// A storage that remembers nothing. -#[derive(Default)] -pub struct NullStorage; - #[derive(Debug)] -pub struct NullError; - -impl Error for NullError { - fn description(&self) -> &str { - "NULL" - } +pub(crate) struct ScopeSet<'a, T> { + hash: ScopeHash, + filter: ScopeFilter, + scopes: &'a [T], } -impl fmt::Display for NullError { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - "NULL-ERROR".fmt(f) - } -} - -impl TokenStorage for NullStorage { - type Error = NullError; - fn set(&mut self, _: u64, _: &Vec<&str>, _: Option) -> Result<(), NullError> { - Ok(()) - } - fn get(&self, _: u64, _: &Vec<&str>) -> Result, NullError> { - Ok(None) - } -} - -/// A storage that remembers values for one session only. -#[derive(Debug, Default)] -pub struct MemoryStorage { - tokens: Vec, -} - -impl MemoryStorage { - pub fn new() -> MemoryStorage { - Default::default() - } -} - -impl TokenStorage for MemoryStorage { - type Error = NullError; - - fn set( - &mut self, - scope_hash: u64, - scopes: &Vec<&str>, - token: Option, - ) -> Result<(), NullError> { - let matched = self.tokens.iter().find_position(|x| x.hash == scope_hash); - if let Some(_) = matched { - self.tokens.retain(|x| x.hash != scope_hash); +// Implement Clone manually. Auto derive fails to work correctly because we want +// Clone to be implemented regardless of whether T is Clone or not. +impl<'a, T> Clone for ScopeSet<'a, T> { + fn clone(&self) -> Self { + ScopeSet { + hash: self.hash, + filter: self.filter, + scopes: self.scopes, } - - match token { - Some(t) => { - self.tokens.push(JSONToken { - hash: scope_hash, - scopes: Some(scopes.iter().map(|x| x.to_string()).collect()), - token: t.clone(), - }); - () - } - None => {} - }; - Ok(()) } +} +impl<'a, T> Copy for ScopeSet<'a, T> {} - fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, NullError> { - let scopes: Vec<_> = scopes.iter().sorted().unique().collect(); +impl<'a, T> ScopeSet<'a, T> +where + T: AsRef, +{ + // implement an inherent from method even though From is implemented. This + // is because passing an array ref like &[&str; 1] (&["foo"]) will be auto + // deref'd to a slice on function boundaries, but it will not implement the + // From trait. This inherent method just serves to auto deref from array + // refs to slices and proxy to the From impl. + pub fn from(scopes: &'a [T]) -> Self { + let (hash, filter) = scopes.iter().fold( + (ScopeHash(0), ScopeFilter(0)), + |(mut scope_hash, mut scope_filter), scope| { + let h = seahash::hash(scope.as_ref().as_bytes()); - for t in &self.tokens { - if let Some(token_scopes) = &t.scopes { - let matched = token_scopes - .iter() - .filter(|x| scopes.contains(&&&x[..])) - .count(); - if matched >= scopes.len() { - return Result::Ok(Some(t.token.clone())); + // Use the first 4 6-bit chunks of the seahash as the 4 hash values + // in the bloom filter. + for i in 0..4 { + // h is a hash derived value in the range 0..64 + let h = (h >> (6 * i)) & 0b11_1111; + scope_filter.0 |= 1 << h; } - } else if scope_hash == t.hash { - return Result::Ok(Some(t.token.clone())); - } + + // xor the hashes together to get an order independent fingerprint. + scope_hash.0 ^= h; + (scope_hash, scope_filter) + }, + ); + ScopeSet { + hash, + filter, + scopes, + } + } +} + +pub(crate) enum Storage { + Memory { tokens: Mutex }, + Disk(DiskStorage), +} + +impl Storage { + pub(crate) async fn set( + &self, + scopes: ScopeSet<'_, T>, + token: TokenInfo, + ) -> Result<(), io::Error> + where + T: AsRef, + { + match self { + Storage::Memory { tokens } => tokens.lock().unwrap().set(scopes, token), + Storage::Disk(disk_storage) => disk_storage.set(scopes, token).await, + } + } + + pub(crate) fn get(&self, scopes: ScopeSet) -> Option + where + T: AsRef, + { + match self { + Storage::Memory { tokens } => tokens.lock().unwrap().get(scopes), + Storage::Disk(disk_storage) => disk_storage.get(scopes), } - Result::Ok(None) } } /// A single stored token. -#[derive(Debug, Clone, Serialize, Deserialize)] + +#[derive(Debug, Clone)] struct JSONToken { - pub hash: u64, - pub scopes: Option>, - pub token: Token, + scopes: Vec, + token: TokenInfo, + hash: ScopeHash, + filter: ScopeFilter, } -impl PartialEq for JSONToken { - fn eq(&self, other: &Self) -> bool { - self.hash == other.hash +impl<'de> Deserialize<'de> for JSONToken { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct RawJSONToken { + scopes: Vec, + token: TokenInfo, + } + let RawJSONToken { scopes, token } = RawJSONToken::deserialize(deserializer)?; + let ScopeSet { hash, filter, .. } = ScopeSet::from(&scopes); + Ok(JSONToken { + scopes, + token, + hash, + filter, + }) } } -impl Eq for JSONToken {} - -impl PartialOrd for JSONToken { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for JSONToken { - fn cmp(&self, other: &Self) -> Ordering { - self.hash.cmp(&other.hash) +impl Serialize for JSONToken { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + #[derive(Serialize)] + struct RawJSONToken<'a> { + scopes: &'a [String], + token: &'a TokenInfo, + } + RawJSONToken { + scopes: &self.scopes, + token: &self.token, + } + .serialize(serializer) } } /// List of tokens in a JSON object -#[derive(Serialize, Deserialize)] -struct JSONTokens { - pub tokens: Vec, +#[derive(Debug, Clone)] +pub(crate) struct JSONTokens { + token_map: HashMap, } -/// Serializes tokens to a JSON file on disk. -#[derive(Default)] -pub struct DiskTokenStorage { - location: String, - tokens: Vec, +impl Serialize for JSONTokens { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.collect_seq(self.token_map.values()) + } } -impl DiskTokenStorage { - pub fn new>(location: S) -> Result { - let mut dts = DiskTokenStorage { - location: location.as_ref().to_owned(), - tokens: Vec::new(), +impl<'de> Deserialize<'de> for JSONTokens { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct V; + impl<'de> serde::de::Visitor<'de> for V { + type Value = JSONTokens; + + // Format a message stating what data this Visitor expects to receive. + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a sequence of JSONToken's") + } + + fn visit_seq(self, mut access: M) -> Result + where + M: serde::de::SeqAccess<'de>, + { + let mut token_map = HashMap::with_capacity(access.size_hint().unwrap_or(0)); + while let Some(json_token) = access.next_element::()? { + token_map.insert(json_token.hash, json_token); + } + Ok(JSONTokens { token_map }) + } + } + + // Instantiate our Visitor and ask the Deserializer to drive + // it over the input data. + deserializer.deserialize_seq(V) + } +} + +impl JSONTokens { + pub(crate) fn new() -> Self { + JSONTokens { + token_map: HashMap::new(), + } + } + + async fn load_from_file(filename: &Path) -> Result { + let contents = tokio::fs::read(filename).await?; + serde_json::from_slice(&contents).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) + } + + fn get( + &self, + ScopeSet { + hash, + filter, + scopes, + }: ScopeSet, + ) -> Option + where + T: AsRef, + { + if let Some(json_token) = self.token_map.get(&hash) { + return Some(json_token.token.clone()); + } + + let requested_scopes_are_subset_of = |other_scopes: &[String]| { + scopes + .iter() + .all(|s| other_scopes.iter().any(|t| t.as_str() == s.as_ref())) + }; + // No exact match for the scopes provided. Search for any tokens that + // exist for a superset of the scopes requested. + self.token_map + .values() + .filter(|json_token| filter.is_subset_of(json_token.filter) == FilterResponse::Maybe) + .find(|v: &&JSONToken| requested_scopes_are_subset_of(&v.scopes)) + .map(|t: &JSONToken| t.token.clone()) + } + + fn set( + &mut self, + ScopeSet { + hash, + filter, + scopes, + }: ScopeSet, + token: TokenInfo, + ) -> Result<(), io::Error> + where + T: AsRef, + { + use std::collections::hash_map::Entry; + match self.token_map.entry(hash) { + Entry::Occupied(mut entry) => { + entry.get_mut().token = token; + } + Entry::Vacant(entry) => { + let json_token = JSONToken { + scopes: scopes.iter().map(|x| x.as_ref().to_owned()).collect(), + token, + hash, + filter, + }; + entry.insert(json_token.clone()); + } + } + Ok(()) + } +} + +pub(crate) struct DiskStorage { + tokens: Mutex, + filename: PathBuf, +} + +impl DiskStorage { + pub(crate) async fn new(filename: PathBuf) -> Result { + let tokens = match JSONTokens::load_from_file(&filename).await { + Ok(tokens) => tokens, + Err(e) if e.kind() == io::ErrorKind::NotFound => JSONTokens::new(), + Err(e) => return Err(e), }; - // best-effort - let read_result = dts.load_from_file(); - - match read_result { - Result::Ok(()) => Result::Ok(dts), - Result::Err(e) => { - match e.kind() { - io::ErrorKind::NotFound => Result::Ok(dts), // File not found; ignore and create new one - _ => Result::Err(e), // e.g. PermissionDenied - } - } - } + Ok(DiskStorage { + tokens: Mutex::new(tokens), + filename, + }) } - fn load_from_file(&mut self) -> Result<(), io::Error> { - let mut f = fs::OpenOptions::new().read(true).open(&self.location)?; - let mut contents = String::new(); - - match f.read_to_string(&mut contents) { - Result::Err(e) => return Result::Err(e), - Result::Ok(_sz) => (), - } - - let tokens: JSONTokens; - - match serde_json::from_str(&contents) { - Result::Err(e) => return Result::Err(io::Error::new(io::ErrorKind::InvalidData, e)), - Result::Ok(t) => tokens = t, - } - - for t in tokens.tokens { - self.tokens.push(t); - } - return Result::Ok(()); + pub(crate) async fn set( + &self, + scopes: ScopeSet<'_, T>, + token: TokenInfo, + ) -> Result<(), io::Error> + where + T: AsRef, + { + use tokio::io::AsyncWriteExt; + let json = { + use std::ops::Deref; + let mut lock = self.tokens.lock().unwrap(); + lock.set(scopes, token)?; + serde_json::to_string(lock.deref()) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? + }; + let mut f = open_writeable_file(&self.filename).await?; + f.write_all(json.as_bytes()).await?; + Ok(()) } - pub fn dump_to_file(&mut self) -> Result<(), io::Error> { - let mut jsontokens = JSONTokens { tokens: Vec::new() }; - - for token in self.tokens.iter() { - jsontokens.tokens.push((*token).clone()); - } - - let serialized; - - match serde_json::to_string(&jsontokens) { - Result::Err(e) => return Result::Err(io::Error::new(io::ErrorKind::InvalidData, e)), - Result::Ok(s) => serialized = s, - } - - let mut f = fs::OpenOptions::new() - .create(true) - .write(true) - .truncate(true) - .open(&self.location)?; - f.write(serialized.as_ref()).map(|_| ()) + pub(crate) fn get(&self, scopes: ScopeSet) -> Option + where + T: AsRef, + { + self.tokens.lock().unwrap().get(scopes) } } -impl TokenStorage for DiskTokenStorage { - type Error = io::Error; - fn set( - &mut self, - scope_hash: u64, - scopes: &Vec<&str>, - token: Option, - ) -> Result<(), Self::Error> { - let matched = self.tokens.iter().find_position(|x| x.hash == scope_hash); - if let Some(_) = matched { - self.tokens.retain(|x| x.hash != scope_hash); - } +#[cfg(unix)] +async fn open_writeable_file( + filename: impl AsRef, +) -> Result { + // Ensure if the file is created it's only readable and writable by the + // current user. + use std::os::unix::fs::OpenOptionsExt; + let opts: tokio::fs::OpenOptions = { + let mut opts = std::fs::OpenOptions::new(); + opts.write(true).create(true).truncate(true).mode(0o600); + opts.into() + }; + opts.open(filename).await +} - match token { - None => (), - Some(t) => { - self.tokens.push(JSONToken { - hash: scope_hash, - scopes: Some(scopes.iter().map(|x| x.to_string()).collect()), - token: t.clone(), - }); - () - } - } - self.dump_to_file() +#[cfg(not(unix))] +async fn open_writeable_file( + filename: impl AsRef, +) -> Result { + // I don't have knowledge of windows or other platforms to know how to + // create a file that's only readable by the current user. + tokio::fs::File::create(filename).await +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_scope_filter() { + let foo = ScopeSet::from(&["foo"]).filter; + let bar = ScopeSet::from(&["bar"]).filter; + let foobar = ScopeSet::from(&["foo", "bar"]).filter; + + // foo and bar are both subsets of foobar. This condition should hold no + // matter what changes are made to the bloom filter implementation. + assert!(foo.is_subset_of(foobar) == FilterResponse::Maybe); + assert!(bar.is_subset_of(foobar) == FilterResponse::Maybe); + + // These conditions hold under the current bloom filter implementation + // because "foo" and "bar" don't collide, but if the bloom filter + // implementations change it could be valid for them to return Maybe. + assert!(foo.is_subset_of(bar) == FilterResponse::No); + assert!(bar.is_subset_of(foo) == FilterResponse::No); + assert!(foobar.is_subset_of(foo) == FilterResponse::No); + assert!(foobar.is_subset_of(bar) == FilterResponse::No); } - fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, Self::Error> { - let scopes: Vec<_> = scopes.iter().sorted().unique().collect(); - for t in &self.tokens { - if let Some(token_scopes) = &t.scopes { - let matched = token_scopes - .iter() - .filter(|x| scopes.contains(&&&x[..])) - .count(); - // we may have some of the tokens as denormalized (many namespaces repeated) - if matched >= scopes.len() { - return Result::Ok(Some(t.token.clone())); - } - } else if scope_hash == t.hash { - return Result::Ok(Some(t.token.clone())); - } + #[tokio::test] + async fn test_disk_storage() { + let new_token = |access_token: &str| TokenInfo { + access_token: access_token.to_owned(), + refresh_token: None, + expires_at: None, + }; + let scope_set = ScopeSet::from(&["myscope"]); + let tempdir = tempfile::tempdir().unwrap(); + { + let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json")) + .await + .unwrap(); + assert!(storage.get(scope_set).is_none()); + storage + .set(scope_set, new_token("my_access_token")) + .await + .unwrap(); + assert_eq!(storage.get(scope_set), Some(new_token("my_access_token"))); + } + { + // Create a new DiskStorage instance and verify the tokens were read from disk correctly. + let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json")) + .await + .unwrap(); + assert_eq!(storage.get(scope_set), Some(new_token("my_access_token"))); } - Result::Ok(None) } } diff --git a/src/types.rs b/src/types.rs index 90553bc..5633f0d 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,343 +1,115 @@ -use chrono::{DateTime, TimeZone, Utc}; -use hyper; -use std::error::Error; -use std::fmt; -use std::io; -use std::str::FromStr; +use crate::error::{AuthErrorOr, Error}; -use futures::prelude::*; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; -/// A marker trait for all Flows -pub trait Flow { - fn type_id() -> FlowType; +/// Represents an access token returned by oauth2 servers. All access tokens are +/// Bearer tokens. Other types of tokens are not supported. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] +pub struct AccessToken { + value: String, + expires_at: Option>, } -#[derive(Deserialize, Debug)] -pub struct JsonError { - pub error: String, - pub error_description: Option, - pub error_uri: Option, -} +impl AccessToken { + /// A string representation of the access token. + pub fn as_str(&self) -> &str { + &self.value + } -/// All possible outcomes of the refresh flow -#[derive(Debug)] -pub enum RefreshResult { - /// Indicates connection failure - Error(hyper::Error), - /// The server did not answer with a new token, providing the server message - RefreshError(String, Option), - /// The refresh operation finished successfully, providing a new `Token` - Success(Token), -} + /// The time the access token will expire, if any. + pub fn expiration_time(&self) -> Option> { + self.expires_at + } -/// Encapsulates all possible results of a `poll_token(...)` operation in the Device flow. -#[derive(Debug)] -pub enum PollError { - /// Connection failure - retry if you think it's worth it - HttpError(hyper::Error), - /// Indicates we are expired, including the expiration date - Expired(DateTime), - /// Indicates that the user declined access. String is server response - AccessDenied, - /// Indicates that too many attempts failed. - TimedOut, - /// Other type of error. - Other(String), -} - -/// Encapsulates all possible results of the `token(...)` operation -#[derive(Debug)] -pub enum RequestError { - /// Indicates connection failure - ClientError(hyper::Error), - /// The OAuth client was not found - InvalidClient, - /// Some requested scopes were invalid. String contains the scopes as part of - /// the server error message - InvalidScope(String), - /// A 'catch-all' variant containing the server error and description - /// First string is the error code, the second may be a more detailed description - NegativeServerResponse(String, Option), - /// A malformed server response. - BadServerResponse(String), - /// Error while decoding a JSON response. - JSONError(serde_json::error::Error), - /// Error within user input. - UserError(String), - /// A lower level IO error. - LowLevelError(io::Error), - /// A poll error occurred in the DeviceFlow. - Poll(PollError), - /// An error occurred while refreshing tokens. - Refresh(RefreshResult), - /// Error in token cache layer - Cache(Box), -} - -impl From for RequestError { - fn from(error: hyper::Error) -> RequestError { - RequestError::ClientError(error) + /// Determine if the access token is expired. + /// This will report that the token is expired 1 minute prior to the + /// expiration time to ensure that when the token is actually sent to the + /// server it's still valid. + pub fn is_expired(&self) -> bool { + // Consider the token expired if it's within 1 minute of it's expiration + // time. + self.expires_at + .map(|expiration_time| expiration_time - chrono::Duration::minutes(1) <= Utc::now()) + .unwrap_or(false) } } -impl From for RequestError { - fn from(value: JsonError) -> RequestError { - match &*value.error { - "invalid_client" => RequestError::InvalidClient, - "invalid_scope" => RequestError::InvalidScope( - value - .error_description - .unwrap_or("no description provided".to_string()), - ), - _ => RequestError::NegativeServerResponse(value.error, value.error_description), +impl AsRef for AccessToken { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl From for AccessToken { + fn from(value: TokenInfo) -> Self { + AccessToken { + value: value.access_token, + expires_at: value.expires_at, } } } -impl fmt::Display for RequestError { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - match *self { - RequestError::ClientError(ref err) => err.fmt(f), - RequestError::InvalidClient => "Invalid Client".fmt(f), - RequestError::InvalidScope(ref scope) => writeln!(f, "Invalid Scope: '{}'", scope), - RequestError::NegativeServerResponse(ref error, ref desc) => { - error.fmt(f)?; - if let &Some(ref desc) = desc { - write!(f, ": {}", desc)?; - } - "\n".fmt(f) - } - RequestError::BadServerResponse(ref s) => s.fmt(f), - RequestError::JSONError(ref e) => format!( - "JSON Error; this might be a bug with unexpected server responses! {}", - e - ) - .fmt(f), - RequestError::UserError(ref s) => s.fmt(f), - RequestError::LowLevelError(ref e) => e.fmt(f), - RequestError::Poll(ref pe) => pe.fmt(f), - RequestError::Refresh(ref rr) => format!("{:?}", rr).fmt(f), - RequestError::Cache(ref e) => e.fmt(f), - } - } -} - -impl Error for RequestError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - match *self { - RequestError::ClientError(ref err) => Some(err), - RequestError::LowLevelError(ref err) => Some(err), - RequestError::JSONError(ref err) => Some(err), - _ => None, - } - } -} - -#[derive(Debug)] -pub struct StringError { - error: String, -} - -impl fmt::Display for StringError { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - self.description().fmt(f) - } -} - -impl StringError { - pub fn new>(error: S, desc: Option) -> StringError { - let mut error = error.as_ref().to_string(); - if let Some(d) = desc { - error.push_str(": "); - error.push_str(d.as_ref()); - } - - StringError { error: error } - } -} - -impl<'a> From<&'a dyn Error> for StringError { - fn from(err: &'a dyn Error) -> StringError { - StringError::new(err.description().to_string(), None) - } -} - -impl From for StringError { - fn from(value: String) -> StringError { - StringError::new(value, None) - } -} - -impl Error for StringError { - fn description(&self) -> &str { - &self.error - } -} - -/// Represents all implemented token types -#[derive(Clone, PartialEq, Debug)] -pub enum TokenType { - /// Means that whoever bears the access token will be granted access - Bearer, -} - -impl AsRef for TokenType { - fn as_ref(&self) -> &'static str { - match *self { - TokenType::Bearer => "Bearer", - } - } -} - -impl FromStr for TokenType { - type Err = (); - fn from_str(s: &str) -> Result { - match s { - "Bearer" => Ok(TokenType::Bearer), - _ => Err(()), - } - } -} - -/// A scheme for use in `hyper::header::Authorization` -#[derive(Clone, PartialEq, Debug)] -pub struct Scheme { - /// The type of our access token - pub token_type: TokenType, - /// The token returned by one of the Authorization Flows - pub access_token: String, -} - -impl std::convert::Into for Scheme { - fn into(self) -> hyper::header::HeaderValue { - hyper::header::HeaderValue::from_str(&format!( - "{} {}", - self.token_type.as_ref(), - self.access_token - )) - .expect("Invalid Scheme header value") - } -} - -impl FromStr for Scheme { - type Err = &'static str; - fn from_str(s: &str) -> Result { - let parts: Vec<&str> = s.split(' ').collect(); - if parts.len() != 2 { - return Err("Expected two parts: "); - } - match ::from_str(parts[0]) { - Ok(t) => Ok(Scheme { - token_type: t, - access_token: parts[1].to_string(), - }), - Err(_) => Err("Couldn't parse token type"), - } - } -} - -/// A provider for authorization tokens, yielding tokens valid for a given scope. -/// The `api_key()` method is an alternative in case there are no scopes or -/// if no user is involved. -pub trait GetToken { - fn token( - &mut self, - scopes: I, - ) -> Box + Send> - where - T: Into, - I: IntoIterator; - - fn api_key(&mut self) -> Option; - - /// Return an application secret with at least token_uri, client_secret, and client_id filled - /// in. This is used for refreshing tokens without interaction from the flow. - fn application_secret(&self) -> ApplicationSecret; -} - /// Represents a token as returned by OAuth2 servers. /// /// It is produced by all authentication flows. /// It authenticates certain operations, and must be refreshed once /// it reached it's expiry date. -/// -/// The type is tuned to be suitable for direct de-serialization from server -/// replies, as well as for serialization for later reuse. This is the reason -/// for the two fields dealing with expiry - once in relative in and once in -/// absolute terms. -/// -/// Utility methods make common queries easier, see `expired()`. #[derive(Clone, PartialEq, Debug, Deserialize, Serialize)] -pub struct Token { +pub(crate) struct TokenInfo { /// used when authenticating calls to oauth2 enabled services. - pub access_token: String, + pub(crate) access_token: String, /// used to refresh an expired access_token. - pub refresh_token: Option, - /// The token type as string - usually 'Bearer'. - pub token_type: String, - /// access_token will expire after this amount of time. - /// Prefer using expiry_date() - pub expires_in: Option, - /// timestamp is seconds since epoch indicating when the token will expire in absolute terms. - /// use expiry_date() to convert to DateTime. - pub expires_in_timestamp: Option, + pub(crate) refresh_token: Option, + /// The time when the token expires. + pub(crate) expires_at: Option>, } -impl Token { +impl TokenInfo { + pub(crate) fn from_json(json_data: &[u8]) -> Result { + #[derive(Deserialize)] + struct RawToken { + access_token: String, + refresh_token: Option, + token_type: String, + expires_in: Option, + } + + let RawToken { + access_token, + refresh_token, + token_type, + expires_in, + } = serde_json::from_slice::>(json_data)?.into_result()?; + + if token_type.to_lowercase().as_str() != "bearer" { + use std::io; + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + r#"unknown token type returned; expected "bearer" found {}"#, + token_type + ), + ) + .into()); + } + + let expires_at = expires_in + .map(|seconds_from_now| Utc::now() + chrono::Duration::seconds(seconds_from_now)); + + Ok(TokenInfo { + access_token, + refresh_token, + expires_at, + }) + } + /// Returns true if we are expired. - /// - /// # Panics - /// * if our access_token is unset - pub fn expired(&self) -> bool { - if self.access_token.len() == 0 { - panic!("called expired() on unset token"); - } - if let Some(expiry_date) = self.expiry_date() { - expiry_date - chrono::Duration::minutes(1) <= Utc::now() - } else { - false - } + pub fn is_expired(&self) -> bool { + self.expires_at + .map(|expiration_time| expiration_time - chrono::Duration::minutes(1) <= Utc::now()) + .unwrap_or(false) } - - /// Returns a DateTime object representing our expiry date. - pub fn expiry_date(&self) -> Option> { - let expires_in_timestamp = self.expires_in_timestamp?; - - Utc.timestamp(expires_in_timestamp, 0).into() - } - - /// Adjust our stored expiry format to be absolute, using the current time. - pub fn set_expiry_absolute(&mut self) -> &mut Token { - if self.expires_in_timestamp.is_some() { - assert!(self.expires_in.is_none()); - return self; - } - - if let Some(expires_in) = self.expires_in { - self.expires_in_timestamp = Some(Utc::now().timestamp() + expires_in); - self.expires_in = None; - } - - self - } -} - -/// All known authentication types, for suitable constants -#[derive(Clone)] -pub enum FlowType { - /// [device authentication](https://developers.google.com/youtube/v3/guides/authentication#devices). Only works - /// for certain scopes. - /// Contains the device token URL; for google, that is - /// https://accounts.google.com/o/oauth2/device/code (exported as `GOOGLE_DEVICE_CODE_URL`) - Device(String), - /// [installed app flow](https://developers.google.com/identity/protocols/OAuth2InstalledApp). Required - /// for Drive, Calendar, Gmail...; Requires user to paste a code from the browser. - InstalledInteractive, - /// Same as InstalledInteractive, but uses a redirect: The OAuth provider redirects the user's - /// browser to a web server that is running on localhost. This may not work as well with the - /// Windows Firewall, but is more comfortable otherwise. The integer describes which port to - /// bind to (default: 8080) - InstalledRedirect(u16), } /// Represents either 'installed' or 'web' applications in a json secrets file. @@ -352,8 +124,8 @@ pub struct ApplicationSecret { pub token_uri: String, /// The authorization server endpoint URI. pub auth_uri: String, + /// The redirect uris. pub redirect_uris: Vec, - /// Name of the google project the credentials are associated with pub project_id: Option, /// The service account email associated with the client. @@ -369,14 +141,15 @@ pub struct ApplicationSecret { /// as returned by the [google developer console](https://code.google.com/apis/console) #[derive(Deserialize, Serialize, Default)] pub struct ConsoleApplicationSecret { + /// web app secret pub web: Option, + /// installed app secret pub installed: Option, } #[cfg(test)] pub mod tests { use super::*; - use hyper; pub const SECRET: &'static str = "{\"installed\":{\"auth_uri\":\"https://accounts.google.com/o/oauth2/auth\",\ @@ -394,25 +167,4 @@ pub mod tests { Err(err) => panic!(err), } } - - #[test] - fn schema() { - let s = Scheme { - token_type: TokenType::Bearer, - access_token: "foo".to_string(), - }; - let mut headers = hyper::HeaderMap::new(); - headers.insert(hyper::header::AUTHORIZATION, s.into()); - assert_eq!( - format!("{:?}", headers), - "{\"authorization\": \"Bearer foo\"}".to_string() - ); - } - - #[test] - fn parse_schema() { - let auth = Scheme::from_str("Bearer foo").unwrap(); - assert_eq!(auth.token_type, TokenType::Bearer); - assert_eq!(auth.access_token, "foo".to_string()); - } } diff --git a/tests/tests.rs b/tests/tests.rs new file mode 100644 index 0000000..265916c --- /dev/null +++ b/tests/tests.rs @@ -0,0 +1,618 @@ +use yup_oauth2::{ + authenticator::Authenticator, + authenticator_delegate::{DeviceAuthResponse, DeviceFlowDelegate, InstalledFlowDelegate}, + error::{AuthError, AuthErrorCode}, + ApplicationSecret, DeviceFlowAuthenticator, Error, InstalledFlowAuthenticator, + InstalledFlowReturnMethod, ServiceAccountAuthenticator, ServiceAccountKey, +}; + +use std::future::Future; +use std::path::PathBuf; +use std::pin::Pin; + +use httptest::{mappers::*, responders::json_encoded, Expectation, Server}; +use hyper::client::connect::HttpConnector; +use hyper::Uri; +use hyper_rustls::HttpsConnector; +use url::form_urlencoded; + +/// Utility function for parsing json. Useful in unit tests. Simply wrap the +/// json! macro in a from_value to deserialize the contents to arbitrary structs. +macro_rules! parse_json { + ($($json:tt)+) => { + ::serde_json::from_value(::serde_json::json!($($json)+)).expect("failed to deserialize") + } +} + +async fn create_device_flow_auth(server: &Server) -> Authenticator> { + let app_secret: ApplicationSecret = parse_json!({ + "client_id": "902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com", + "project_id": "yup-test-243420", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": server.url_str("/token"), + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_secret": "iuMPN6Ne1PD7cos29Tk9rlqH", + "redirect_uris": ["urn:ietf:wg:oauth:2.0:oob","http://localhost"], + }); + struct FD; + impl DeviceFlowDelegate for FD { + fn present_user_code<'a>( + &'a self, + pi: &'a DeviceAuthResponse, + ) -> Pin + 'a + Send>> { + assert_eq!("https://example.com/verify", pi.verification_uri); + Box::pin(async {}) + } + } + + DeviceFlowAuthenticator::builder(app_secret) + .flow_delegate(Box::new(FD)) + .device_code_url(server.url_str("/code")) + .build() + .await + .unwrap() +} + +#[tokio::test] +async fn test_device_success() { + let _ = env_logger::try_init(); + let server = Server::run(); + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/code"), + request::body(url_decoded(contains_entry(( + "client_id", + matches("902216714886") + )))), + ]) + .respond_with(json_encoded(serde_json::json!({ + "device_code": "devicecode", + "user_code": "usercode", + "verification_url": "https://example.com/verify", + "expires_in": 1234567, + "interval": 1 + }))), + ); + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("client_secret", "iuMPN6Ne1PD7cos29Tk9rlqH")), + contains_entry(("code", "devicecode")), + ])), + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 1234567 + }))), + ); + + let auth = create_device_flow_auth(&server).await; + let token = auth + .token(&["https://www.googleapis.com/scope/1"]) + .await + .expect("token failed"); + assert_eq!("accesstoken", token.as_str()); +} + +#[tokio::test] +async fn test_device_no_code() { + let _ = env_logger::try_init(); + let server = Server::run(); + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/code"), + request::body(url_decoded(contains_entry(( + "client_id", + matches("902216714886") + )))), + ]) + .respond_with(json_encoded(serde_json::json!({ + "error": "invalid_client_id", + "error_description": "description" + }))), + ); + let auth = create_device_flow_auth(&server).await; + let res = auth.token(&["https://www.googleapis.com/scope/1"]).await; + assert!(res.is_err()); + assert!(format!("{}", res.unwrap_err()).contains("invalid_client_id")); +} + +#[tokio::test] +async fn test_device_no_token() { + let _ = env_logger::try_init(); + let server = Server::run(); + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/code"), + request::body(url_decoded(contains_entry(( + "client_id", + matches("902216714886") + )))), + ]) + .respond_with(json_encoded(serde_json::json!({ + "device_code": "devicecode", + "user_code": "usercode", + "verification_url": "https://example.com/verify", + "expires_in": 1234567, + "interval": 1 + }))), + ); + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("client_secret", "iuMPN6Ne1PD7cos29Tk9rlqH")), + contains_entry(("code", "devicecode")), + ])), + ]) + .respond_with(json_encoded(serde_json::json!({ + "error": "access_denied" + }))), + ); + let auth = create_device_flow_auth(&server).await; + let res = auth.token(&["https://www.googleapis.com/scope/1"]).await; + assert!(res.is_err()); + assert!(format!("{}", res.unwrap_err()).contains("access_denied")); +} + +async fn create_installed_flow_auth( + server: &Server, + method: InstalledFlowReturnMethod, + filename: Option, +) -> Authenticator> { + let app_secret: ApplicationSecret = parse_json!({ + "client_id": "902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com", + "project_id": "yup-test-243420", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": server.url_str("/token"), + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_secret": "iuMPN6Ne1PD7cos29Tk9rlqH", + "redirect_uris": ["urn:ietf:wg:oauth:2.0:oob","http://localhost"], + }); + struct FD(hyper::Client>); + impl InstalledFlowDelegate for FD { + /// Depending on need_code, return the pre-set code or send the code to the server at + /// the redirect_uri given in the url. + fn present_user_url<'a>( + &'a self, + url: &'a str, + need_code: bool, + ) -> Pin> + Send + 'a>> { + use std::str::FromStr; + Box::pin(async move { + if need_code { + Ok("authorizationcode".to_owned()) + } else { + // Parse presented url to obtain redirect_uri with location of local + // code-accepting server. + let uri = Uri::from_str(url.as_ref()).unwrap(); + let query = uri.query().unwrap(); + let parsed = form_urlencoded::parse(query.as_bytes()).into_owned(); + let mut rduri = None; + for (k, v) in parsed { + if k == "redirect_uri" { + rduri = Some(v); + break; + } + } + if rduri.is_none() { + return Err("no redirect_uri!".into()); + } + let mut rduri = rduri.unwrap(); + rduri.push_str("?code=authorizationcode"); + let rduri = Uri::from_str(rduri.as_ref()).unwrap(); + // Hit server. + self.0 + .get(rduri) + .await + .map_err(|e| e.to_string()) + .map(|_| "".to_string()) + } + }) + } + } + + let mut builder = InstalledFlowAuthenticator::builder(app_secret, method).flow_delegate( + Box::new(FD(hyper::Client::builder().build(HttpsConnector::new()))), + ); + + builder = if let Some(filename) = filename { + builder.persist_tokens_to_disk(filename) + } else { + builder + }; + + builder.build().await.unwrap() +} + +#[tokio::test] +async fn test_installed_interactive_success() { + let _ = env_logger::try_init(); + let server = Server::run(); + let auth = + create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("code", "authorizationcode")), + contains_entry(("client_id", matches("9022167.*"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 12345678 + }))), + ); + + let tok = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!("accesstoken", tok.as_str()); +} + +#[tokio::test] +async fn test_installed_redirect_success() { + let _ = env_logger::try_init(); + let server = Server::run(); + let auth = + create_installed_flow_auth(&server, InstalledFlowReturnMethod::HTTPRedirect, None).await; + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("code", "authorizationcode")), + contains_entry(("client_id", matches("9022167.*"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 12345678 + }))), + ); + + let tok = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!("accesstoken", tok.as_str()); +} + +#[tokio::test] +async fn test_installed_error() { + let _ = env_logger::try_init(); + let server = Server::run(); + let auth = + create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("code", "authorizationcode")), + contains_entry(("client_id", matches("9022167.*"))), + ])) + ]) + .respond_with( + http::Response::builder() + .status(404) + .body(serde_json::json!({"error": "invalid_code"}).to_string()) + .unwrap(), + ), + ); + + let tokr = auth.token(&["https://googleapis.com/some/scope"]).await; + assert!(tokr.is_err()); + assert!(format!("{}", tokr.unwrap_err()).contains("invalid_code")); +} + +async fn create_service_account_auth( + server: &Server, +) -> Authenticator> { + let key: ServiceAccountKey = parse_json!({ + "type": "service_account", + "project_id": "yup-test-243420", + "private_key_id": "26de294916614a5ebdf7a065307ed3ea9941902b", + "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDemmylrvp1KcOn\n9yTAVVKPpnpYznvBvcAU8Qjwr2fSKylpn7FQI54wCk5VJVom0jHpAmhxDmNiP8yv\nHaqsef+87Oc0n1yZ71/IbeRcHZc2OBB33/LCFqf272kThyJo3qspEqhuAw0e8neg\nLQb4jpm9PsqR8IjOoAtXQSu3j0zkXemMYFy93PWHjVpPEUX16NGfsWH7oxspBHOk\n9JPGJL8VJdbiAoDSDgF0y9RjJY5I52UeHNhMsAkTYs6mIG4kKXt2+T9tAyHw8aho\nwmuytQAfydTflTfTG8abRtliF3nil2taAc5VB07dP1b4dVYy/9r6M8Z0z4XM7aP+\nNdn2TKm3AgMBAAECggEAWi54nqTlXcr2M5l535uRb5Xz0f+Q/pv3ceR2iT+ekXQf\n+mUSShOr9e1u76rKu5iDVNE/a7H3DGopa7ZamzZvp2PYhSacttZV2RbAIZtxU6th\n7JajPAM+t9klGh6wj4jKEcE30B3XVnbHhPJI9TCcUyFZoscuPXt0LLy/z8Uz0v4B\nd5JARwyxDMb53VXwukQ8nNY2jP7WtUig6zwE5lWBPFMbi8GwGkeGZOruAK5sPPwY\nGBAlfofKANI7xKx9UXhRwisB4+/XI1L0Q6xJySv9P+IAhDUI6z6kxR+WkyT/YpG3\nX9gSZJc7qEaxTIuDjtep9GTaoEqiGntjaFBRKoe+VQKBgQDzM1+Ii+REQqrGlUJo\nx7KiVNAIY/zggu866VyziU6h5wjpsoW+2Npv6Dv7nWvsvFodrwe50Y3IzKtquIal\nVd8aa50E72JNImtK/o5Nx6xK0VySjHX6cyKENxHRDnBmNfbALRM+vbD9zMD0lz2q\nmns/RwRGq3/98EqxP+nHgHSr9QKBgQDqUYsFAAfvfT4I75Glc9svRv8IsaemOm07\nW1LCwPnj1MWOhsTxpNF23YmCBupZGZPSBFQobgmHVjQ3AIo6I2ioV6A+G2Xq/JCF\nmzfbvZfqtbbd+nVgF9Jr1Ic5T4thQhAvDHGUN77BpjEqZCQLAnUWJx9x7e2xvuBl\n1A6XDwH/ewKBgQDv4hVyNyIR3nxaYjFd7tQZYHTOQenVffEAd9wzTtVbxuo4sRlR\nNM7JIRXBSvaATQzKSLHjLHqgvJi8LITLIlds1QbNLl4U3UVddJbiy3f7WGTqPFfG\nkLhUF4mgXpCpkMLxrcRU14Bz5vnQiDmQRM4ajS7/kfwue00BZpxuZxst3QKBgQCI\nRI3FhaQXyc0m4zPfdYYVc4NjqfVmfXoC1/REYHey4I1XetbT9Nb/+ow6ew0UbgSC\nUZQjwwJ1m1NYXU8FyovVwsfk9ogJ5YGiwYb1msfbbnv/keVq0c/Ed9+AG9th30qM\nIf93hAfClITpMz2mzXIMRQpLdmQSR4A2l+E4RjkSOwKBgQCB78AyIdIHSkDAnCxz\nupJjhxEhtQ88uoADxRoEga7H/2OFmmPsqfytU4+TWIdal4K+nBCBWRvAX1cU47vH\nJOlSOZI0gRKe0O4bRBQc8GXJn/ubhYSxI02IgkdGrIKpOb5GG10m85ZvqsXw3bKn\nRVHMD0ObF5iORjZUqD0yRitAdg==\n-----END PRIVATE KEY-----\n", + "client_email": "yup-test-sa-1@yup-test-243420.iam.gserviceaccount.com", + "client_id": "102851967901799660408", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": server.url_str("/token"), + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/yup-test-sa-1%40yup-test-243420.iam.gserviceaccount.com" + }); + + ServiceAccountAuthenticator::builder(key) + .build() + .await + .unwrap() +} + +#[tokio::test] +async fn test_service_account_success() { + use chrono::Utc; + let _ = env_logger::try_init(); + let server = Server::run(); + let auth = create_service_account_auth(&server).await; + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + ]).respond_with(json_encoded(serde_json::json!({ + "access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk", + "expires_in": 3600, + "token_type": "Bearer" + }))) + ); + let tok = auth + .token(&["https://www.googleapis.com/auth/pubsub"]) + .await + .expect("token failed"); + assert!(tok.as_str().contains("ya29.c.ElouBywiys0Ly")); + assert!(Utc::now() + chrono::Duration::seconds(3600) >= tok.expiration_time().unwrap()); +} + +#[tokio::test] +async fn test_service_account_error() { + let _ = env_logger::try_init(); + let server = Server::run(); + let auth = create_service_account_auth(&server).await; + server.expect( + Expectation::matching(all_of![request::method("POST"), request::path("/token"),]) + .respond_with(json_encoded(serde_json::json!({ + "error": "access_denied", + }))), + ); + + let result = auth + .token(&["https://www.googleapis.com/auth/pubsub"]) + .await; + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_refresh() { + let _ = env_logger::try_init(); + let server = Server::run(); + let auth = + create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; + // We refresh a token whenever it's within 1 minute of expiring. So + // acquiring a token that expires in 59 seconds will force a refresh on + // the next token call. + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("code", "authorizationcode")), + contains_entry(("client_id", matches("^9022167"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 59, + }))), + ); + let tok = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!("accesstoken", tok.as_str()); + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("refresh_token", "refreshtoken")), + contains_entry(("client_id", matches("^9022167"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken2", + "token_type": "Bearer", + "expires_in": 59, + }))), + ); + + let tok = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!("accesstoken2", tok.as_str()); + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("refresh_token", "refreshtoken")), + contains_entry(("client_id", matches("^9022167"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken3", + "token_type": "Bearer", + "expires_in": 59, + }))), + ); + + let tok = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!("accesstoken3", tok.as_str()); + + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("refresh_token", "refreshtoken")), + contains_entry(("client_id", matches("^9022167"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "error": "invalid_request", + }))), + ); + + let tok_err = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect_err("token refresh succeeded unexpectedly"); + match tok_err { + Error::AuthError(AuthError { + error: AuthErrorCode::InvalidRequest, + .. + }) => {} + e => panic!("unexpected error on refresh: {:?}", e), + } +} + +#[tokio::test] +async fn test_memory_storage() { + let _ = env_logger::try_init(); + let server = Server::run(); + let auth = + create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("code", "authorizationcode")), + contains_entry(("client_id", matches("^9022167"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 12345678, + }))), + ); + + // Call token twice. Ensure that only one http request is made and + // identical tokens are returned. + let token1 = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + let token2 = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!(token1.as_str(), "accesstoken"); + assert_eq!(token1, token2); + + // Create a new authenticator. This authenticator does not share a cache + // with the previous one. Validate that it receives a different token. + let auth2 = + create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("code", "authorizationcode")), + contains_entry(("client_id", matches("^9022167"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken2", + "refresh_token": "refreshtoken2", + "token_type": "Bearer", + "expires_in": 12345678, + }))), + ); + let token3 = auth2 + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!(token3.as_str(), "accesstoken2"); +} + +#[tokio::test] +async fn test_disk_storage() { + let _ = env_logger::try_init(); + let server = Server::run(); + let tempdir = tempfile::tempdir().unwrap(); + let storage_path = tempdir.path().join("tokenstorage.json"); + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("code", "authorizationcode")), + contains_entry(("client_id", matches("^9022167"))), + ])), + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 12345678 + }))), + ); + { + let auth = create_installed_flow_auth( + &server, + InstalledFlowReturnMethod::Interactive, + Some(storage_path.clone()), + ) + .await; + + // Call token twice. Ensure that only one http request is made and + // identical tokens are returned. + let token1 = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + let token2 = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!(token1.as_str(), "accesstoken"); + assert_eq!(token1, token2); + } + + // Create a new authenticator. This authenticator uses the same token + // storage file as the previous one so should receive a token without + // making any http requests. + let auth = create_installed_flow_auth( + &server, + InstalledFlowReturnMethod::Interactive, + Some(storage_path.clone()), + ) + .await; + // Call token twice. Ensure that identical tokens are returned. + let token1 = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + let token2 = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!(token1.as_str(), "accesstoken"); + assert_eq!(token1, token2); +}