From 93cbd913415d1bbfa9ff11841685440b5b8d1fc5 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 7 Nov 2019 14:23:40 -0800 Subject: [PATCH 01/71] Move to std::futures to support async/await. --- Cargo.toml | 11 +- examples/test-device/Cargo.toml | 8 +- examples/test-device/src/main.rs | 14 +- examples/test-installed/Cargo.toml | 8 +- examples/test-installed/src/main.rs | 27 +- examples/test-svc-acct/Cargo.toml | 8 +- examples/test-svc-acct/src/main.rs | 33 +- src/authenticator.rs | 220 ++++---- src/authenticator_delegate.rs | 74 +-- src/device.rs | 341 ++++++------- src/installed.rs | 765 ++++++++++++---------------- src/lib.rs | 15 +- src/refresh.rs | 124 +++-- src/service_account.rs | 284 +++++------ src/storage.rs | 128 +++-- src/types.rs | 11 +- 16 files changed, 923 insertions(+), 1148 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 49a3df1..e1ca60f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,8 @@ edition = "2018" base64 = "0.10" chrono = "0.4" http = "0.1" -hyper = {version = "0.12", default-features = false} -hyper-rustls = "0.17" +hyper = {version = "0.13.0-alpha.4", features = ["unstable-stream"]} +hyper-rustls = "=0.18.0-alpha.2" itertools = "0.8" log = "0.3" rustls = "0.16" @@ -23,10 +23,9 @@ serde = "1.0" serde_json = "1.0" serde_derive = "1.0" url = "1" -futures = "0.1" -tokio-threadpool = "0.1" -tokio = "0.1" -tokio-timer = "0.2" +futures-preview = "=0.3.0-alpha.19" +tokio = "=0.2.0-alpha.6" +futures-util-preview = "=0.3.0-alpha.19" [dev-dependencies] getopts = "0.2" diff --git a/examples/test-device/Cargo.toml b/examples/test-device/Cargo.toml index 1b4ed95..39ca484 100644 --- a/examples/test-device/Cargo.toml +++ b/examples/test-device/Cargo.toml @@ -6,7 +6,7 @@ edition = "2018" [dependencies] yup-oauth2 = { path = "../../" } -hyper = "0.12" -hyper-rustls = "0.17" -futures = "0.1" -tokio = "0.1" +hyper = {version = "0.13.0-alpha.4", features = ["unstable-stream"]} +hyper-rustls = "=0.18.0-alpha.2" +futures-preview = "=0.3.0-alpha.19" +tokio = "=0.2.0-alpha.6" diff --git a/examples/test-device/src/main.rs b/examples/test-device/src/main.rs index 278fd78..42b3ab8 100644 --- a/examples/test-device/src/main.rs +++ b/examples/test-device/src/main.rs @@ -1,20 +1,20 @@ -use futures::prelude::*; use yup_oauth2::{self, Authenticator, DeviceFlow, GetToken}; use std::path; use tokio; -fn main() { +#[tokio::main] +async fn main() { let creds = yup_oauth2::read_application_secret(path::Path::new("clientsecret.json")) .expect("clientsecret"); - let mut auth = Authenticator::new(DeviceFlow::new(creds)) + let auth = Authenticator::new(DeviceFlow::new(creds)) .persist_tokens_to_disk("tokenstorage.json") .build() .expect("authenticator"); let scopes = vec!["https://www.googleapis.com/auth/youtube.readonly"]; - let mut rt = tokio::runtime::Runtime::new().unwrap(); - let fut = auth.token(scopes).and_then(|tok| Ok(println!("{:?}", tok))); - - println!("{:?}", rt.block_on(fut)); + match auth.token(scopes).await { + Err(e) => println!("error: {:?}", e), + Ok(t) => println!("token: {:?}", t), + } } diff --git a/examples/test-installed/Cargo.toml b/examples/test-installed/Cargo.toml index 0d6e654..e7fa5d2 100644 --- a/examples/test-installed/Cargo.toml +++ b/examples/test-installed/Cargo.toml @@ -6,7 +6,7 @@ edition = "2018" [dependencies] yup-oauth2 = { path = "../../" } -hyper = "0.12" -hyper-rustls = "0.17" -futures = "0.1" -tokio = "0.1" +hyper = {version = "0.13.0-alpha.4", features = ["unstable-stream"]} +hyper-rustls = "=0.18.0-alpha.2" +futures-preview = "=0.3.0-alpha.19" +tokio = "=0.2.0-alpha.6" diff --git a/examples/test-installed/src/main.rs b/examples/test-installed/src/main.rs index 3aa29ca..f4909a3 100644 --- a/examples/test-installed/src/main.rs +++ b/examples/test-installed/src/main.rs @@ -1,24 +1,16 @@ -use futures::prelude::*; use yup_oauth2::GetToken; use yup_oauth2::{Authenticator, InstalledFlow}; -use hyper::client::Client; -use hyper_rustls::HttpsConnector; - use std::path::Path; -fn main() { - let https = HttpsConnector::new(1); - let client = Client::builder() - .keep_alive(false) - .build::<_, hyper::Body>(https); - let ad = yup_oauth2::DefaultFlowDelegate; +#[tokio::main] +async fn main() { let secret = yup_oauth2::read_application_secret(Path::new("clientsecret.json")) .expect("clientsecret.json"); - let mut auth = Authenticator::new(InstalledFlow::new( + let auth = Authenticator::new(InstalledFlow::new( secret, - yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect(8081), + yup_oauth2::InstalledFlowReturnMethod::HTTPRedirectEphemeral, )) .persist_tokens_to_disk("tokencache.json") .build() @@ -26,11 +18,8 @@ fn main() { let s = "https://www.googleapis.com/auth/drive.file".to_string(); let scopes = vec![s]; - let tok = auth.token(scopes); - let fut = tok.map_err(|e| println!("error: {:?}", e)).and_then(|t| { - println!("The token is {:?}", t); - Ok(()) - }); - - tokio::run(fut) + match auth.token(scopes).await { + Err(e) => println!("error: {:?}", e), + Ok(t) => println!("The token is {:?}", t), + } } diff --git a/examples/test-svc-acct/Cargo.toml b/examples/test-svc-acct/Cargo.toml index 4bccf09..14c7d9b 100644 --- a/examples/test-svc-acct/Cargo.toml +++ b/examples/test-svc-acct/Cargo.toml @@ -6,7 +6,7 @@ edition = "2018" [dependencies] yup-oauth2 = { path = "../../" } -hyper = "0.12" -hyper-rustls = "0.17" -futures = "0.1" -tokio = "0.1" +hyper = {version = "0.13.0-alpha.4", features = ["unstable-stream"]} +hyper-rustls = "=0.18.0-alpha.2" +futures-preview = "=0.3.0-alpha.19" +tokio = "=0.2.0-alpha.6" diff --git a/examples/test-svc-acct/src/main.rs b/examples/test-svc-acct/src/main.rs index ebaaac1..3d18fdc 100644 --- a/examples/test-svc-acct/src/main.rs +++ b/examples/test-svc-acct/src/main.rs @@ -1,29 +1,22 @@ +use std::path; +use tokio; use yup_oauth2; - -use futures::prelude::*; use yup_oauth2::GetToken; -use tokio; - -use std::path; - -fn main() { +#[tokio::main] +async fn main() { let creds = yup_oauth2::service_account_key_from_file(path::Path::new("serviceaccount.json")).unwrap(); - let mut sa = yup_oauth2::ServiceAccountAccess::new(creds).build(); + let sa = yup_oauth2::ServiceAccountAccess::new(creds).build(); - let fut = sa + let tok = sa .token(vec!["https://www.googleapis.com/auth/pubsub"]) - .and_then(|tok| { - println!("token is: {:?}", tok); - Ok(()) - }); - let fut2 = sa + .await + .unwrap(); + println!("token is: {:?}", tok); + let tok = sa .token(vec!["https://www.googleapis.com/auth/pubsub"]) - .and_then(|tok| { - println!("cached token is {:?} and should be identical", tok); - Ok(()) - }); - let all = fut.join(fut2).then(|_| Ok(())); - tokio::run(all) + .await + .unwrap(); + println!("cached token is {:?} and should be identical", tok); } diff --git a/src/authenticator.rs b/src/authenticator.rs index 3a7189e..cde9fe4 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -3,13 +3,13 @@ use crate::refresh::RefreshFlow; use crate::storage::{hash_scopes, DiskTokenStorage, MemoryStorage, TokenStorage}; use crate::types::{ApplicationSecret, GetToken, RefreshResult, RequestError, Token}; -use futures::{future, prelude::*}; -use tokio_timer; +use futures::prelude::*; use std::error::Error; use std::io; use std::path::Path; -use std::sync::{Arc, Mutex}; +use std::pin::Pin; +use std::sync::Arc; /// Authenticator abstracts different `GetToken` implementations behind one type and handles /// caching received tokens. It's important to use it (instead of the flows directly) because @@ -28,8 +28,8 @@ struct AuthenticatorImpl< C: hyper::client::connect::Connect, > { client: hyper::Client, - inner: Arc>, - store: Arc>, + inner: Arc, + store: Arc, delegate: AD, } @@ -48,7 +48,7 @@ impl HyperClientBuilder for DefaultHyperClient { fn build_hyper_client(self) -> hyper::Client { hyper::Client::builder() .keep_alive(false) - .build::<_, hyper::Body>(hyper_rustls::HttpsConnector::new(1)) + .build::<_, hyper::Body>(hyper_rustls::HttpsConnector::new()) } } @@ -169,14 +169,12 @@ where where T::TokenGetter: 'static + GetToken + Send, S: 'static + Send, - AD: 'static + Send, + AD: 'static + Send + Sync, C::Connector: 'static + Clone + Send, { let client = self.client.build_hyper_client(); - let store = Arc::new(Mutex::new(self.store?)); - let inner = Arc::new(Mutex::new( - self.token_getter.build_token_getter(client.clone()), - )); + let store = Arc::new(self.store?); + let inner = Arc::new(self.token_getter.build_token_getter(client.clone())); Ok(AuthenticatorImpl { client, @@ -187,147 +185,125 @@ where } } -impl< - GT: 'static + GetToken + Send, - S: 'static + TokenStorage + Send, - AD: 'static + AuthenticatorDelegate + Send, - C: 'static + hyper::client::connect::Connect + Clone + Send, - > GetToken for AuthenticatorImpl +impl AuthenticatorImpl +where + GT: 'static + GetToken, + S: 'static + TokenStorage, + AD: 'static + AuthenticatorDelegate + Send + Sync, + C: 'static + hyper::client::connect::Connect + Clone + Send, { - /// Returns the API Key of the inner flow. - fn api_key(&mut self) -> Option { - self.inner.lock().unwrap().api_key() - } - /// Returns the application secret of the inner flow. - fn application_secret(&self) -> ApplicationSecret { - self.inner.lock().unwrap().application_secret() - } - - fn token( - &mut self, - scopes: I, - ) -> Box + Send> - where - T: Into, - I: IntoIterator, - { - let (scope_key, scopes) = hash_scopes(scopes); + async fn get_token(&self, scope_key: u64, scopes: Vec) -> Result { let store = self.store.clone(); let mut delegate = self.delegate.clone(); let client = self.client.clone(); - let appsecret = self.inner.lock().unwrap().application_secret(); + let appsecret = self.inner.application_secret(); let gettoken = self.inner.clone(); - let loopfn = move |()| -> Box< - dyn Future, Error = RequestError> + Send, - > { - // How well does this work with tokio? - match store.lock().unwrap().get( + loop { + match store.get( scope_key.clone(), &scopes.iter().map(|s| s.as_str()).collect(), ) { Ok(Some(t)) => { if !t.expired() { - return Box::new(Ok(future::Loop::Break(t)).into_future()); + return Ok(t); } // Implement refresh flow. let refresh_token = t.refresh_token.clone(); let mut delegate = delegate.clone(); let store = store.clone(); let scopes = scopes.clone(); - let refresh_fut = RefreshFlow::refresh_token( + let rr = RefreshFlow::refresh_token( client.clone(), appsecret.clone(), refresh_token.unwrap(), ) - .and_then(move |rr| -> Box, Error=RequestError> + Send> { - match rr { - RefreshResult::Error(ref e) => { - delegate.token_refresh_failed( - format!("{}", e.description().to_string()), - &Some("the request has likely timed out".to_string()), - ); - Box::new(Err(RequestError::Refresh(rr)).into_future()) + .await?; + match rr { + RefreshResult::Error(ref e) => { + delegate.token_refresh_failed( + format!("{}", e.description().to_string()), + &Some("the request has likely timed out".to_string()), + ); + return Err(RequestError::Refresh(rr)); + } + RefreshResult::RefreshError(ref s, ref ss) => { + delegate.token_refresh_failed( + format!("{} {}", s, ss.clone().map(|s| format!("({})", s)).unwrap_or("".to_string())), + &Some("the refresh token is likely invalid and your authorization has been revoked".to_string()), + ); + return Err(RequestError::Refresh(rr)); + } + RefreshResult::Success(t) => { + let x = store.set( + scope_key, + &scopes.iter().map(|s| s.as_str()).collect(), + Some(t.clone()), + ); + if let Err(e) = x { + match delegate.token_storage_failure(true, &e) { + Retry::Skip => return Ok(t), + Retry::Abort => return Err(RequestError::Cache(Box::new(e))), + Retry::After(d) => tokio::timer::delay_for(d).await, } - RefreshResult::RefreshError(ref s, ref ss) => { - delegate.token_refresh_failed( - format!("{} {}", s, ss.clone().map(|s| format!("({})", s)).unwrap_or("".to_string())), - &Some("the refresh token is likely invalid and your authorization has been revoked".to_string()), - ); - Box::new(Err(RequestError::Refresh(rr)).into_future()) - } - RefreshResult::Success(t) => { - if let Err(e) = store.lock().unwrap().set(scope_key, &scopes.iter().map(|s| s.as_str()).collect(), Some(t.clone())) { - match delegate.token_storage_failure(true, &e) { - Retry::Skip => Box::new(Ok(future::Loop::Break(t)).into_future()), - Retry::Abort => Box::new(Err(RequestError::Cache(Box::new(e))).into_future()), - Retry::After(d) => Box::new( - tokio_timer::sleep(d) - .then(|_| Ok(future::Loop::Continue(()))), - ) - as Box< - dyn Future< - Item = future::Loop, - Error = RequestError> + Send>, - } - } else { - Box::new(Ok(future::Loop::Break(t)).into_future()) - } - }, + } else { + return Ok(t); } - }); - Box::new(refresh_fut) + } + } } Ok(None) => { let store = store.clone(); let scopes = scopes.clone(); let mut delegate = delegate.clone(); - Box::new( - gettoken - .lock() - .unwrap() - .token(scopes.clone()) - .and_then(move |t| { - if let Err(e) = store.lock().unwrap().set( - scope_key, - &scopes.iter().map(|s| s.as_str()).collect(), - Some(t.clone()), - ) { - match delegate.token_storage_failure(true, &e) { - Retry::Skip => { - Box::new(Ok(future::Loop::Break(t)).into_future()) - } - Retry::Abort => Box::new( - Err(RequestError::Cache(Box::new(e))).into_future(), - ), - Retry::After(d) => Box::new( - tokio_timer::sleep(d) - .then(|_| Ok(future::Loop::Continue(()))), - ) - as Box< - dyn Future< - Item = future::Loop, - Error = RequestError, - > + Send, - >, - } - } else { - Box::new(Ok(future::Loop::Break(t)).into_future()) - } - }), - ) + let t = gettoken.token(scopes.clone()).await?; + if let Err(e) = store.set( + scope_key, + &scopes.iter().map(|s| s.as_str()).collect(), + Some(t.clone()), + ) { + match delegate.token_storage_failure(true, &e) { + Retry::Skip => return Ok(t), + Retry::Abort => return Err(RequestError::Cache(Box::new(e))), + Retry::After(d) => tokio::timer::delay_for(d).await, + } + } else { + return Ok(t); + } } Err(err) => match delegate.token_storage_failure(false, &err) { - Retry::Abort | Retry::Skip => { - return Box::new(Err(RequestError::Cache(Box::new(err))).into_future()) - } - Retry::After(d) => { - return Box::new( - tokio_timer::sleep(d).then(|_| Ok(future::Loop::Continue(()))), - ) - } + Retry::Abort | Retry::Skip => return Err(RequestError::Cache(Box::new(err))), + Retry::After(d) => tokio::timer::delay_for(d).await, }, } - }; - Box::new(future::loop_fn((), loopfn)) + } + } +} + +impl< + GT: 'static + GetToken, + S: 'static + TokenStorage, + AD: 'static + AuthenticatorDelegate + Send + Sync, + C: 'static + hyper::client::connect::Connect + Clone + Send, + > GetToken for AuthenticatorImpl +{ + /// Returns the API Key of the inner flow. + fn api_key(&self) -> Option { + self.inner.api_key() + } + /// Returns the application secret of the inner flow. + fn application_secret(&self) -> ApplicationSecret { + self.inner.application_secret() + } + + fn token<'a, I, T>( + &'a self, + scopes: I, + ) -> Pin> + Send + 'a>> + where + T: Into, + I: IntoIterator, + { + let (scope_key, scopes) = hash_scopes(scopes); + Box::pin(self.get_token(scope_key, scopes)) } } diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 9077f6f..eb3bceb 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -2,14 +2,15 @@ use hyper; use std::error::Error; use std::fmt; -use std::io; +use std::pin::Pin; use crate::types::{PollError, RequestError}; use chrono::{DateTime, Local, Utc}; use std::time::Duration; -use futures::{future, prelude::*}; +use futures::prelude::*; +use tio::AsyncBufReadExt; use tokio::io as tio; /// A utility type to indicate how operations DeviceFlowHelper operations should be retried @@ -83,7 +84,7 @@ pub trait AuthenticatorDelegate: Clone { /// This can be useful if the underlying `TokenStorage` may fail occasionally. /// if `is_set` is true, the failure resulted from `TokenStorage.set(...)`. Otherwise, /// it was `TokenStorage.get(...)` - fn token_storage_failure(&mut self, is_set: bool, _: &dyn Error) -> Retry { + fn token_storage_failure(&mut self, is_set: bool, _: &(dyn Error + Send + Sync)) -> Retry { let _ = is_set; Retry::Abort } @@ -114,11 +115,11 @@ pub trait FlowDelegate: Clone { /// Called if the request code is expired. You will have to start over in this case. /// This will be the last call the delegate receives. /// Given `DateTime` is the expiration date - fn expired(&mut self, _: &DateTime) {} + fn expired(&self, _: &DateTime) {} /// Called if the user denied access. You would have to start over. /// This will be the last call the delegate receives. - fn denied(&mut self) {} + fn denied(&self) {} /// Called as long as we are waiting for the user to authorize us. /// Can be used to print progress information, or decide to time-out. @@ -127,7 +128,7 @@ pub trait FlowDelegate: Clone { /// # Notes /// * Only used in `DeviceFlow`. Return value will only be used if it /// is larger than the interval desired by the server. - fn pending(&mut self, _: &PollInformation) -> Retry { + fn pending(&self, _: &PollInformation) -> Retry { Retry::After(Duration::from_secs(5)) } @@ -140,7 +141,7 @@ pub trait FlowDelegate: Clone { /// # Notes /// * Will be called exactly once, provided we didn't abort during `request_code` phase. /// * Will only be called if the Authenticator's flow_type is `FlowType::Device`. - fn present_user_code(&mut self, pi: &PollInformation) { + fn present_user_code(&self, pi: &PollInformation) { println!( "Please enter {} at {} and grant access to this application", pi.user_code, pi.verification_url @@ -156,35 +157,44 @@ pub trait FlowDelegate: Clone { /// We need the user to navigate to a URL using their browser and potentially paste back a code /// (or maybe not). Whether they have to enter a code depends on the InstalledFlowReturnMethod /// used. - fn present_user_url + fmt::Display>( - &mut self, + fn present_user_url<'a, S: AsRef + fmt::Display + Send + Sync + 'a>( + &'a self, url: S, need_code: bool, - ) -> Box, Error = Box> + Send> { - if need_code { - println!( - "Please direct your browser to {}, follow the instructions and enter the \ - code displayed here: ", - url - ); + ) -> Pin>> + Send + 'a>> + { + Box::pin(present_user_url(url, need_code)) + } +} - Box::new( - tio::lines(io::BufReader::new(tio::stdin())) - .into_future() - .map_err(|(e, _)| { - println!("{:?}", e); - Box::new(e) as Box - }) - .and_then(|(l, _)| Ok(l)), - ) - } else { - println!( - "Please direct your browser to {} and follow the instructions displayed \ - there.", - url - ); - Box::new(future::ok(None)) +async fn present_user_url + fmt::Display>( + url: S, + need_code: bool, +) -> Result> { + if need_code { + println!( + "Please direct your browser to {}, follow the instructions and enter the \ + code displayed here: ", + url + ); + let mut user_input = String::new(); + match tio::BufReader::new(tio::stdin()) + .read_line(&mut user_input) + .await + { + Err(err) => { + println!("{:?}", err); + Err(Box::new(err) as Box) + } + Ok(_) => Ok(user_input), } + } else { + println!( + "Please direct your browser to {} and follow the instructions displayed \ + there.", + url + ); + Ok(String::new()) } } diff --git a/src/device.rs b/src/device.rs index 78cce65..dc4cfdf 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,16 +1,14 @@ use std::iter::{FromIterator, IntoIterator}; +use std::pin::Pin; use std::time::Duration; use ::log::{error, log}; use chrono::{self, Utc}; -use futures::stream::Stream; -use futures::{future, prelude::*}; -use http; +use futures::{prelude::*}; use hyper; use hyper::header; use itertools::Itertools; use serde_json as json; -use tokio_timer; use url::form_urlencoded; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate, PollInformation, Retry}; @@ -75,7 +73,7 @@ impl DeviceFlow { impl crate::authenticator::AuthFlow for DeviceFlow where - FD: FlowDelegate + Send + 'static, + FD: FlowDelegate + Send + Sync + 'static, C: hyper::client::connect::Connect + 'static, { type TokenGetter = DeviceFlowImpl; @@ -108,21 +106,21 @@ impl Flow for DeviceFlowImpl { } impl< - FD: FlowDelegate + Clone + Send + 'static, + FD: FlowDelegate + Clone + Send + Sync + 'static, C: hyper::client::connect::Connect + Sync + 'static, > GetToken for DeviceFlowImpl { - fn token( - &mut self, + fn token<'a, I, T>( + &'a self, scopes: I, - ) -> Box + Send> + ) -> Pin> + Send + 'a>> where T: Into, I: IntoIterator, { - self.retrieve_device_token(Vec::from_iter(scopes.into_iter().map(Into::into))) + Box::pin(self.retrieve_device_token(Vec::from_iter(scopes.into_iter().map(Into::into)))) } - fn api_key(&mut self) -> Option { + fn api_key(&self) -> Option { None } fn application_secret(&self) -> ApplicationSecret { @@ -139,75 +137,51 @@ where { /// Essentially what `GetToken::token` does: Retrieve a token for the given scopes without /// caching. - fn retrieve_device_token<'a>( - &mut self, + pub async fn retrieve_device_token<'a>( + &self, scopes: Vec, - ) -> Box + Send> { + ) -> Result { let application_secret = self.application_secret.clone(); let client = self.client.clone(); let wait = self.wait; - let mut fd = self.fd.clone(); - let request_code = Self::request_code( + let fd = self.fd.clone(); + let (pollinf, device_code) = Self::request_code( application_secret.clone(), client.clone(), self.device_code_url.clone(), scopes, ) - .and_then(move |(pollinf, device_code)| { - fd.present_user_code(&pollinf); - Ok((pollinf, device_code)) - }); - let fd = self.fd.clone(); - Box::new(request_code.and_then(move |(pollinf, device_code)| { - future::loop_fn(0, move |i| { - // Make a copy of everything every time, because the loop function needs to be - // repeatable, i.e. we can't move anything out. - let pt = Self::poll_token( - application_secret.clone(), - client.clone(), - device_code.clone(), - pollinf.clone(), - fd.clone(), - ); - let maxn = wait.as_secs() / pollinf.interval.as_secs(); - let mut fd = fd.clone(); - let pollinf = pollinf.clone(); - tokio_timer::sleep(pollinf.interval) - .then(|_| pt) - .then(move |r| match r { - Ok(None) if i < maxn => match fd.pending(&pollinf) { - Retry::Abort | Retry::Skip => { - Box::new(Err(RequestError::Poll(PollError::TimedOut)).into_future()) - } - Retry::After(d) => Box::new( - tokio_timer::sleep(d) - .then(move |_| Ok(future::Loop::Continue(i + 1))), - ) - as Box< - dyn Future< - Item = future::Loop, - Error = RequestError, - > + Send, - >, - }, - Ok(Some(tok)) => Box::new(Ok(future::Loop::Break(tok)).into_future()), - Err(e @ PollError::AccessDenied) - | Err(e @ PollError::TimedOut) - | Err(e @ PollError::Expired(_)) => { - Box::new(Err(RequestError::Poll(e)).into_future()) - } - Err(ref e) if i < maxn => { - error!("Unknown error from poll token api: {}", e); - Box::new(Ok(future::Loop::Continue(i + 1)).into_future()) - } - // Too many attempts. - Ok(None) | Err(_) => { - error!("Too many poll attempts"); - Box::new(Err(RequestError::Poll(PollError::TimedOut)).into_future()) - } - }) - }) - })) + .await?; + fd.present_user_code(&pollinf); + let maxn = wait.as_secs() / pollinf.interval.as_secs(); + for _ in 0..maxn { + let fd = fd.clone(); + let pollinf = pollinf.clone(); + tokio::timer::delay_for(pollinf.interval).await; + let r = Self::poll_token( + application_secret.clone(), + client.clone(), + device_code.clone(), + pollinf.clone(), + fd.clone(), + ) + .await; + match r { + Ok(None) => match fd.pending(&pollinf) { + Retry::Abort | Retry::Skip => { + return Err(RequestError::Poll(PollError::TimedOut)) + } + Retry::After(d) => tokio::timer::delay_for(d).await, + }, + Ok(Some(tok)) => return Ok(tok), + Err(e @ PollError::AccessDenied) + | Err(e @ PollError::TimedOut) + | Err(e @ PollError::Expired(_)) => return Err(RequestError::Poll(e)), + Err(ref e) => error!("Unknown error from poll token api: {}", e), + } + } + error!("Too many poll attempts"); + Err(RequestError::Poll(PollError::TimedOut)) } /// The first step involves asking the server for a code that the user @@ -225,12 +199,12 @@ where /// * If called after a successful result was returned at least once. /// # Examples /// See test-cases in source code for a more complete example. - fn request_code( + async fn request_code( application_secret: ApplicationSecret, client: hyper::Client, device_code_url: String, scopes: Vec, - ) -> impl Future { + ) -> Result<(PollInformation, String), RequestError> { // note: cloned() shouldn't be needed, see issue // https://github.com/servo/rust-url/issues/81 let req = form_urlencoded::Serializer::new(String::new()) @@ -248,66 +222,48 @@ where // note: works around bug in rustlang // https://github.com/rust-lang/rust/issues/22252 - let request = hyper::Request::post(device_code_url) + let req = hyper::Request::post(device_code_url) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(req)) - .into_future(); - request - .then( - move |request: Result, http::Error>| { - let request = request.unwrap(); - client.request(request) - }, - ) - .then( - |r: Result, hyper::error::Error>| { - match r { - Err(err) => { - return Err(RequestError::ClientError(err)); - } - Ok(res) => { - // This return type is defined in https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.2 - // The alias is present as Google use a non-standard name for verification_uri. - // According to the standard interval is optional, however, all tested implementations provide it. - // verification_uri_complete is optional in the standard but not provided in tested implementations. - #[derive(Deserialize)] - struct JsonData { - device_code: String, - user_code: String, - #[serde(alias = "verification_url")] - verification_uri: String, - expires_in: Option, - interval: i64, - } + .unwrap(); + let resp = client + .request(req) + .await + .map_err(|e| RequestError::ClientError(e))?; + // This return type is defined in https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.2 + // The alias is present as Google use a non-standard name for verification_uri. + // According to the standard interval is optional, however, all tested implementations provide it. + // verification_uri_complete is optional in the standard but not provided in tested implementations. + #[derive(Deserialize)] + struct JsonData { + device_code: String, + user_code: String, + #[serde(alias = "verification_url")] + verification_uri: String, + expires_in: Option, + interval: i64, + } - let json_str: String = res - .into_body() - .concat2() - .wait() - .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) - .unwrap(); // TODO: error handling + let json_bytes = resp.into_body().try_concat().await?; - // check for error - match json::from_str::(&json_str) { - Err(_) => {} // ignore, move on - Ok(res) => return Err(RequestError::from(res)), - } + // check for error + match json::from_slice::(&json_bytes) { + Err(_) => {} // ignore, move on + Ok(res) => return Err(RequestError::from(res)), + } - let decoded: JsonData = json::from_str(&json_str).unwrap(); + let decoded: JsonData = + json::from_slice(&json_bytes).map_err(|e| RequestError::JSONError(e))?; - let expires_in = decoded.expires_in.unwrap_or(60 * 60); + let expires_in = decoded.expires_in.unwrap_or(60 * 60); - let pi = PollInformation { - user_code: decoded.user_code, - verification_url: decoded.verification_uri, - expires_at: Utc::now() + chrono::Duration::seconds(expires_in), - interval: Duration::from_secs(i64::abs(decoded.interval) as u64), - }; - Ok((pi, decoded.device_code)) - } - } - }, - ) + let pi = PollInformation { + user_code: decoded.user_code, + verification_url: decoded.verification_uri, + expires_at: Utc::now() + chrono::Duration::seconds(expires_in), + interval: Duration::from_secs(i64::abs(decoded.interval) as u64), + }; + Ok((pi, decoded.device_code)) } /// If the first call is successful, this method may be called. @@ -328,19 +284,17 @@ where /// /// # Examples /// See test-cases in source code for a more complete example. - fn poll_token<'a>( + async fn poll_token<'a>( application_secret: ApplicationSecret, client: hyper::Client, device_code: String, pi: PollInformation, - mut fd: FD, - ) -> impl Future, Error = PollError> { - let expired = if pi.expires_at <= Utc::now() { + fd: FD, + ) -> Result, PollError> { + if pi.expires_at <= Utc::now() { fd.expired(&pi.expires_at); - Err(PollError::Expired(pi.expires_at)).into_future() - } else { - Ok(()).into_future() - }; + return Err(PollError::Expired(pi.expires_at)); + } // We should be ready for a new request let req = form_urlencoded::Serializer::new(String::new()) @@ -356,46 +310,44 @@ where .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(req)) .unwrap(); // TODO: Error checking - expired - .and_then(move |_| client.request(request).map_err(|e| PollError::HttpError(e))) - .map(|res| { - res.into_body() - .concat2() - .wait() - .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) - .unwrap() // TODO: error handling - }) - .and_then(move |json_str: String| { - #[derive(Deserialize)] - struct JsonError { - error: String, - } + let res = client + .request(request) + .await + .map_err(|e| PollError::HttpError(e))?; + let body = res + .into_body() + .try_concat() + .await + .map_err(|e| PollError::HttpError(e))?; + #[derive(Deserialize)] + struct JsonError { + error: String, + } - match json::from_str::(&json_str) { - Err(_) => {} // ignore, move on, it's not an error - Ok(res) => { - match res.error.as_ref() { - "access_denied" => { - fd.denied(); - return Err(PollError::AccessDenied); - } - "authorization_pending" => return Ok(None), - s => { - return Err(PollError::Other(format!( - "server message '{}' not understood", - s - ))) - } - }; + match json::from_slice::(&body) { + Err(_) => {} // ignore, move on, it's not an error + Ok(res) => { + match res.error.as_ref() { + "access_denied" => { + fd.denied(); + return Err(PollError::AccessDenied); } - } + "authorization_pending" => return Ok(None), + s => { + return Err(PollError::Other(format!( + "server message '{}' not understood", + s + ))) + } + }; + } + } - // yes, we expect that ! - let mut t: Token = json::from_str(&json_str).unwrap(); - t.set_expiry_absolute(); + // yes, we expect that ! + let mut t: Token = json::from_slice(&body).unwrap(); + t.set_expiry_absolute(); - Ok(Some(t.clone())) - }) + Ok(Some(t)) } } @@ -415,7 +367,7 @@ mod tests { #[derive(Clone)] struct FD; impl FlowDelegate for FD { - fn present_user_code(&mut self, pi: &PollInformation) { + fn present_user_code(&self, pi: &PollInformation) { assert_eq!("https://example.com/verify", pi.verification_url); } } @@ -426,17 +378,17 @@ mod tests { app_secret.token_uri = format!("{}/token", server_url); let device_code_url = format!("{}/code", server_url); - let https = HttpsConnector::new(1); + let https = HttpsConnector::new(); let client = hyper::Client::builder() .keep_alive(false) .build::<_, hyper::Body>(https); - let mut flow = DeviceFlow::new(app_secret) + let flow = DeviceFlow::new(app_secret) .delegate(FD) .device_code_url(device_code_url) .build_token_getter(client); - let mut rt = tokio::runtime::Builder::new() + let rt = tokio::runtime::Builder::new() .core_threads(1) .panic_handler(|e| std::panic::resume_unwind(e)) .build() @@ -461,13 +413,14 @@ mod tests { .with_body(token_response) .create(); - let fut = flow - .token(vec!["https://www.googleapis.com/scope/1"]) - .then(|token| { - let token = token.unwrap(); - assert_eq!("accesstoken", token.access_token); - Ok(()) as Result<(), ()> - }); + let fut = async { + let token = flow + .token(vec!["https://www.googleapis.com/scope/1"]) + .await + .unwrap(); + assert_eq!("accesstoken", token.access_token); + Ok(()) as Result<(), ()> + }; rt.block_on(fut).expect("block_on"); _m.assert(); @@ -493,13 +446,12 @@ mod tests { .expect(0) // Never called! .create(); - let fut = flow - .token(vec!["https://www.googleapis.com/scope/1"]) - .then(|token| { - assert!(token.is_err()); - assert!(format!("{}", token.unwrap_err()).contains("invalid_client_id")); - Ok(()) as Result<(), ()> - }); + let fut = async { + let res = flow.token(vec!["https://www.googleapis.com/scope/1"]).await; + assert!(res.is_err()); + assert!(format!("{}", res.unwrap_err()).contains("invalid_client_id")); + Ok(()) as Result<(), ()> + }; rt.block_on(fut).expect("block_on"); _m.assert(); @@ -524,13 +476,12 @@ mod tests { .expect(1) .create(); - let fut = flow - .token(vec!["https://www.googleapis.com/scope/1"]) - .then(|token| { - assert!(token.is_err()); - assert!(format!("{}", token.unwrap_err()).contains("Access denied by user")); - Ok(()) as Result<(), ()> - }); + let fut = async { + let res = flow.token(vec!["https://www.googleapis.com/scope/1"]).await; + assert!(res.is_err()); + assert!(format!("{}", res.unwrap_err()).contains("Access denied by user")); + Ok(()) as Result<(), ()> + }; rt.block_on(fut).expect("block_on"); _m.assert(); diff --git a/src/installed.rs b/src/installed.rs index 31c967a..1733e6f 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -3,13 +3,16 @@ // Refer to the project root for licensing information. // use std::convert::AsRef; +use std::future::Future; +use std::net::SocketAddr; +use std::pin::Pin; use std::sync::{Arc, Mutex}; -use futures::prelude::*; -use futures::stream::Stream; -use futures::sync::oneshot; +use futures::future::FutureExt; +use futures_util::try_stream::TryStreamExt; use hyper; -use hyper::{header, StatusCode, Uri}; +use hyper::header; +use tokio::sync::oneshot; use url::form_urlencoded; use url::percent_encoding::{percent_encode, QUERY_ENCODE_SET}; @@ -58,20 +61,22 @@ where }) } -impl - GetToken for InstalledFlowImpl +impl< + FD: FlowDelegate + 'static + Send + Sync + Clone, + C: hyper::client::connect::Connect + 'static, + > GetToken for InstalledFlowImpl { - fn token( - &mut self, + fn token<'a, I, T>( + &'a self, scopes: I, - ) -> Box + Send> + ) -> Pin> + Send + 'a>> where T: Into, I: IntoIterator, { - Box::new(self.obtain_token(scopes.into_iter().map(Into::into).collect())) + Box::pin(self.obtain_token(scopes.into_iter().map(Into::into).collect())) } - fn api_key(&mut self) -> Option { + fn api_key(&self) -> Option { None } fn application_secret(&self) -> ApplicationSecret { @@ -140,7 +145,7 @@ where impl crate::authenticator::AuthFlow for InstalledFlow where - FD: FlowDelegate + Send + 'static, + FD: FlowDelegate + Send + Sync + 'static, C: hyper::client::connect::Connect + 'static, { type TokenGetter = InstalledFlowImpl; @@ -164,160 +169,136 @@ impl<'c, FD: 'static + FlowDelegate + Clone + Send, C: 'c + hyper::client::conne /// . Return that token /// /// It's recommended not to use the DefaultFlowDelegate, but a specialized one. - fn obtain_token<'a>( - &mut self, + async fn obtain_token<'a>( + &self, scopes: Vec, // Note: I haven't found a better way to give a list of strings here, due to ownership issues with futures. - ) -> impl 'a + Future + Send { - let rduri = self.fd.redirect_uri(); - // Start server on localhost to accept auth code. - let server_bind_port = match self.method { - InstalledFlowReturnMethod::HTTPRedirect(port) => Some(port), - InstalledFlowReturnMethod::HTTPRedirectEphemeral => Some(0), - _ => None, - }; - let server = if let Some(port) = server_bind_port { - match InstalledFlowServer::new(port) { - Result::Err(e) => Err(RequestError::ClientError(e)), - Result::Ok(server) => Ok(Some(server)), + ) -> Result { + match self.method { + InstalledFlowReturnMethod::HTTPRedirect(port) => { + self.ask_auth_code_via_http(scopes.iter(), port).await } - } else { - Ok(None) - }; - let port = if let Ok(Some(ref srv)) = server { - Some(srv.port) - } else { - None - }; - let client = self.client.clone(); - let (appsecclone, appsecclone2) = (self.appsecret.clone(), self.appsecret.clone()); - let auth_delegate = self.fd.clone(); - server - .into_future() - // First: Obtain authorization code from user. - .and_then(move |server| { - Self::ask_authorization_code(server, auth_delegate, &appsecclone, scopes.iter()) - }) - // Exchange the authorization code provided by Google/the provider for a refresh and an - // access token. - .and_then(move |authcode| { - let request = Self::request_token(appsecclone2, authcode, rduri, port); - let result = client.request(request); - // Handle result here, it makes ownership tracking easier. - result - .and_then(move |r| { - r.into_body() - .concat2() - .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) - // TODO: error handling - }) - .then(|body_or| { - let resp = match body_or { - Err(e) => return Err(RequestError::ClientError(e)), - Ok(s) => s, - }; - - let token_resp: Result = - serde_json::from_str(&resp); - - match token_resp { - Err(e) => { - return Err(RequestError::JSONError(e)); - } - Ok(tok) => { - if tok.error.is_some() { - Err(RequestError::NegativeServerResponse( - tok.error.unwrap(), - tok.error_description, - )) - } else { - Ok(tok) - } - } - } - }) - }) - // Return the combined token. - .and_then(|tokens| { - // Successful response - if tokens.access_token.is_some() { - let mut token = Token { - access_token: tokens.access_token.unwrap(), - refresh_token: Some(tokens.refresh_token.unwrap()), - token_type: tokens.token_type.unwrap(), - expires_in: tokens.expires_in, - expires_in_timestamp: None, - }; - - token.set_expiry_absolute(); - Ok(token) - } else { - Err(RequestError::NegativeServerResponse( - tokens.error.unwrap(), - tokens.error_description, - )) - } - }) + InstalledFlowReturnMethod::HTTPRedirectEphemeral => { + self.ask_auth_code_via_http(scopes.iter(), 0).await + } + InstalledFlowReturnMethod::Interactive => { + self.ask_auth_code_interactively(scopes.iter()).await + } + } } - fn ask_authorization_code<'a, S, T>( - server: Option, - mut auth_delegate: FD, - appsecret: &ApplicationSecret, - scopes: S, - ) -> Box + Send> + async fn ask_auth_code_interactively<'a, S, T>(&self, scopes: S) -> Result where T: AsRef + 'a, S: Iterator, { - if server.is_none() { - let url = build_authentication_request_url( - &appsecret.auth_uri, - &appsecret.client_id, - scopes, - auth_delegate.redirect_uri(), - ); - Box::new( - auth_delegate - .present_user_url(&url, true /* need_code */) - .then(|r| { - match r { - Ok(Some(mut code)) => { - // Partial backwards compatibility in case an implementation adds a new line - // due to previous behaviour. - let ends_with_newline = - code.chars().last().map(|c| c == '\n').unwrap_or(false); - if ends_with_newline { - code.pop(); - } - Ok(code) - } - _ => Err(RequestError::UserError("couldn't read code".to_string())), - } - }), - ) - } else { - let mut server = server.unwrap(); - // The redirect URI must be this very localhost URL, otherwise authorization is refused - // by certain providers. - let url = build_authentication_request_url( - &appsecret.auth_uri, - &appsecret.client_id, - scopes, - auth_delegate - .redirect_uri() - .or_else(|| Some(format!("http://localhost:{}", server.port))), - ); - Box::new( - auth_delegate - .present_user_url(&url, false /* need_code */) - .then(move |_| server.block_till_auth()) - .map_err(|e| { - RequestError::UserError(format!( - "could not obtain token via redirect: {}", - e - )) - }), - ) + let auth_delegate = &self.fd; + let appsecret = &self.appsecret; + let url = build_authentication_request_url( + &appsecret.auth_uri, + &appsecret.client_id, + scopes, + auth_delegate.redirect_uri(), + ); + let authcode = match auth_delegate + .present_user_url(&url, true /* need code */) + .await + { + Ok(mut code) => { + // Partial backwards compatibility in case an implementation adds a new line + // due to previous behaviour. + let ends_with_newline = code.chars().last().map(|c| c == '\n').unwrap_or(false); + if ends_with_newline { + code.pop(); + } + code + } + _ => return Err(RequestError::UserError("couldn't read code".to_string())), + }; + self.exchange_auth_code(authcode, None).await + } + + async fn ask_auth_code_via_http<'a, S, T>( + &self, + scopes: S, + desired_port: u16, + ) -> Result + where + T: AsRef + 'a, + S: Iterator, + { + let auth_delegate = &self.fd; + let appsecret = &self.appsecret; + let server = InstalledFlowServer::run(desired_port)?; + let bound_port = server.local_addr().port(); + + // Present url to user. + // The redirect URI must be this very localhost URL, otherwise authorization is refused + // by certain providers. + let url = build_authentication_request_url( + &appsecret.auth_uri, + &appsecret.client_id, + scopes, + auth_delegate + .redirect_uri() + .or_else(|| Some(format!("http://localhost:{}", bound_port))), + ); + let _ = auth_delegate + .present_user_url(&url, false /* need code */) + .await; + + let auth_code = server.wait_for_auth_code().await; + self.exchange_auth_code(auth_code, Some(bound_port)).await + } + + async fn exchange_auth_code( + &self, + authcode: String, + port: Option, + ) -> Result { + let appsec = &self.appsecret; + let redirect_uri = &self.fd.redirect_uri(); + let request = Self::request_token(appsec.clone(), authcode, redirect_uri.clone(), port); + let resp = self + .client + .request(request) + .await + .map_err(|e| RequestError::ClientError(e))?; + let body = resp + .into_body() + .try_concat() + .await + .map_err(|e| RequestError::ClientError(e))?; + let tokens: JSONTokenResponse = + serde_json::from_slice(&body).map_err(|e| RequestError::JSONError(e))?; + match tokens { + JSONTokenResponse { + error: Some(err), + error_description, + .. + } => Err(RequestError::NegativeServerResponse(err, error_description)), + JSONTokenResponse { + access_token: Some(access_token), + refresh_token, + token_type: Some(token_type), + expires_in, + .. + } => { + let mut token = Token { + access_token, + refresh_token, + token_type, + expires_in, + expires_in_timestamp: None, + }; + token.set_expiry_absolute(); + Ok(token) + } + JSONTokenResponse { + error_description, .. + } => Err(RequestError::NegativeServerResponse( + "".to_owned(), + error_description, + )), } } @@ -362,124 +343,86 @@ struct JSONTokenResponse { error_description: Option, } +fn spawn_with_handle(f: F) -> impl Future +where + F: Future + 'static + Send, +{ + let (tx, rx) = oneshot::channel(); + tokio::spawn(f.map(move |_| tx.send(()).unwrap())); + async { + let _ = rx.await; + } +} + struct InstalledFlowServer { - port: u16, - shutdown_tx: Option>, - auth_code_rx: Option>, - threadpool: Option, + addr: SocketAddr, + auth_code_rx: oneshot::Receiver, + trigger_shutdown_tx: oneshot::Sender<()>, + shutdown_complete: Pin + Send>>, } impl InstalledFlowServer { - fn new(port: u16) -> Result { + fn run(desired_port: u16) -> Result { + use hyper::service::{make_service_fn, service_fn}; let (auth_code_tx, auth_code_rx) = oneshot::channel::(); - let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let (trigger_shutdown_tx, trigger_shutdown_rx) = oneshot::channel::<()>(); + let auth_code_tx = Arc::new(Mutex::new(Some(auth_code_tx))); - let threadpool = tokio_threadpool::Builder::new() - .pool_size(1) - .name_prefix("InstalledFlowServer-") - .build(); - let service_maker = InstalledFlowServiceMaker::new(auth_code_tx); - - let addr: std::net::SocketAddr = ([127, 0, 0, 1], port).into(); - let builder = hyper::server::Server::try_bind(&addr)?; - let server = builder.http1_only(true).serve(service_maker); - let port = server.local_addr().port(); - let server_future = server - .with_graceful_shutdown(shutdown_rx) - .map_err(|err| panic!("Failed badly: {}", err)); - - threadpool.spawn(server_future); - - Result::Ok(InstalledFlowServer { - port: port, - shutdown_tx: Some(shutdown_tx), - auth_code_rx: Some(auth_code_rx), - threadpool: Some(threadpool), + let service = make_service_fn(move |_| { + let auth_code_tx = auth_code_tx.clone(); + async move { + use std::convert::Infallible; + Ok::<_, Infallible>(service_fn(move |req| { + installed_flow_server::handle_req(req, auth_code_tx.clone()) + })) + } + }); + let addr: std::net::SocketAddr = ([127, 0, 0, 1], desired_port).into(); + let server = hyper::server::Server::try_bind(&addr)?; + let server = server.http1_only(true).serve(service); + let addr = server.local_addr(); + let shutdown_complete = spawn_with_handle(async { + let _ = server + .with_graceful_shutdown(async move { + let _ = trigger_shutdown_rx.await; + }) + .await; + }); + Ok(InstalledFlowServer { + addr, + auth_code_rx, + trigger_shutdown_tx, + shutdown_complete: Box::pin(shutdown_complete), }) } - fn block_till_auth(&mut self) -> Result { - match self.auth_code_rx.take() { - Some(auth_code_rx) => auth_code_rx.wait(), - None => Result::Err(oneshot::Canceled), - } + fn local_addr(&self) -> SocketAddr { + self.addr + } + + async fn wait_for_auth_code(self) -> String { + // Wait for the auth code from the server. + let auth_code = self + .auth_code_rx + .await + .expect("server shutdown while waiting for auth_code"); + // auth code received. shutdown the server + let _ = self.trigger_shutdown_tx.send(()); + self.shutdown_complete.await; + auth_code } } -impl std::ops::Drop for InstalledFlowServer { - fn drop(&mut self) { - self.shutdown_tx.take().map(|tx| tx.send(())); - self.auth_code_rx.take().map(|mut rx| rx.close()); - self.threadpool.take(); - } -} +mod installed_flow_server { + use hyper::{Body, Request, Response, StatusCode, Uri}; + use std::sync::{Arc, Mutex}; + use tokio::sync::oneshot; + use url::form_urlencoded; -pub struct InstalledFlowHandlerResponseFuture { - inner: Box< - dyn futures::Future, Error = hyper::http::Error> + Send, - >, -} - -impl InstalledFlowHandlerResponseFuture { - fn new( - fut: Box< - dyn futures::Future, Error = hyper::http::Error> - + Send, - >, - ) -> Self { - Self { inner: fut } - } -} - -impl futures::Future for InstalledFlowHandlerResponseFuture { - type Item = hyper::Response; - type Error = hyper::http::Error; - - fn poll(&mut self) -> futures::Poll { - self.inner.poll() - } -} - -/// Creates InstalledFlowService on demand -struct InstalledFlowServiceMaker { - auth_code_tx: Arc>>>, -} - -impl InstalledFlowServiceMaker { - fn new(auth_code_tx: oneshot::Sender) -> InstalledFlowServiceMaker { - let auth_code_tx = Arc::new(Mutex::new(Option::Some(auth_code_tx))); - InstalledFlowServiceMaker { auth_code_tx } - } -} - -impl hyper::service::MakeService for InstalledFlowServiceMaker { - type ReqBody = hyper::Body; - type ResBody = hyper::Body; - type Error = hyper::http::Error; - type Service = InstalledFlowService; - type Future = futures::future::FutureResult; - type MakeError = hyper::http::Error; - - fn make_service(&mut self, _ctx: Ctx) -> Self::Future { - let service = InstalledFlowService { - auth_code_tx: self.auth_code_tx.clone(), - }; - futures::future::ok(service) - } -} - -/// HTTP service handling the redirect from the provider. -struct InstalledFlowService { - auth_code_tx: Arc>>>, -} - -impl hyper::service::Service for InstalledFlowService { - type ReqBody = hyper::Body; - type ResBody = hyper::Body; - type Error = hyper::http::Error; - type Future = InstalledFlowHandlerResponseFuture; - - fn call(&mut self, req: hyper::Request) -> Self::Future { + pub(super) async fn handle_req( + req: Request, + auth_code_tx: Arc>>>, + ) -> Result, http::Error> { match req.uri().path_and_query() { Some(path_and_query) => { // We use a fake URL because the redirect goes to a URL, meaning we @@ -491,77 +434,44 @@ impl hyper::service::Service for InstalledFlowService { .path_and_query(path_and_query.clone()) .build(); - if url.is_err() { - let response = hyper::Response::builder() + match url { + Err(_) => hyper::Response::builder() .status(StatusCode::BAD_REQUEST) - .body(hyper::Body::from("Unparseable URL")); - - match response { - Ok(response) => InstalledFlowHandlerResponseFuture::new(Box::new( - futures::future::ok(response), - )), - Err(err) => InstalledFlowHandlerResponseFuture::new(Box::new( - futures::future::err(err), - )), - } - } else { - self.handle_url(url.unwrap()); - let response = - hyper::Response::builder() - .status(StatusCode::OK) - .body(hyper::Body::from( - "SuccessYou may now \ - close this window.", - )); - - match response { - Ok(response) => InstalledFlowHandlerResponseFuture::new(Box::new( - futures::future::ok(response), - )), - Err(err) => InstalledFlowHandlerResponseFuture::new(Box::new( - futures::future::err(err), - )), - } - } - } - None => { - let response = hyper::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(hyper::Body::from("Invalid Request!")); - - match response { - Ok(response) => InstalledFlowHandlerResponseFuture::new(Box::new( - futures::future::ok(response), - )), - Err(err) => { - InstalledFlowHandlerResponseFuture::new(Box::new(futures::future::err(err))) - } + .body(hyper::Body::from("Unparseable URL")), + Ok(url) => match auth_code_from_url(url) { + Some(auth_code) => { + if let Some(sender) = auth_code_tx.lock().unwrap().take() { + let _ = sender.send(auth_code); + } + hyper::Response::builder().status(StatusCode::OK).body( + hyper::Body::from( + "SuccessYou may now \ + close this window.", + ), + ) + } + None => hyper::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(hyper::Body::from("No `code` in URL")), + }, } } + None => hyper::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(hyper::Body::from("Invalid Request!")), } } -} -impl InstalledFlowService { - fn handle_url(&mut self, url: hyper::Uri) { + fn auth_code_from_url(url: hyper::Uri) -> Option { // The provider redirects to the specified localhost URL, appending the authorization // code, like this: http://localhost:8080/xyz/?code=4/731fJ3BheyCouCniPufAd280GHNV5Ju35yYcGs - // We take that code and send it to the ask_authorization_code() function that - // waits for it. - for (param, val) in form_urlencoded::parse(url.query().unwrap_or("").as_bytes()) { - if param == "code".to_string() { - let mut auth_code_tx = self.auth_code_tx.lock().unwrap(); - match auth_code_tx.take() { - Some(auth_code_tx) => { - let _ = auth_code_tx.send(val.to_owned().to_string()); - } - None => { - // call to the server after a previous call. Each server is only designed - // to receive a single request. - } - }; + form_urlencoded::parse(url.query().unwrap_or("").as_bytes()).find_map(|(param, val)| { + if param == "code" { + Some(val.into_owned()) + } else { + None } - } + }) } } @@ -571,7 +481,7 @@ mod tests { use std::fmt; use std::str::FromStr; - use hyper; + use hyper::Uri; use hyper::client::connect::HttpConnector; use hyper_rustls::HttpsConnector; use mockito::{self, mock}; @@ -593,45 +503,44 @@ mod tests { impl FlowDelegate for FD { /// Depending on need_code, return the pre-set code or send the code to the server at /// the redirect_uri given in the url. - fn present_user_url + fmt::Display>( - &mut self, + fn present_user_url<'a, S: AsRef + fmt::Display + Send + Sync + 'a>( + &'a self, url: S, need_code: bool, - ) -> Box, Error = Box> + Send> - { - if need_code { - Box::new(Ok(Some(self.0.clone())).into_future()) - } else { - // Parse presented url to obtain redirect_uri with location of local - // code-accepting server. - let uri = Uri::from_str(url.as_ref()).unwrap(); - let query = uri.query().unwrap(); - let parsed = form_urlencoded::parse(query.as_bytes()).into_owned(); - let mut rduri = None; - for (k, v) in parsed { - if k == "redirect_uri" { - rduri = Some(v); - break; + ) -> Pin< + Box>> + Send + 'a>, + > { + Box::pin(async move { + if need_code { + Ok(self.0.clone()) + } else { + // Parse presented url to obtain redirect_uri with location of local + // code-accepting server. + let uri = Uri::from_str(url.as_ref()).unwrap(); + let query = uri.query().unwrap(); + let parsed = form_urlencoded::parse(query.as_bytes()).into_owned(); + let mut rduri = None; + for (k, v) in parsed { + if k == "redirect_uri" { + rduri = Some(v); + break; + } } - } - if rduri.is_none() { - return Box::new( - Err(Box::new(StringError::new("no redirect uri!", None)) - as Box) - .into_future(), - ); - } - let mut rduri = rduri.unwrap(); - rduri.push_str(&format!("?code={}", self.0)); - let rduri = Uri::from_str(rduri.as_ref()).unwrap(); - // Hit server. - return Box::new( + if rduri.is_none() { + return Err(Box::new(StringError::new("no redirect uri!", None)) + as Box); + } + let mut rduri = rduri.unwrap(); + rduri.push_str(&format!("?code={}", self.0)); + let rduri = Uri::from_str(rduri.as_ref()).unwrap(); + // Hit server. self.1 .get(rduri) - .map_err(|e| Box::new(e) as Box) - .map(|_| None), - ); - } + .await + .map_err(|e| Box::new(e) as Box) + .map(|_| "".to_string()) + } + }) } } @@ -640,18 +549,18 @@ mod tests { let mut app_secret = parse_application_secret(app_secret).unwrap(); app_secret.token_uri = format!("{}/token", server_url); - let https = HttpsConnector::new(1); + let https = HttpsConnector::new(); let client = hyper::Client::builder() .keep_alive(false) .build::<_, hyper::Body>(https); let fd = FD("authorizationcode".to_string(), client.clone()); - let mut inf = + let inf = InstalledFlow::new(app_secret.clone(), InstalledFlowReturnMethod::Interactive) .delegate(fd) .build_token_getter(client.clone()); - let mut rt = tokio::runtime::Builder::new() + let rt = tokio::runtime::Builder::new() .core_threads(1) .panic_handler(|e| std::panic::resume_unwind(e)) .build() @@ -665,20 +574,24 @@ mod tests { .expect(1) .create(); - let fut = inf - .token(vec!["https://googleapis.com/some/scope"]) - .and_then(|tok| { + let fut = || { + async { + let tok = inf + .token(vec!["https://googleapis.com/some/scope"]) + .await + .map_err(|_| ())?; assert_eq!("accesstoken", tok.access_token); assert_eq!("refreshtoken", tok.refresh_token.unwrap()); assert_eq!("Bearer", tok.token_type); - Ok(()) - }); - rt.block_on(fut).expect("block on"); + Ok(()) as Result<(), ()> + } + }; + rt.block_on(fut()).expect("block on"); _m.assert(); } // Successful path with HTTP redirect. { - let mut inf = + let inf = InstalledFlow::new(app_secret, InstalledFlowReturnMethod::HTTPRedirect(8081)) .delegate(FD( "authorizationcodefromlocalserver".to_string(), @@ -691,14 +604,16 @@ mod tests { .expect(1) .create(); - let fut = inf - .token(vec!["https://googleapis.com/some/scope"]) - .and_then(|tok| { - assert_eq!("accesstoken", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); - Ok(()) - }); + let fut = async { + let tok = inf + .token(vec!["https://googleapis.com/some/scope"]) + .await + .map_err(|_| ())?; + assert_eq!("accesstoken", tok.access_token); + assert_eq!("refreshtoken", tok.refresh_token.unwrap()); + assert_eq!("Bearer", tok.token_type); + Ok(()) as Result<(), ()> + }; rt.block_on(fut).expect("block on"); _m.assert(); } @@ -713,17 +628,16 @@ mod tests { .expect(1) .create(); - let fut = inf - .token(vec!["https://googleapis.com/some/scope"]) - .then(|tokr| { - assert!(tokr.is_err()); - assert!(format!("{}", tokr.unwrap_err()).contains("invalid_code")); - Ok(()) as Result<(), ()> - }); + let fut = async { + let tokr = inf.token(vec!["https://googleapis.com/some/scope"]).await; + assert!(tokr.is_err()); + assert!(format!("{}", tokr.unwrap_err()).contains("invalid_code")); + Ok(()) as Result<(), ()> + }; rt.block_on(fut).expect("block on"); _m.assert(); } - rt.shutdown_on_idle().wait().expect("shutdown"); + rt.shutdown_on_idle(); } #[test] @@ -743,41 +657,53 @@ mod tests { ); } - #[test] - fn test_server_random_local_port() { - let addr1 = InstalledFlowServer::new(0).unwrap(); - let addr2 = InstalledFlowServer::new(0).unwrap(); - assert_ne!(addr1.port, addr2.port); + #[tokio::test] + async fn test_server_random_local_port() { + let addr1 = InstalledFlowServer::run(0).unwrap().local_addr(); + let addr2 = InstalledFlowServer::run(0).unwrap().local_addr(); + assert_ne!(addr1.port(), addr2.port()); } - #[test] - fn test_http_handle_url() { + #[tokio::test] + async fn test_http_handle_url() { let (tx, rx) = oneshot::channel(); - let mut handler = InstalledFlowService { - auth_code_tx: Arc::new(Mutex::new(Option::Some(tx))), - }; // URLs are usually a bit botched let url: Uri = "http://example.com:1234/?code=ab/c%2Fd#".parse().unwrap(); - handler.handle_url(url); - assert_eq!(rx.wait().unwrap(), "ab/c/d".to_string()); + let req = hyper::Request::get(url) + .body(hyper::body::Body::empty()) + .unwrap(); + installed_flow_server::handle_req(req, Arc::new(Mutex::new(Some(tx)))) + .await + .unwrap(); + assert_eq!(rx.await.unwrap().as_str(), "ab/c/d"); } - #[test] - fn test_server() { - let runtime = tokio::runtime::Runtime::new().unwrap(); + #[tokio::test] + async fn test_server() { let client: hyper::Client = - hyper::Client::builder() - .executor(runtime.executor()) - .build_http(); - let mut server = InstalledFlowServer::new(0).unwrap(); + hyper::Client::builder().build_http(); + let server = InstalledFlowServer::run(0).unwrap(); + + let response = client + .get(format!("http://{}/", server.local_addr()).parse().unwrap()) + .await; + match response { + Result::Ok(_response) => { + // TODO: Do we really want this to assert success? + //assert!(response.status().is_success()); + } + Result::Err(err) => { + assert!(false, "Failed to request from local server: {:?}", err); + } + } let response = client .get( - format!("http://127.0.0.1:{}/", server.port) + format!("http://{}/?code=ab/c%2Fd#", server.local_addr()) .parse() .unwrap(), ) - .wait(); + .await; match response { Result::Ok(response) => { assert!(response.status().is_success()); @@ -787,29 +713,6 @@ mod tests { } } - let response = client - .get( - format!("http://127.0.0.1:{}/?code=ab/c%2Fd#", server.port) - .parse() - .unwrap(), - ) - .wait(); - match response { - Result::Ok(response) => { - assert!(response.status().is_success()); - } - Result::Err(err) => { - assert!(false, "Failed to request from local server: {:?}", err); - } - } - - match server.block_till_auth() { - Result::Ok(response) => { - assert_eq!(response, "ab/c/d".to_string()); - } - Result::Err(err) => { - assert!(false, "Server failed to pass on the message: {:?}", err); - } - } + assert_eq!(server.wait_for_auth_code().await.as_str(), "ab/c/d"); } } diff --git a/src/lib.rs b/src/lib.rs index 48e3b28..f83144e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,7 +47,8 @@ //! //! use std::path::Path; //! -//! fn main() { +//! #[tokio::main] +//! async fn main() { //! // Read application secret from a file. Sometimes it's easier to compile it directly into //! // the binary. The clientsecret file contains JSON like `{"installed":{"client_id": ... }}` //! let secret = yup_oauth2::read_application_secret(Path::new("clientsecret.json")) @@ -69,14 +70,10 @@ //! //! // token() is the one important function of this crate; it does everything to //! // obtain a token that can be sent e.g. as Bearer token. -//! let tok = auth.token(scopes); -//! // Finally we print the token. -//! let fut = tok.map_err(|e| println!("error: {:?}", e)).and_then(|t| { -//! println!("The token is {:?}", t); -//! Ok(()) -//! }); -//! -//! tokio::run(fut) +//! match auth.token(scopes).await { +//! Ok(token) => println!("The token is {:?}", token), +//! Err(e) => println!("error: {:?}", e), +//! } //! } //! ``` //! diff --git a/src/refresh.rs b/src/refresh.rs index 1175ecd..717016f 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -2,8 +2,7 @@ use crate::types::{ApplicationSecret, JsonError, RefreshResult, RequestError}; use super::Token; use chrono::Utc; -use futures::stream::Stream; -use futures::Future; +use futures_util::try_stream::TryStreamExt; use hyper; use hyper::header; use serde_json as json; @@ -31,11 +30,12 @@ impl RefreshFlow { /// /// # Examples /// Please see the crate landing page for an example. - pub fn refresh_token<'a, C: 'static + hyper::client::connect::Connect>( + pub async fn refresh_token( client: hyper::Client, client_secret: ApplicationSecret, refresh_token: String, - ) -> impl 'a + Future { + ) -> Result { + // TODO: Does this function ever return RequestError? Maybe have it just return RefreshResult. let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ ("client_id", client_secret.client_id.clone()), @@ -50,53 +50,42 @@ impl RefreshFlow { .body(hyper::Body::from(req)) .unwrap(); // TODO: error handling - client - .request(request) - .then(|r| { - match r { - Err(err) => return Err(RefreshResult::Error(err)), - Ok(res) => { - Ok(res - .into_body() - .concat2() - .wait() - .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) - .unwrap()) // TODO: error handling - } - } - }) - .then(move |maybe_json_str: Result| { - if let Err(e) = maybe_json_str { - return Ok(e); - } - let json_str = maybe_json_str.unwrap(); - #[derive(Deserialize)] - struct JsonToken { - access_token: String, - token_type: String, - expires_in: i64, - } - - match json::from_str::(&json_str) { - Err(_) => {} - Ok(res) => { - return Ok(RefreshResult::RefreshError( - res.error, - res.error_description, - )) - } - } - - let t: JsonToken = json::from_str(&json_str).unwrap(); - Ok(RefreshResult::Success(Token { - access_token: t.access_token, - token_type: t.token_type, - refresh_token: Some(refresh_token.to_string()), - expires_in: None, - expires_in_timestamp: Some(Utc::now().timestamp() + t.expires_in), - })) - }) - .map_err(RequestError::Refresh) + let resp = match client.request(request).await { + Ok(resp) => resp, + Err(err) => return Ok(RefreshResult::Error(err)), + }; + let body = match resp.into_body().try_concat().await { + Ok(body) => body, + Err(err) => return Ok(RefreshResult::Error(err)), + }; + if let Ok(json_err) = json::from_slice::(&body) { + return Ok(RefreshResult::RefreshError( + json_err.error, + json_err.error_description, + )); + } + #[derive(Deserialize)] + struct JsonToken { + access_token: String, + token_type: String, + expires_in: i64, + } + let t: JsonToken = match json::from_slice(&body) { + Err(_) => { + return Ok(RefreshResult::RefreshError( + "failed to deserialized json token from refresh response".to_owned(), + None, + )) + } + Ok(token) => token, + }; + Ok(RefreshResult::Success(Token { + access_token: t.access_token, + token_type: t.token_type, + refresh_token: Some(refresh_token.to_string()), + expires_in: None, + expires_in_timestamp: Some(Utc::now().timestamp() + t.expires_in), + })) } } @@ -119,12 +108,12 @@ mod tests { app_secret.token_uri = format!("{}/token", server_url); let refresh_token = "my-refresh-token".to_string(); - let https = HttpsConnector::new(1); + let https = HttpsConnector::new(); let client = hyper::Client::builder() .keep_alive(false) .build::<_, hyper::Body>(https); - let mut rt = tokio::runtime::Builder::new() + let rt = tokio::runtime::Builder::new() .core_threads(1) .panic_handler(|e| std::panic::resume_unwind(e)) .build() @@ -138,13 +127,14 @@ mod tests { .with_status(200) .with_body(r#"{"access_token": "new-access-token", "token_type": "Bearer", "expires_in": 1234567}"#) .create(); - let fut = RefreshFlow::refresh_token( - client.clone(), - app_secret.clone(), - refresh_token.clone(), - ) - .then(|rr| { - let rr = rr.unwrap(); + let fut = async { + let rr = RefreshFlow::refresh_token( + client.clone(), + app_secret.clone(), + refresh_token.clone(), + ) + .await + .unwrap(); match rr { RefreshResult::Success(tok) => { assert_eq!("new-access-token", tok.access_token); @@ -153,7 +143,7 @@ mod tests { _ => panic!(format!("unexpected RefreshResult {:?}", rr)), } Ok(()) as Result<(), ()> - }); + }; rt.block_on(fut).expect("block_on"); _m.assert(); @@ -167,18 +157,20 @@ mod tests { .with_body(r#"{"error": "invalid_token"}"#) .create(); - let fut = RefreshFlow::refresh_token(client, app_secret, refresh_token).then(|rr| { - let rr = rr.unwrap(); + let fut = async { + let rr = RefreshFlow::refresh_token(client, app_secret, refresh_token) + .await + .unwrap(); match rr { RefreshResult::RefreshError(e, None) => { assert_eq!(e, "invalid_token"); } _ => panic!(format!("unexpected RefreshResult {:?}", rr)), } - Ok(()) - }); + Ok(()) as Result<(), ()> + }; - tokio::run(fut); + rt.block_on(fut).expect("block_on"); _m.assert(); } } diff --git a/src/service_account.rs b/src/service_account.rs index b61448f..eb63200 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -12,14 +12,14 @@ //! use std::default::Default; +use std::pin::Pin; use std::sync::{Arc, Mutex}; use crate::authenticator::{DefaultHyperClient, HyperClientBuilder}; use crate::storage::{hash_scopes, MemoryStorage, TokenStorage}; -use crate::types::{ApplicationSecret, GetToken, JsonError, RequestError, StringError, Token}; +use crate::types::{ApplicationSecret, GetToken, JsonError, RequestError, Token}; -use futures::stream::Stream; -use futures::{future, prelude::*}; +use futures::prelude::*; use hyper::header; use url::form_urlencoded; @@ -266,83 +266,95 @@ struct TokenResponse { expires_in: Option, } -impl TokenResponse { - fn to_oauth_token(self) -> Token { - let expires_ts = chrono::Utc::now().timestamp() + self.expires_in.unwrap_or(0); - - Token { - access_token: self.access_token.unwrap(), - token_type: self.token_type.unwrap(), - refresh_token: Some(String::new()), - expires_in: self.expires_in, - expires_in_timestamp: Some(expires_ts), - } - } -} - impl<'a, C: 'static + hyper::client::connect::Connect> ServiceAccountAccessImpl { /// Send a request for a new Bearer token to the OAuth provider. - fn request_token( + async fn request_token( client: hyper::client::Client, sub: Option, key: ServiceAccountKey, scopes: Vec, - ) -> impl Future { + ) -> Result { let mut claims = init_claims_from_key(&key, &scopes); claims.sub = sub.clone(); let signed = JWT::new(claims) .sign(key.private_key.as_ref().unwrap()) - .into_future(); - signed - .map_err(RequestError::LowLevelError) - .map(|signed| { - form_urlencoded::Serializer::new(String::new()) - .extend_pairs(vec![ - ("grant_type".to_string(), GRANT_TYPE.to_string()), - ("assertion".to_string(), signed), - ]) - .finish() - }) - .map(|rqbody| { - hyper::Request::post(key.token_uri.unwrap()) - .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") - .body(hyper::Body::from(rqbody)) - .unwrap() - }) - .and_then(move |request| client.request(request).map_err(RequestError::ClientError)) - .and_then(|response| { - response - .into_body() - .concat2() - .map_err(RequestError::ClientError) - }) - .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) - .and_then(|s| { - if let Ok(jse) = serde_json::from_str::(&s) { - Err(RequestError::NegativeServerResponse( - jse.error, - jse.error_description, - )) - } else { - serde_json::from_str(&s).map_err(RequestError::JSONError) + .map_err(RequestError::LowLevelError)?; + let rqbody = form_urlencoded::Serializer::new(String::new()) + .extend_pairs(vec![ + ("grant_type".to_string(), GRANT_TYPE.to_string()), + ("assertion".to_string(), signed), + ]) + .finish(); + let request = hyper::Request::post(key.token_uri.unwrap()) + .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") + .body(hyper::Body::from(rqbody)) + .unwrap(); + let response = client + .request(request) + .await + .map_err(RequestError::ClientError)?; + let body = response + .into_body() + .try_concat() + .await + .map_err(RequestError::ClientError)?; + if let Ok(jse) = serde_json::from_slice::(&body) { + return Err(RequestError::NegativeServerResponse( + jse.error, + jse.error_description, + )); + } + let token: TokenResponse = + serde_json::from_slice(&body).map_err(RequestError::JSONError)?; + let token = match token { + TokenResponse { + access_token: Some(access_token), + token_type: Some(token_type), + expires_in: Some(expires_in), + .. + } => { + let expires_ts = chrono::Utc::now().timestamp() + expires_in; + Token { + access_token, + token_type, + refresh_token: None, + expires_in: Some(expires_in), + expires_in_timestamp: Some(expires_ts), } - }) - .then(|token: Result| match token { - Err(e) => return Err(e), - Ok(token) => { - if token.access_token.is_none() - || token.token_type.is_none() - || token.expires_in.is_none() - { - Err(RequestError::BadServerResponse(format!( - "Token response lacks fields: {:?}", - token - ))) - } else { - Ok(token.to_oauth_token()) - } - } - }) + } + _ => { + return Err(RequestError::BadServerResponse(format!( + "Token response lacks fields: {:?}", + token + ))) + } + }; + Ok(token) + } + + async fn get_token(&self, hash: u64, scopes: Vec) -> Result { + let cache = self.cache.clone(); + match cache + .lock() + .unwrap() + .get(hash, &scopes.iter().map(|s| s.as_str()).collect()) + { + Ok(Some(token)) if !token.expired() => return Ok(token), + _ => {} + } + let token = Self::request_token( + self.client.clone(), + self.sub.clone(), + self.key.clone(), + scopes.iter().map(|s| s.to_string()).collect(), + ) + .await?; + let _ = cache.lock().unwrap().set( + hash, + &scopes.iter().map(|s| s.as_str()).collect(), + Some(token.clone()), + ); + Ok(token) } } @@ -350,61 +362,16 @@ impl GetToken for ServiceAccountAccessImpl where C: hyper::client::connect::Connect, { - fn token( - &mut self, + fn token<'a, I, T>( + &'a self, scopes: I, - ) -> Box + Send> + ) -> Pin> + Send + 'a>> where T: Into, I: IntoIterator, { let (hash, scps0) = hash_scopes(scopes); - let cache = self.cache.clone(); - let scps = scps0.clone(); - - let cache_lookup = futures::lazy(move || { - match cache - .lock() - .unwrap() - .get(hash, &scps.iter().map(|s| s.as_str()).collect()) - { - Ok(Some(token)) => { - if !token.expired() { - return Ok(token); - } - return Err(StringError::new("expired token in cache", None)); - } - Err(e) => return Err(StringError::new(format!("cache lookup error: {}", e), None)), - Ok(None) => return Err(StringError::new("no token in cache", None)), - } - }); - - let cache = self.cache.clone(); - let req_token = Self::request_token( - self.client.clone(), - self.sub.clone(), - self.key.clone(), - scps0.iter().map(|s| s.to_string()).collect(), - ) - .then(move |r| match r { - Ok(token) => { - let _ = cache.lock().unwrap().set( - hash, - &scps0.iter().map(|s| s.as_str()).collect(), - Some(token.clone()), - ); - Box::new(future::ok(token)) - } - Err(e) => Box::new(future::err(e)), - }); - - Box::new(cache_lookup.then(|r| match r { - Ok(t) => Box::new(Ok(t).into_future()) - as Box + Send>, - Err(_) => { - Box::new(req_token) as Box + Send> - } - })) + Box::pin(self.get_token(hash, scps0)) } /// Returns an empty ApplicationSecret as tokens for service accounts don't need to be @@ -413,7 +380,7 @@ where Default::default() } - fn api_key(&mut self) -> Option { + fn api_key(&self) -> Option { None } } @@ -458,11 +425,11 @@ mod tests { "token_type": "Bearer" }"#; - let https = HttpsConnector::new(1); + let https = HttpsConnector::new(); let client = hyper::Client::builder() .keep_alive(false) .build::<_, hyper::Body>(https); - let mut rt = tokio::runtime::Builder::new() + let rt = tokio::runtime::Builder::new() .core_threads(1) .panic_handler(|e| std::panic::resume_unwind(e)) .build() @@ -476,14 +443,15 @@ mod tests { .with_body(json_response) .expect(1) .create(); - let mut acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None); - let fut = acc - .token(vec!["https://www.googleapis.com/auth/pubsub"]) - .and_then(|tok| { - assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); - assert_eq!(Some(3600), tok.expires_in); - Ok(()) - }); + let acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None); + let fut = async { + let tok = acc + .token(vec!["https://www.googleapis.com/auth/pubsub"]) + .await?; + assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); + assert_eq!(Some(3600), tok.expires_in); + Ok(()) as Result<(), RequestError> + }; rt.block_on(fut).expect("block_on"); assert!(acc @@ -497,13 +465,14 @@ mod tests { .unwrap() .is_some()); // Test that token is in cache (otherwise mock will tell us) - let fut = acc - .token(vec!["https://www.googleapis.com/auth/pubsub"]) - .and_then(|tok| { - assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); - assert_eq!(Some(3600), tok.expires_in); - Ok(()) - }); + let fut = async { + let tok = acc + .token(vec!["https://www.googleapis.com/auth/pubsub"]) + .await?; + assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); + assert_eq!(Some(3600), tok.expires_in); + Ok(()) as Result<(), RequestError> + }; rt.block_on(fut).expect("block_on 2"); _m.assert(); @@ -515,19 +484,20 @@ mod tests { .with_header("content-type", "text/json") .with_body(bad_json_response) .create(); - let mut acc = ServiceAccountAccess::new(key.clone()) + let acc = ServiceAccountAccess::new(key.clone()) .hyper_client(client.clone()) .build(); - let fut = acc - .token(vec!["https://www.googleapis.com/auth/pubsub"]) - .then(|result| { - assert!(result.is_err()); - Ok(()) as Result<(), ()> - }); + let fut = async { + let result = acc + .token(vec!["https://www.googleapis.com/auth/pubsub"]) + .await; + assert!(result.is_err()); + Ok(()) as Result<(), ()> + }; rt.block_on(fut).expect("block_on"); _m.assert(); } - rt.shutdown_on_idle().wait().expect("shutdown"); + rt.shutdown_on_idle(); } // Valid but deactivated key. @@ -538,17 +508,21 @@ mod tests { #[allow(dead_code)] fn test_service_account_e2e() { let key = service_account_key_from_file(&TEST_PRIVATE_KEY_PATH.to_string()).unwrap(); - let https = HttpsConnector::new(4); - let runtime = tokio::runtime::Runtime::new().unwrap(); - let client = hyper::Client::builder() - .executor(runtime.executor()) - .build(https); - let mut acc = ServiceAccountAccess::new(key).hyper_client(client).build(); - println!( - "{:?}", - acc.token(vec!["https://www.googleapis.com/auth/pubsub"]) - .wait() - ); + let https = HttpsConnector::new(); + let client = hyper::Client::builder().build(https); + let acc = ServiceAccountAccess::new(key).hyper_client(client).build(); + let rt = tokio::runtime::Builder::new() + .core_threads(1) + .panic_handler(|e| std::panic::resume_unwind(e)) + .build() + .unwrap(); + rt.block_on(async { + println!( + "{:?}", + acc.token(vec!["https://www.googleapis.com/auth/pubsub"]) + .await + ); + }); } #[test] diff --git a/src/storage.rs b/src/storage.rs index 950d015..81049ba 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -10,7 +10,9 @@ use std::fmt; use std::fs; use std::hash::{Hash, Hasher}; use std::io; -use std::io::{Read, Write}; +use std::io::Write; +use std::path::{Path, PathBuf}; +use std::sync::Mutex; use crate::types::Token; use itertools::Itertools; @@ -20,13 +22,13 @@ use itertools::Itertools; /// should be stored or retrieved. /// For completeness, the underlying, sorted scopes are provided as well. They might be /// useful for presentation to the user. -pub trait TokenStorage { +pub trait TokenStorage: Send + Sync { type Error: 'static + Error + Send + Sync; /// If `token` is None, it is invalid or revoked and should be removed from storage. /// Otherwise, it should be saved. fn set( - &mut self, + &self, scope_hash: u64, scopes: &Vec<&str>, token: Option, @@ -69,7 +71,7 @@ impl fmt::Display for NullError { impl TokenStorage for NullStorage { type Error = NullError; - fn set(&mut self, _: u64, _: &Vec<&str>, _: Option) -> Result<(), NullError> { + fn set(&self, _: u64, _: &Vec<&str>, _: Option) -> Result<(), NullError> { Ok(()) } fn get(&self, _: u64, _: &Vec<&str>) -> Result, NullError> { @@ -80,7 +82,7 @@ impl TokenStorage for NullStorage { /// A storage that remembers values for one session only. #[derive(Debug, Default)] pub struct MemoryStorage { - tokens: Vec, + tokens: Mutex>, } impl MemoryStorage { @@ -93,19 +95,20 @@ impl TokenStorage for MemoryStorage { type Error = NullError; fn set( - &mut self, + &self, scope_hash: u64, scopes: &Vec<&str>, token: Option, ) -> Result<(), NullError> { - let matched = self.tokens.iter().find_position(|x| x.hash == scope_hash); - if let Some(_) = matched { + let mut tokens = self.tokens.lock().expect("poisoned mutex"); + let matched = tokens.iter().find_position(|x| x.hash == scope_hash); + if let Some((idx, _)) = matched { self.tokens.retain(|x| x.hash != scope_hash); } match token { Some(t) => { - self.tokens.push(JSONToken { + tokens.push(JSONToken { hash: scope_hash, scopes: Some(scopes.iter().map(|x| x.to_string()).collect()), token: t.clone(), @@ -120,7 +123,8 @@ impl TokenStorage for MemoryStorage { fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, NullError> { let scopes: Vec<_> = scopes.iter().sorted().unique().collect(); - for t in &self.tokens { + let tokens = self.tokens.lock().expect("poisoned mutex"); + for t in tokens.iter() { if let Some(token_scopes) = &t.scopes { let matched = token_scopes .iter() @@ -174,58 +178,32 @@ struct JSONTokens { /// Serializes tokens to a JSON file on disk. #[derive(Default)] pub struct DiskTokenStorage { - location: String, - tokens: Vec, + location: PathBuf, + tokens: Mutex>, } impl DiskTokenStorage { - pub fn new>(location: S) -> Result { - let mut dts = DiskTokenStorage { - location: location.as_ref().to_owned(), - tokens: Vec::new(), + pub fn new>(location: S) -> Result { + let filename = location.into(); + let tokens = match load_from_file(&filename) { + Ok(tokens) => tokens, + Err(e) if e.kind() == io::ErrorKind::NotFound => Vec::new(), + Err(e) => return Err(e), }; - - // best-effort - let read_result = dts.load_from_file(); - - match read_result { - Result::Ok(()) => Result::Ok(dts), - Result::Err(e) => { - match e.kind() { - io::ErrorKind::NotFound => Result::Ok(dts), // File not found; ignore and create new one - _ => Result::Err(e), // e.g. PermissionDenied - } - } - } + Ok(DiskTokenStorage { + location: filename, + tokens: Mutex::new(tokens), + }) } - fn load_from_file(&mut self) -> Result<(), io::Error> { - let mut f = fs::OpenOptions::new().read(true).open(&self.location)?; - let mut contents = String::new(); - - match f.read_to_string(&mut contents) { - Result::Err(e) => return Result::Err(e), - Result::Ok(_sz) => (), - } - - let tokens: JSONTokens; - - match serde_json::from_str(&contents) { - Result::Err(e) => return Result::Err(io::Error::new(io::ErrorKind::InvalidData, e)), - Result::Ok(t) => tokens = t, - } - - for t in tokens.tokens { - self.tokens.push(t); - } - return Result::Ok(()); - } - - pub fn dump_to_file(&mut self) -> Result<(), io::Error> { + pub fn dump_to_file(&self) -> Result<(), io::Error> { let mut jsontokens = JSONTokens { tokens: Vec::new() }; - for token in self.tokens.iter() { - jsontokens.tokens.push((*token).clone()); + { + let tokens = self.tokens.lock().expect("mutex poisoned"); + for token in tokens.iter() { + jsontokens.tokens.push((*token).clone()); + } } let serialized; @@ -235,6 +213,7 @@ impl DiskTokenStorage { Result::Ok(s) => serialized = s, } + // TODO: Write to disk asynchronously so that we don't stall the eventloop if invoked in async context. let mut f = fs::OpenOptions::new() .create(true) .write(true) @@ -244,28 +223,38 @@ impl DiskTokenStorage { } } +fn load_from_file(filename: &Path) -> Result, io::Error> { + let contents = std::fs::read_to_string(filename)?; + let container: JSONTokens = serde_json::from_str(&contents) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(container.tokens) +} + impl TokenStorage for DiskTokenStorage { type Error = io::Error; fn set( - &mut self, + &self, scope_hash: u64, scopes: &Vec<&str>, token: Option, ) -> Result<(), Self::Error> { - let matched = self.tokens.iter().find_position(|x| x.hash == scope_hash); - if let Some(_) = matched { - self.tokens.retain(|x| x.hash != scope_hash); - } + { + let mut tokens = self.tokens.lock().expect("poisoned mutex"); + let matched = tokens.iter().find_position(|x| x.hash == scope_hash); + if let Some((idx, _)) = matched { + self.tokens.retain(|x| x.hash != scope_hash); + } - match token { - None => (), - Some(t) => { - self.tokens.push(JSONToken { - hash: scope_hash, - scopes: Some(scopes.iter().map(|x| x.to_string()).collect()), - token: t.clone(), - }); - () + match token { + None => (), + Some(t) => { + tokens.push(JSONToken { + hash: scope_hash, + scopes: Some(scopes.iter().map(|x| x.to_string()).collect()), + token: t.clone(), + }); + () + } } } self.dump_to_file() @@ -273,7 +262,8 @@ impl TokenStorage for DiskTokenStorage { fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, Self::Error> { let scopes: Vec<_> = scopes.iter().sorted().unique().collect(); - for t in &self.tokens { + let tokens = self.tokens.lock().expect("poisoned mutex"); + for t in tokens.iter() { if let Some(token_scopes) = &t.scopes { let matched = token_scopes .iter() diff --git a/src/types.rs b/src/types.rs index 90553bc..3f96e14 100644 --- a/src/types.rs +++ b/src/types.rs @@ -3,6 +3,7 @@ use hyper; use std::error::Error; use std::fmt; use std::io; +use std::pin::Pin; use std::str::FromStr; use futures::prelude::*; @@ -239,16 +240,16 @@ impl FromStr for Scheme { /// A provider for authorization tokens, yielding tokens valid for a given scope. /// The `api_key()` method is an alternative in case there are no scopes or /// if no user is involved. -pub trait GetToken { - fn token( - &mut self, +pub trait GetToken: Send + Sync { + fn token<'a, I, T>( + &'a self, scopes: I, - ) -> Box + Send> + ) -> Pin> + Send + 'a>> where T: Into, I: IntoIterator; - fn api_key(&mut self) -> Option; + fn api_key(&self) -> Option; /// Return an application secret with at least token_uri, client_secret, and client_id filled /// in. This is used for refreshing tokens without interaction from the flow. From a4c9b6034efd101c88f74999687f95d2cc3bf4e4 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 7 Nov 2019 15:21:38 -0800 Subject: [PATCH 02/71] Require trait implementations to be Send + Sync. Tidy up some of the trait bounds on types and methods. --- src/authenticator.rs | 12 +++++------- src/authenticator_delegate.rs | 12 ++++++------ src/device.rs | 14 +++++++------- src/installed.rs | 24 +++++++++++++++--------- src/service_account.rs | 9 ++++++--- src/storage.rs | 28 ++++++---------------------- 6 files changed, 45 insertions(+), 54 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index cde9fe4..6e17d2e 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -167,9 +167,9 @@ where /// Create the authenticator. pub fn build(self) -> io::Result where - T::TokenGetter: 'static + GetToken + Send, + T::TokenGetter: 'static + GetToken, S: 'static + Send, - AD: 'static + Send + Sync, + AD: 'static, C::Connector: 'static + Clone + Send, { let client = self.client.build_hyper_client(); @@ -189,12 +189,12 @@ impl AuthenticatorImpl where GT: 'static + GetToken, S: 'static + TokenStorage, - AD: 'static + AuthenticatorDelegate + Send + Sync, + AD: 'static + AuthenticatorDelegate, C: 'static + hyper::client::connect::Connect + Clone + Send, { async fn get_token(&self, scope_key: u64, scopes: Vec) -> Result { let store = self.store.clone(); - let mut delegate = self.delegate.clone(); + let delegate = &self.delegate; let client = self.client.clone(); let appsecret = self.inner.application_secret(); let gettoken = self.inner.clone(); @@ -209,7 +209,6 @@ where } // Implement refresh flow. let refresh_token = t.refresh_token.clone(); - let mut delegate = delegate.clone(); let store = store.clone(); let scopes = scopes.clone(); let rr = RefreshFlow::refresh_token( @@ -254,7 +253,6 @@ where Ok(None) => { let store = store.clone(); let scopes = scopes.clone(); - let mut delegate = delegate.clone(); let t = gettoken.token(scopes.clone()).await?; if let Err(e) = store.set( scope_key, @@ -282,7 +280,7 @@ where impl< GT: 'static + GetToken, S: 'static + TokenStorage, - AD: 'static + AuthenticatorDelegate + Send + Sync, + AD: 'static + AuthenticatorDelegate, C: 'static + hyper::client::connect::Connect + Clone + Send, > GetToken for AuthenticatorImpl { diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index eb3bceb..0ac916b 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -71,11 +71,11 @@ impl Error for PollError { /// /// The only method that needs to be implemented manually is `present_user_code(...)`, /// as no assumptions are made on how this presentation should happen. -pub trait AuthenticatorDelegate: Clone { +pub trait AuthenticatorDelegate: Clone + Send + Sync { /// Called whenever there is an client, usually if there are network problems. /// /// Return retry information. - fn client_error(&mut self, _: &hyper::Error) -> Retry { + fn client_error(&self, _: &hyper::Error) -> Retry { Retry::Abort } @@ -84,19 +84,19 @@ pub trait AuthenticatorDelegate: Clone { /// This can be useful if the underlying `TokenStorage` may fail occasionally. /// if `is_set` is true, the failure resulted from `TokenStorage.set(...)`. Otherwise, /// it was `TokenStorage.get(...)` - fn token_storage_failure(&mut self, is_set: bool, _: &(dyn Error + Send + Sync)) -> Retry { + fn token_storage_failure(&self, is_set: bool, _: &(dyn Error + Send + Sync)) -> Retry { let _ = is_set; Retry::Abort } /// The server denied the attempt to obtain a request code - fn request_failure(&mut self, _: RequestError) {} + fn request_failure(&self, _: RequestError) {} /// Called if we could not acquire a refresh token for a reason possibly specified /// by the server. /// This call is made for the delegate's information only. fn token_refresh_failed>( - &mut self, + &self, error: S, error_description: &Option, ) { @@ -111,7 +111,7 @@ pub trait AuthenticatorDelegate: Clone { /// FlowDelegate methods are called when an OAuth flow needs to ask the application what to do in /// certain cases. -pub trait FlowDelegate: Clone { +pub trait FlowDelegate: Clone + Send + Sync { /// Called if the request code is expired. You will have to start over in this case. /// This will be the last call the delegate receives. /// Given `DateTime` is the expiration date diff --git a/src/device.rs b/src/device.rs index dc4cfdf..02f19c7 100644 --- a/src/device.rs +++ b/src/device.rs @@ -73,7 +73,7 @@ impl DeviceFlow { impl crate::authenticator::AuthFlow for DeviceFlow where - FD: FlowDelegate + Send + Sync + 'static, + FD: FlowDelegate + 'static, C: hyper::client::connect::Connect + 'static, { type TokenGetter = DeviceFlowImpl; @@ -105,10 +105,10 @@ impl Flow for DeviceFlowImpl { } } -impl< - FD: FlowDelegate + Clone + Send + Sync + 'static, - C: hyper::client::connect::Connect + Sync + 'static, - > GetToken for DeviceFlowImpl +impl GetToken for DeviceFlowImpl +where + FD: FlowDelegate + 'static, + C: hyper::client::connect::Connect + 'static, { fn token<'a, I, T>( &'a self, @@ -130,10 +130,10 @@ impl< impl DeviceFlowImpl where - C: hyper::client::connect::Connect + Sync + 'static, + C: hyper::client::connect::Connect + 'static, C::Transport: 'static, C::Future: 'static, - FD: FlowDelegate + Clone + Send + 'static, + FD: FlowDelegate + 'static, { /// Essentially what `GetToken::token` does: Retrieve a token for the given scopes without /// caching. diff --git a/src/installed.rs b/src/installed.rs index 1733e6f..ba6758e 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -61,10 +61,10 @@ where }) } -impl< - FD: FlowDelegate + 'static + Send + Sync + Clone, - C: hyper::client::connect::Connect + 'static, - > GetToken for InstalledFlowImpl +impl GetToken for InstalledFlowImpl +where + FD: FlowDelegate + 'static, + C: hyper::client::connect::Connect + 'static, { fn token<'a, I, T>( &'a self, @@ -85,7 +85,11 @@ impl< } /// The InstalledFlow implementation. -pub struct InstalledFlowImpl { +pub struct InstalledFlowImpl +where + FD: FlowDelegate + 'static, + C: hyper::client::connect::Connect + 'static, +{ method: InstalledFlowReturnMethod, client: hyper::client::Client, fd: FD, @@ -109,7 +113,7 @@ pub enum InstalledFlowReturnMethod { /// InstalledFlowImpl provides tokens for services that follow the "Installed" OAuth flow. (See /// https://www.oauth.com/oauth2-servers/authorization/, /// https://developers.google.com/identity/protocols/OAuth2InstalledApp). -pub struct InstalledFlow { +pub struct InstalledFlow { method: InstalledFlowReturnMethod, flow_delegate: FD, appsecret: ApplicationSecret, @@ -145,7 +149,7 @@ where impl crate::authenticator::AuthFlow for InstalledFlow where - FD: FlowDelegate + Send + Sync + 'static, + FD: FlowDelegate + 'static, C: hyper::client::connect::Connect + 'static, { type TokenGetter = InstalledFlowImpl; @@ -160,8 +164,10 @@ where } } -impl<'c, FD: 'static + FlowDelegate + Clone + Send, C: 'c + hyper::client::connect::Connect> - InstalledFlowImpl +impl InstalledFlowImpl +where + FD: FlowDelegate + 'static, + C: hyper::client::connect::Connect + 'static, { /// Handles the token request flow; it consists of the following steps: /// . Obtain a authorization code with user cooperation or internal redirect. diff --git a/src/service_account.rs b/src/service_account.rs index eb63200..c56a9cb 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -266,7 +266,10 @@ struct TokenResponse { expires_in: Option, } -impl<'a, C: 'static + hyper::client::connect::Connect> ServiceAccountAccessImpl { +impl ServiceAccountAccessImpl +where + C: hyper::client::connect::Connect + 'static, +{ /// Send a request for a new Bearer token to the OAuth provider. async fn request_token( client: hyper::client::Client, @@ -358,9 +361,9 @@ impl<'a, C: 'static + hyper::client::connect::Connect> ServiceAccountAccessImpl< } } -impl GetToken for ServiceAccountAccessImpl +impl GetToken for ServiceAccountAccessImpl where - C: hyper::client::connect::Connect, + C: hyper::client::connect::Connect + 'static, { fn token<'a, I, T>( &'a self, diff --git a/src/storage.rs b/src/storage.rs index 81049ba..a1224e7 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -6,7 +6,6 @@ use std::cmp::Ordering; use std::collections::hash_map::DefaultHasher; use std::error::Error; -use std::fmt; use std::fs; use std::hash::{Hash, Hasher}; use std::io; @@ -54,27 +53,12 @@ where #[derive(Default)] pub struct NullStorage; -#[derive(Debug)] -pub struct NullError; - -impl Error for NullError { - fn description(&self) -> &str { - "NULL" - } -} - -impl fmt::Display for NullError { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - "NULL-ERROR".fmt(f) - } -} - impl TokenStorage for NullStorage { - type Error = NullError; - fn set(&self, _: u64, _: &Vec<&str>, _: Option) -> Result<(), NullError> { + type Error = std::convert::Infallible; + fn set(&self, _: u64, _: &Vec<&str>, _: Option) -> Result<(), Self::Error> { Ok(()) } - fn get(&self, _: u64, _: &Vec<&str>) -> Result, NullError> { + fn get(&self, _: u64, _: &Vec<&str>) -> Result, Self::Error> { Ok(None) } } @@ -92,14 +76,14 @@ impl MemoryStorage { } impl TokenStorage for MemoryStorage { - type Error = NullError; + type Error = std::convert::Infallible; fn set( &self, scope_hash: u64, scopes: &Vec<&str>, token: Option, - ) -> Result<(), NullError> { + ) -> Result<(), Self::Error> { let mut tokens = self.tokens.lock().expect("poisoned mutex"); let matched = tokens.iter().find_position(|x| x.hash == scope_hash); if let Some((idx, _)) = matched { @@ -120,7 +104,7 @@ impl TokenStorage for MemoryStorage { Ok(()) } - fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, NullError> { + fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, Self::Error> { let scopes: Vec<_> = scopes.iter().sorted().unique().collect(); let tokens = self.tokens.lock().expect("poisoned mutex"); From 0f29c258c607e571042be04057e36252258185bb Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 7 Nov 2019 15:27:22 -0800 Subject: [PATCH 03/71] FlowType isn't used for anything. Remove it. --- src/authenticator_delegate.rs | 2 +- src/device.rs | 8 +------- src/lib.rs | 2 +- src/types.rs | 23 ----------------------- 4 files changed, 3 insertions(+), 32 deletions(-) diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 0ac916b..5312c13 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -140,7 +140,7 @@ pub trait FlowDelegate: Clone + Send + Sync { /// along with the `verification_url`. /// # Notes /// * Will be called exactly once, provided we didn't abort during `request_code` phase. - /// * Will only be called if the Authenticator's flow_type is `FlowType::Device`. + /// * Will only be called if the Authenticator's flow_type is `DeviceFlow`. fn present_user_code(&self, pi: &PollInformation) { println!( "Please enter {} at {} and grant access to this application", diff --git a/src/device.rs b/src/device.rs index 02f19c7..a924f17 100644 --- a/src/device.rs +++ b/src/device.rs @@ -13,7 +13,7 @@ use url::form_urlencoded; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate, PollInformation, Retry}; use crate::types::{ - ApplicationSecret, Flow, FlowType, GetToken, JsonError, PollError, RequestError, Token, + ApplicationSecret, GetToken, JsonError, PollError, RequestError, Token, }; pub const GOOGLE_DEVICE_CODE_URL: &'static str = "https://accounts.google.com/o/oauth2/device/code"; @@ -99,12 +99,6 @@ pub struct DeviceFlowImpl { wait: Duration, } -impl Flow for DeviceFlowImpl { - fn type_id() -> FlowType { - FlowType::Device(String::new()) - } -} - impl GetToken for DeviceFlowImpl where FD: FlowDelegate + 'static, diff --git a/src/lib.rs b/src/lib.rs index f83144e..523e703 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -101,6 +101,6 @@ pub use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; pub use crate::service_account::*; pub use crate::storage::{DiskTokenStorage, MemoryStorage, NullStorage, TokenStorage}; pub use crate::types::{ - ApplicationSecret, ConsoleApplicationSecret, FlowType, GetToken, PollError, RefreshResult, + ApplicationSecret, ConsoleApplicationSecret, GetToken, PollError, RefreshResult, RequestError, Scheme, Token, TokenType, }; diff --git a/src/types.rs b/src/types.rs index 3f96e14..a47a0fa 100644 --- a/src/types.rs +++ b/src/types.rs @@ -8,11 +8,6 @@ use std::str::FromStr; use futures::prelude::*; -/// A marker trait for all Flows -pub trait Flow { - fn type_id() -> FlowType; -} - #[derive(Deserialize, Debug)] pub struct JsonError { pub error: String, @@ -323,24 +318,6 @@ impl Token { } } -/// All known authentication types, for suitable constants -#[derive(Clone)] -pub enum FlowType { - /// [device authentication](https://developers.google.com/youtube/v3/guides/authentication#devices). Only works - /// for certain scopes. - /// Contains the device token URL; for google, that is - /// https://accounts.google.com/o/oauth2/device/code (exported as `GOOGLE_DEVICE_CODE_URL`) - Device(String), - /// [installed app flow](https://developers.google.com/identity/protocols/OAuth2InstalledApp). Required - /// for Drive, Calendar, Gmail...; Requires user to paste a code from the browser. - InstalledInteractive, - /// Same as InstalledInteractive, but uses a redirect: The OAuth provider redirects the user's - /// browser to a web server that is running on localhost. This may not work as well with the - /// Windows Firewall, but is more comfortable otherwise. The integer describes which port to - /// bind to (default: 8080) - InstalledRedirect(u16), -} - /// Represents either 'installed' or 'web' applications in a json secrets file. /// See `ConsoleApplicationSecret` for more information #[derive(Deserialize, Serialize, Clone, Default)] From 7e210a22c5459362d8ffd226b56bc38aa1de152d Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 7 Nov 2019 16:22:09 -0800 Subject: [PATCH 04/71] Have TokenStorage take scopes by iterator rather than Vec. This reduces the number of allocations needed. --- src/authenticator.rs | 8 ++- src/service_account.rs | 6 +-- src/storage.rs | 115 ++++++++++++++++++++++++----------------- 3 files changed, 73 insertions(+), 56 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index 6e17d2e..f39461f 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -201,7 +201,7 @@ where loop { match store.get( scope_key.clone(), - &scopes.iter().map(|s| s.as_str()).collect(), + &scopes, ) { Ok(Some(t)) => { if !t.expired() { @@ -210,7 +210,6 @@ where // Implement refresh flow. let refresh_token = t.refresh_token.clone(); let store = store.clone(); - let scopes = scopes.clone(); let rr = RefreshFlow::refresh_token( client.clone(), appsecret.clone(), @@ -235,7 +234,7 @@ where RefreshResult::Success(t) => { let x = store.set( scope_key, - &scopes.iter().map(|s| s.as_str()).collect(), + &scopes, Some(t.clone()), ); if let Err(e) = x { @@ -252,11 +251,10 @@ where } Ok(None) => { let store = store.clone(); - let scopes = scopes.clone(); let t = gettoken.token(scopes.clone()).await?; if let Err(e) = store.set( scope_key, - &scopes.iter().map(|s| s.as_str()).collect(), + &scopes, Some(t.clone()), ) { match delegate.token_storage_failure(true, &e) { diff --git a/src/service_account.rs b/src/service_account.rs index c56a9cb..355f77e 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -340,7 +340,7 @@ where match cache .lock() .unwrap() - .get(hash, &scopes.iter().map(|s| s.as_str()).collect()) + .get(hash, scopes.iter()) { Ok(Some(token)) if !token.expired() => return Ok(token), _ => {} @@ -354,7 +354,7 @@ where .await?; let _ = cache.lock().unwrap().set( hash, - &scopes.iter().map(|s| s.as_str()).collect(), + scopes.iter(), Some(token.clone()), ); Ok(token) @@ -463,7 +463,7 @@ mod tests { .unwrap() .get( 3502164897243251857, - &vec!["https://www.googleapis.com/auth/pubsub"] + ["https://www.googleapis.com/auth/pubsub"].iter(), ) .unwrap() .is_some()); diff --git a/src/storage.rs b/src/storage.rs index a1224e7..551c563 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -26,14 +26,21 @@ pub trait TokenStorage: Send + Sync { /// If `token` is None, it is invalid or revoked and should be removed from storage. /// Otherwise, it should be saved. - fn set( + fn set( &self, scope_hash: u64, - scopes: &Vec<&str>, + scopes: I, token: Option, - ) -> Result<(), Self::Error>; + ) -> Result<(), Self::Error> + where + I: IntoIterator, + I::Item: AsRef; + /// A `None` result indicates that there is no token for the given scope_hash. - fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, Self::Error>; + fn get(&self, scope_hash: u64, scopes: I) -> Result, Self::Error> + where + I: IntoIterator + Clone, + I::Item: AsRef; } /// Calculate a hash value describing the scopes, and return a sorted Vec of the scopes. @@ -55,10 +62,19 @@ pub struct NullStorage; impl TokenStorage for NullStorage { type Error = std::convert::Infallible; - fn set(&self, _: u64, _: &Vec<&str>, _: Option) -> Result<(), Self::Error> { + fn set(&self, _: u64, _: I, _: Option) -> Result<(), Self::Error> + where + I: IntoIterator, + I::Item: AsRef, + { Ok(()) } - fn get(&self, _: u64, _: &Vec<&str>) -> Result, Self::Error> { + + fn get(&self, _: u64, _: I) -> Result, Self::Error> + where + I: IntoIterator + Clone, + I::Item: AsRef, + { Ok(None) } } @@ -78,12 +94,16 @@ impl MemoryStorage { impl TokenStorage for MemoryStorage { type Error = std::convert::Infallible; - fn set( + fn set( &self, scope_hash: u64, - scopes: &Vec<&str>, + scopes: I, token: Option, - ) -> Result<(), Self::Error> { + ) -> Result<(), Self::Error> + where + I: IntoIterator, + I::Item: AsRef, + { let mut tokens = self.tokens.lock().expect("poisoned mutex"); let matched = tokens.iter().find_position(|x| x.hash == scope_hash); if let Some((idx, _)) = matched { @@ -94,7 +114,7 @@ impl TokenStorage for MemoryStorage { Some(t) => { tokens.push(JSONToken { hash: scope_hash, - scopes: Some(scopes.iter().map(|x| x.to_string()).collect()), + scopes: Some(scopes.into_iter().map(|x| x.as_ref().to_string()).collect()), token: t.clone(), }); () @@ -104,24 +124,13 @@ impl TokenStorage for MemoryStorage { Ok(()) } - fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, Self::Error> { - let scopes: Vec<_> = scopes.iter().sorted().unique().collect(); - + fn get(&self, scope_hash: u64, scopes: I) -> Result, Self::Error> + where + I: IntoIterator + Clone, + I::Item: AsRef, + { let tokens = self.tokens.lock().expect("poisoned mutex"); - for t in tokens.iter() { - if let Some(token_scopes) = &t.scopes { - let matched = token_scopes - .iter() - .filter(|x| scopes.contains(&&&x[..])) - .count(); - if matched >= scopes.len() { - return Result::Ok(Some(t.token.clone())); - } - } else if scope_hash == t.hash { - return Result::Ok(Some(t.token.clone())); - } - } - Result::Ok(None) + Ok(token_for_scopes(&tokens, scope_hash, scopes)) } } @@ -216,12 +225,16 @@ fn load_from_file(filename: &Path) -> Result, io::Error> { impl TokenStorage for DiskTokenStorage { type Error = io::Error; - fn set( + fn set( &self, scope_hash: u64, - scopes: &Vec<&str>, + scopes: I, token: Option, - ) -> Result<(), Self::Error> { + ) -> Result<(), Self::Error> + where + I: IntoIterator, + I::Item: AsRef, + { { let mut tokens = self.tokens.lock().expect("poisoned mutex"); let matched = tokens.iter().find_position(|x| x.hash == scope_hash); @@ -234,7 +247,7 @@ impl TokenStorage for DiskTokenStorage { Some(t) => { tokens.push(JSONToken { hash: scope_hash, - scopes: Some(scopes.iter().map(|x| x.to_string()).collect()), + scopes: Some(scopes.into_iter().map(|x| x.as_ref().to_string()).collect()), token: t.clone(), }); () @@ -243,24 +256,30 @@ impl TokenStorage for DiskTokenStorage { } self.dump_to_file() } - fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, Self::Error> { - let scopes: Vec<_> = scopes.iter().sorted().unique().collect(); + fn get(&self, scope_hash: u64, scopes: I) -> Result, Self::Error> + where + I: IntoIterator + Clone, + I::Item: AsRef, + { let tokens = self.tokens.lock().expect("poisoned mutex"); - for t in tokens.iter() { - if let Some(token_scopes) = &t.scopes { - let matched = token_scopes - .iter() - .filter(|x| scopes.contains(&&&x[..])) - .count(); - // we may have some of the tokens as denormalized (many namespaces repeated) - if matched >= scopes.len() { - return Result::Ok(Some(t.token.clone())); - } - } else if scope_hash == t.hash { - return Result::Ok(Some(t.token.clone())); - } - } - Result::Ok(None) + Ok(token_for_scopes(&tokens, scope_hash, scopes)) } } + +fn token_for_scopes(tokens: &[JSONToken], scope_hash: u64, scopes: I) -> Option +where + I: IntoIterator + Clone, + I::Item: AsRef, +{ + for t in tokens.iter() { + if let Some(token_scopes) = &t.scopes { + if scopes.clone().into_iter().all(|s| token_scopes.iter().any(|t| t == s.as_ref())) { + return Some(t.token.clone()); + } + } else if scope_hash == t.hash { + return Some(t.token.clone()) + } + } + None +} From 696577aa01c2bda04714e82081020e92a896f1b5 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 8 Nov 2019 12:43:17 -0800 Subject: [PATCH 05/71] Accept scopes as a slice of anything that can produce a &str. Along with the public facing change the implementation has been modified to no longer clone the scopes instead using the pointer to the scopes the user provided. This greatly reduces the number of allocations on each token() call. Note that this also changes the hashing method used for token storage in an incompatible way with the previous implementation. The previous implementation pre-sorted the vector and hashed the contents to make the result independent of the ordering of the scopes. Instead we now combine the hash values of each scope together with XOR, thus producing a hash value that does not depend on order without needing to allocate another vector and sort. --- examples/test-device/src/main.rs | 2 +- examples/test-installed/src/main.rs | 3 +- examples/test-svc-acct/src/main.rs | 5 +- src/authenticator.rs | 26 ++++---- src/device.rs | 43 +++++++------- src/helper.rs | 22 +++++++ src/installed.rs | 42 ++++++------- src/lib.rs | 3 +- src/service_account.rs | 53 +++++++++-------- src/storage.rs | 92 ++++++++++++++++------------- src/types.rs | 7 +-- 11 files changed, 166 insertions(+), 132 deletions(-) diff --git a/examples/test-device/src/main.rs b/examples/test-device/src/main.rs index 42b3ab8..64e413c 100644 --- a/examples/test-device/src/main.rs +++ b/examples/test-device/src/main.rs @@ -12,7 +12,7 @@ async fn main() { .build() .expect("authenticator"); - let scopes = vec!["https://www.googleapis.com/auth/youtube.readonly"]; + let scopes = &["https://www.googleapis.com/auth/youtube.readonly"]; match auth.token(scopes).await { Err(e) => println!("error: {:?}", e), Ok(t) => println!("token: {:?}", t), diff --git a/examples/test-installed/src/main.rs b/examples/test-installed/src/main.rs index f4909a3..c333255 100644 --- a/examples/test-installed/src/main.rs +++ b/examples/test-installed/src/main.rs @@ -15,8 +15,7 @@ async fn main() { .persist_tokens_to_disk("tokencache.json") .build() .unwrap(); - let s = "https://www.googleapis.com/auth/drive.file".to_string(); - let scopes = vec![s]; + let scopes = &["https://www.googleapis.com/auth/drive.file"]; match auth.token(scopes).await { Err(e) => println!("error: {:?}", e), diff --git a/examples/test-svc-acct/src/main.rs b/examples/test-svc-acct/src/main.rs index 3d18fdc..6ad49f1 100644 --- a/examples/test-svc-acct/src/main.rs +++ b/examples/test-svc-acct/src/main.rs @@ -8,14 +8,15 @@ async fn main() { let creds = yup_oauth2::service_account_key_from_file(path::Path::new("serviceaccount.json")).unwrap(); let sa = yup_oauth2::ServiceAccountAccess::new(creds).build(); + let scopes = &["https://www.googleapis.com/auth/pubsub"]; let tok = sa - .token(vec!["https://www.googleapis.com/auth/pubsub"]) + .token(scopes) .await .unwrap(); println!("token is: {:?}", tok); let tok = sa - .token(vec!["https://www.googleapis.com/auth/pubsub"]) + .token(scopes) .await .unwrap(); println!("cached token is {:?} and should be identical", tok); diff --git a/src/authenticator.rs b/src/authenticator.rs index f39461f..bbd6781 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -192,7 +192,11 @@ where AD: 'static + AuthenticatorDelegate, C: 'static + hyper::client::connect::Connect + Clone + Send, { - async fn get_token(&self, scope_key: u64, scopes: Vec) -> Result { + async fn get_token(&self, scopes: &[T]) -> Result + where + T: AsRef + Sync, + { + let scope_key = hash_scopes(scopes); let store = self.store.clone(); let delegate = &self.delegate; let client = self.client.clone(); @@ -200,8 +204,8 @@ where let gettoken = self.inner.clone(); loop { match store.get( - scope_key.clone(), - &scopes, + scope_key, + scopes, ) { Ok(Some(t)) => { if !t.expired() { @@ -234,7 +238,7 @@ where RefreshResult::Success(t) => { let x = store.set( scope_key, - &scopes, + scopes, Some(t.clone()), ); if let Err(e) = x { @@ -251,10 +255,10 @@ where } Ok(None) => { let store = store.clone(); - let t = gettoken.token(scopes.clone()).await?; + let t = gettoken.token(scopes).await?; if let Err(e) = store.set( scope_key, - &scopes, + scopes, Some(t.clone()), ) { match delegate.token_storage_failure(true, &e) { @@ -291,15 +295,13 @@ impl< self.inner.application_secret() } - fn token<'a, I, T>( + fn token<'a, T>( &'a self, - scopes: I, + scopes: &'a [T], ) -> Pin> + Send + 'a>> where - T: Into, - I: IntoIterator, + T: AsRef + Sync, { - let (scope_key, scopes) = hash_scopes(scopes); - Box::pin(self.get_token(scope_key, scopes)) + Box::pin(self.get_token(scopes)) } } diff --git a/src/device.rs b/src/device.rs index a924f17..018bde7 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,4 +1,3 @@ -use std::iter::{FromIterator, IntoIterator}; use std::pin::Pin; use std::time::Duration; @@ -7,7 +6,6 @@ use chrono::{self, Utc}; use futures::{prelude::*}; use hyper; use hyper::header; -use itertools::Itertools; use serde_json as json; use url::form_urlencoded; @@ -104,15 +102,14 @@ where FD: FlowDelegate + 'static, C: hyper::client::connect::Connect + 'static, { - fn token<'a, I, T>( + fn token<'a, T>( &'a self, - scopes: I, + scopes: &'a [T], ) -> Pin> + Send + 'a>> where - T: Into, - I: IntoIterator, + T: AsRef + Sync, { - Box::pin(self.retrieve_device_token(Vec::from_iter(scopes.into_iter().map(Into::into)))) + Box::pin(self.retrieve_device_token(scopes)) } fn api_key(&self) -> Option { None @@ -131,10 +128,13 @@ where { /// Essentially what `GetToken::token` does: Retrieve a token for the given scopes without /// caching. - pub async fn retrieve_device_token<'a>( + pub async fn retrieve_device_token( &self, - scopes: Vec, - ) -> Result { + scopes: &[T], + ) -> Result + where + T: AsRef, + { let application_secret = self.application_secret.clone(); let client = self.client.clone(); let wait = self.wait; @@ -193,24 +193,21 @@ where /// * If called after a successful result was returned at least once. /// # Examples /// See test-cases in source code for a more complete example. - async fn request_code( + async fn request_code( application_secret: ApplicationSecret, client: hyper::Client, device_code_url: String, - scopes: Vec, - ) -> Result<(PollInformation, String), RequestError> { + scopes: &[T], + ) -> Result<(PollInformation, String), RequestError> + where + T: AsRef, + { // note: cloned() shouldn't be needed, see issue // https://github.com/servo/rust-url/issues/81 let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ ("client_id", application_secret.client_id.clone()), - ( - "scope", - scopes - .into_iter() - .intersperse(" ".to_string()) - .collect::(), - ), + ("scope", crate::helper::join(scopes, " ")), ]) .finish(); @@ -409,7 +406,7 @@ mod tests { let fut = async { let token = flow - .token(vec!["https://www.googleapis.com/scope/1"]) + .token(&["https://www.googleapis.com/scope/1"]) .await .unwrap(); assert_eq!("accesstoken", token.access_token); @@ -441,7 +438,7 @@ mod tests { .create(); let fut = async { - let res = flow.token(vec!["https://www.googleapis.com/scope/1"]).await; + let res = flow.token(&["https://www.googleapis.com/scope/1"]).await; assert!(res.is_err()); assert!(format!("{}", res.unwrap_err()).contains("invalid_client_id")); Ok(()) as Result<(), ()> @@ -471,7 +468,7 @@ mod tests { .create(); let fut = async { - let res = flow.token(vec!["https://www.googleapis.com/scope/1"]).await; + let res = flow.token(&["https://www.googleapis.com/scope/1"]).await; assert!(res.is_err()); assert!(format!("{}", res.unwrap_err()).contains("Access denied by user")); Ok(()) as Result<(), ()> diff --git a/src/helper.rs b/src/helper.rs index c9471fc..7ef7d06 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -61,3 +61,25 @@ pub fn service_account_key_from_file>(path: S) -> io::Result Ok(decoded), } } + +pub(crate) fn join(pieces: &[T], separator: &str) -> String +where + T: AsRef, +{ + let mut iter = pieces.iter(); + let first = match iter.next() { + Some(p) => p, + None => return String::new(), + }; + let num_separators = pieces.len() - 1; + let pieces_size: usize = pieces.iter().map(|p| p.as_ref().len()).sum(); + let size = pieces_size + separator.len() * num_separators; + let mut result = String::with_capacity(size); + result.push_str(first.as_ref()); + for p in iter { + result.push_str(separator); + result.push_str(p.as_ref()); + } + debug_assert_eq!(size, result.len()); + result +} \ No newline at end of file diff --git a/src/installed.rs b/src/installed.rs index ba6758e..26ee573 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -66,15 +66,14 @@ where FD: FlowDelegate + 'static, C: hyper::client::connect::Connect + 'static, { - fn token<'a, I, T>( + fn token<'a, T>( &'a self, - scopes: I, + scopes: &'a [T], ) -> Pin> + Send + 'a>> where - T: Into, - I: IntoIterator, + T: AsRef + Sync, { - Box::pin(self.obtain_token(scopes.into_iter().map(Into::into).collect())) + Box::pin(self.obtain_token(scopes)) } fn api_key(&self) -> Option { None @@ -175,27 +174,29 @@ where /// . Return that token /// /// It's recommended not to use the DefaultFlowDelegate, but a specialized one. - async fn obtain_token<'a>( + async fn obtain_token( &self, - scopes: Vec, // Note: I haven't found a better way to give a list of strings here, due to ownership issues with futures. - ) -> Result { + scopes: &[T], + ) -> Result + where + T: AsRef, + { match self.method { InstalledFlowReturnMethod::HTTPRedirect(port) => { - self.ask_auth_code_via_http(scopes.iter(), port).await + self.ask_auth_code_via_http(scopes, port).await } InstalledFlowReturnMethod::HTTPRedirectEphemeral => { - self.ask_auth_code_via_http(scopes.iter(), 0).await + self.ask_auth_code_via_http(scopes, 0).await } InstalledFlowReturnMethod::Interactive => { - self.ask_auth_code_interactively(scopes.iter()).await + self.ask_auth_code_interactively(scopes).await } } } - async fn ask_auth_code_interactively<'a, S, T>(&self, scopes: S) -> Result + async fn ask_auth_code_interactively(&self, scopes: &[T]) -> Result where - T: AsRef + 'a, - S: Iterator, + T: AsRef, { let auth_delegate = &self.fd; let appsecret = &self.appsecret; @@ -223,14 +224,13 @@ where self.exchange_auth_code(authcode, None).await } - async fn ask_auth_code_via_http<'a, S, T>( + async fn ask_auth_code_via_http( &self, - scopes: S, + scopes: &[T], desired_port: u16, ) -> Result where - T: AsRef + 'a, - S: Iterator, + T: AsRef, { let auth_delegate = &self.fd; let appsecret = &self.appsecret; @@ -583,7 +583,7 @@ mod tests { let fut = || { async { let tok = inf - .token(vec!["https://googleapis.com/some/scope"]) + .token(&["https://googleapis.com/some/scope"]) .await .map_err(|_| ())?; assert_eq!("accesstoken", tok.access_token); @@ -612,7 +612,7 @@ mod tests { let fut = async { let tok = inf - .token(vec!["https://googleapis.com/some/scope"]) + .token(&["https://googleapis.com/some/scope"]) .await .map_err(|_| ())?; assert_eq!("accesstoken", tok.access_token); @@ -635,7 +635,7 @@ mod tests { .create(); let fut = async { - let tokr = inf.token(vec!["https://googleapis.com/some/scope"]).await; + let tokr = inf.token(&["https://googleapis.com/some/scope"]).await; assert!(tokr.is_err()); assert!(format!("{}", tokr.unwrap_err()).contains("invalid_code")); Ok(()) as Result<(), ()> diff --git a/src/lib.rs b/src/lib.rs index 523e703..181e1e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,8 +65,7 @@ //! .build() //! .unwrap(); //! -//! let s = "https://www.googleapis.com/auth/drive.file".to_string(); -//! let scopes = vec![s]; +//! let scopes = &["https://www.googleapis.com/auth/drive.file"]; //! //! // token() is the one important function of this crate; it does everything to //! // obtain a token that can be sent e.g. as Bearer token. diff --git a/src/service_account.rs b/src/service_account.rs index 355f77e..88831c8 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -157,17 +157,15 @@ impl JWT { } } -/// Set `iss`, `aud`, `exp`, `iat`, `scope` field in the returned `Claims`. `scopes` is an iterator -/// yielding strings with OAuth scopes. -fn init_claims_from_key<'a, I, T>(key: &ServiceAccountKey, scopes: I) -> Claims +/// Set `iss`, `aud`, `exp`, `iat`, `scope` field in the returned `Claims`. +fn init_claims_from_key(key: &ServiceAccountKey, scopes: &[T]) -> Claims where - T: AsRef + 'a, - I: IntoIterator, + T: AsRef, { let iat = chrono::Utc::now().timestamp(); let expiry = iat + 3600 - 5; // Max validity is 1h. - let mut scopes_string = scopes.into_iter().fold(String::new(), |mut acc, sc| { + let mut scopes_string = scopes.iter().fold(String::new(), |mut acc, sc| { acc.push_str(sc.as_ref()); acc.push_str(" "); acc @@ -271,13 +269,16 @@ where C: hyper::client::connect::Connect + 'static, { /// Send a request for a new Bearer token to the OAuth provider. - async fn request_token( + async fn request_token( client: hyper::client::Client, sub: Option, key: ServiceAccountKey, - scopes: Vec, - ) -> Result { - let mut claims = init_claims_from_key(&key, &scopes); + scopes: &[T], + ) -> Result + where + T: AsRef, + { + let mut claims = init_claims_from_key(&key, scopes); claims.sub = sub.clone(); let signed = JWT::new(claims) .sign(key.private_key.as_ref().unwrap()) @@ -335,12 +336,16 @@ where Ok(token) } - async fn get_token(&self, hash: u64, scopes: Vec) -> Result { + async fn get_token(&self, scopes: &[T]) -> Result + where + T: AsRef, + { + let hash = hash_scopes(scopes); let cache = self.cache.clone(); match cache .lock() .unwrap() - .get(hash, scopes.iter()) + .get(hash, scopes) { Ok(Some(token)) if !token.expired() => return Ok(token), _ => {} @@ -349,12 +354,12 @@ where self.client.clone(), self.sub.clone(), self.key.clone(), - scopes.iter().map(|s| s.to_string()).collect(), + scopes, ) .await?; let _ = cache.lock().unwrap().set( hash, - scopes.iter(), + scopes, Some(token.clone()), ); Ok(token) @@ -365,16 +370,14 @@ impl GetToken for ServiceAccountAccessImpl where C: hyper::client::connect::Connect + 'static, { - fn token<'a, I, T>( + fn token<'a, T>( &'a self, - scopes: I, + scopes: &'a [T], ) -> Pin> + Send + 'a>> where - T: Into, - I: IntoIterator, + T: AsRef + Sync, { - let (hash, scps0) = hash_scopes(scopes); - Box::pin(self.get_token(hash, scps0)) + Box::pin(self.get_token(scopes)) } /// Returns an empty ApplicationSecret as tokens for service accounts don't need to be @@ -449,7 +452,7 @@ mod tests { let acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None); let fut = async { let tok = acc - .token(vec!["https://www.googleapis.com/auth/pubsub"]) + .token(&["https://www.googleapis.com/auth/pubsub"]) .await?; assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); assert_eq!(Some(3600), tok.expires_in); @@ -463,14 +466,14 @@ mod tests { .unwrap() .get( 3502164897243251857, - ["https://www.googleapis.com/auth/pubsub"].iter(), + &["https://www.googleapis.com/auth/pubsub"], ) .unwrap() .is_some()); // Test that token is in cache (otherwise mock will tell us) let fut = async { let tok = acc - .token(vec!["https://www.googleapis.com/auth/pubsub"]) + .token(&["https://www.googleapis.com/auth/pubsub"]) .await?; assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); assert_eq!(Some(3600), tok.expires_in); @@ -492,7 +495,7 @@ mod tests { .build(); let fut = async { let result = acc - .token(vec!["https://www.googleapis.com/auth/pubsub"]) + .token(&["https://www.googleapis.com/auth/pubsub"]) .await; assert!(result.is_err()); Ok(()) as Result<(), ()> @@ -522,7 +525,7 @@ mod tests { rt.block_on(async { println!( "{:?}", - acc.token(vec!["https://www.googleapis.com/auth/pubsub"]) + acc.token(&["https://www.googleapis.com/auth/pubsub"]) .await ); }); diff --git a/src/storage.rs b/src/storage.rs index 551c563..c1403ca 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -26,34 +26,35 @@ pub trait TokenStorage: Send + Sync { /// If `token` is None, it is invalid or revoked and should be removed from storage. /// Otherwise, it should be saved. - fn set( + fn set( &self, scope_hash: u64, - scopes: I, + scopes: &[T], token: Option, ) -> Result<(), Self::Error> where - I: IntoIterator, - I::Item: AsRef; + T: AsRef; /// A `None` result indicates that there is no token for the given scope_hash. - fn get(&self, scope_hash: u64, scopes: I) -> Result, Self::Error> + fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> where - I: IntoIterator + Clone, - I::Item: AsRef; + T: AsRef; } -/// Calculate a hash value describing the scopes, and return a sorted Vec of the scopes. -pub fn hash_scopes(scopes: I) -> (u64, Vec) +/// Calculate a hash value describing the scopes. The order of the scopes in the +/// list does not change the hash value. i.e. two lists that contains the exact +/// same scopes, but in different order will return the same hash value. +pub fn hash_scopes(scopes: &[T]) -> u64 where - T: Into, - I: IntoIterator, + T: AsRef, { - let mut sv: Vec = scopes.into_iter().map(Into::into).collect(); - sv.sort(); - let mut sh = DefaultHasher::new(); - sv.hash(&mut sh); - (sh.finish(), sv) + let mut hash_sum = DefaultHasher::new().finish(); + for scope in scopes { + let mut hasher = DefaultHasher::new(); + scope.as_ref().hash(&mut hasher); + hash_sum ^= hasher.finish(); + } + hash_sum } /// A storage that remembers nothing. @@ -62,18 +63,16 @@ pub struct NullStorage; impl TokenStorage for NullStorage { type Error = std::convert::Infallible; - fn set(&self, _: u64, _: I, _: Option) -> Result<(), Self::Error> + fn set(&self, _: u64, _: &[T], _: Option) -> Result<(), Self::Error> where - I: IntoIterator, - I::Item: AsRef, + T: AsRef { Ok(()) } - fn get(&self, _: u64, _: I) -> Result, Self::Error> + fn get(&self, _: u64, _: &[T]) -> Result, Self::Error> where - I: IntoIterator + Clone, - I::Item: AsRef, + T: AsRef { Ok(None) } @@ -94,15 +93,14 @@ impl MemoryStorage { impl TokenStorage for MemoryStorage { type Error = std::convert::Infallible; - fn set( + fn set( &self, scope_hash: u64, - scopes: I, + scopes: &[T], token: Option, ) -> Result<(), Self::Error> where - I: IntoIterator, - I::Item: AsRef, + T: AsRef { let mut tokens = self.tokens.lock().expect("poisoned mutex"); let matched = tokens.iter().find_position(|x| x.hash == scope_hash); @@ -124,10 +122,9 @@ impl TokenStorage for MemoryStorage { Ok(()) } - fn get(&self, scope_hash: u64, scopes: I) -> Result, Self::Error> + fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> where - I: IntoIterator + Clone, - I::Item: AsRef, + T: AsRef { let tokens = self.tokens.lock().expect("poisoned mutex"); Ok(token_for_scopes(&tokens, scope_hash, scopes)) @@ -225,15 +222,14 @@ fn load_from_file(filename: &Path) -> Result, io::Error> { impl TokenStorage for DiskTokenStorage { type Error = io::Error; - fn set( + fn set( &self, scope_hash: u64, - scopes: I, + scopes: &[T], token: Option, ) -> Result<(), Self::Error> where - I: IntoIterator, - I::Item: AsRef, + T: AsRef { { let mut tokens = self.tokens.lock().expect("poisoned mutex"); @@ -257,24 +253,22 @@ impl TokenStorage for DiskTokenStorage { self.dump_to_file() } - fn get(&self, scope_hash: u64, scopes: I) -> Result, Self::Error> + fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> where - I: IntoIterator + Clone, - I::Item: AsRef, + T: AsRef { let tokens = self.tokens.lock().expect("poisoned mutex"); Ok(token_for_scopes(&tokens, scope_hash, scopes)) } } -fn token_for_scopes(tokens: &[JSONToken], scope_hash: u64, scopes: I) -> Option +fn token_for_scopes(tokens: &[JSONToken], scope_hash: u64, scopes: &[T]) -> Option where - I: IntoIterator + Clone, - I::Item: AsRef, + T: AsRef, { for t in tokens.iter() { if let Some(token_scopes) = &t.scopes { - if scopes.clone().into_iter().all(|s| token_scopes.iter().any(|t| t == s.as_ref())) { + if scopes.iter().all(|s| token_scopes.iter().any(|t| t == s.as_ref())) { return Some(t.token.clone()); } } else if scope_hash == t.hash { @@ -283,3 +277,21 @@ where } None } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hash_scopes() { + // Idential list should hash equal. + assert_eq!(hash_scopes(&["foo", "bar"]), hash_scopes(&["foo", "bar"])); + // The hash should be order independent. + assert_eq!(hash_scopes(&["bar", "foo"]), hash_scopes(&["foo", "bar"])); + assert_eq!(hash_scopes(&["bar", "baz", "bat"]), hash_scopes(&["baz", "bar", "bat"])); + + // Ensure hashes differ when the contents are different by more than + // just order. + assert_ne!(hash_scopes(&["foo", "bar", "baz"]), hash_scopes(&["foo", "bar"])); + } +} diff --git a/src/types.rs b/src/types.rs index a47a0fa..5d5b54f 100644 --- a/src/types.rs +++ b/src/types.rs @@ -236,13 +236,12 @@ impl FromStr for Scheme { /// The `api_key()` method is an alternative in case there are no scopes or /// if no user is involved. pub trait GetToken: Send + Sync { - fn token<'a, I, T>( + fn token<'a, T>( &'a self, - scopes: I, + scopes: &'a [T], ) -> Pin> + Send + 'a>> where - T: Into, - I: IntoIterator; + T: AsRef + Sync; fn api_key(&self) -> Option; From 9542e3a9f18bc33effa6e5d72aefa66b857d17e1 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 8 Nov 2019 13:40:34 -0800 Subject: [PATCH 06/71] Remove instances of cloning ApplicationSecret ApplicationSecret is not a small struct. This removes the instances where it's cloned in favor of passing a shared reference. --- src/authenticator.rs | 4 +-- src/authenticator_delegate.rs | 2 +- src/device.rs | 22 ++++++------ src/installed.rs | 64 +++++++++++++++++------------------ src/refresh.rs | 16 ++++----- src/service_account.rs | 5 +-- src/types.rs | 18 +++++++++- 7 files changed, 73 insertions(+), 58 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index bbd6781..32f67ca 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -216,7 +216,7 @@ where let store = store.clone(); let rr = RefreshFlow::refresh_token( client.clone(), - appsecret.clone(), + appsecret, refresh_token.unwrap(), ) .await?; @@ -291,7 +291,7 @@ impl< self.inner.api_key() } /// Returns the application secret of the inner flow. - fn application_secret(&self) -> ApplicationSecret { + fn application_secret(&self) -> &ApplicationSecret { self.inner.application_secret() } diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 5312c13..4c25969 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -133,7 +133,7 @@ pub trait FlowDelegate: Clone + Send + Sync { } /// Configure a custom redirect uri if needed. - fn redirect_uri(&self) -> Option { + fn redirect_uri(&self) -> Option<&str> { None } /// The server has returned a `user_code` which must be shown to the user, diff --git a/src/device.rs b/src/device.rs index 018bde7..c8573f2 100644 --- a/src/device.rs +++ b/src/device.rs @@ -114,8 +114,8 @@ where fn api_key(&self) -> Option { None } - fn application_secret(&self) -> ApplicationSecret { - self.application_secret.clone() + fn application_secret(&self) -> &ApplicationSecret { + &self.application_secret } } @@ -135,12 +135,12 @@ where where T: AsRef, { - let application_secret = self.application_secret.clone(); + let application_secret = &self.application_secret; let client = self.client.clone(); let wait = self.wait; let fd = self.fd.clone(); let (pollinf, device_code) = Self::request_code( - application_secret.clone(), + application_secret, client.clone(), self.device_code_url.clone(), scopes, @@ -153,7 +153,7 @@ where let pollinf = pollinf.clone(); tokio::timer::delay_for(pollinf.interval).await; let r = Self::poll_token( - application_secret.clone(), + application_secret, client.clone(), device_code.clone(), pollinf.clone(), @@ -194,7 +194,7 @@ where /// # Examples /// See test-cases in source code for a more complete example. async fn request_code( - application_secret: ApplicationSecret, + application_secret: &ApplicationSecret, client: hyper::Client, device_code_url: String, scopes: &[T], @@ -206,8 +206,8 @@ where // https://github.com/servo/rust-url/issues/81 let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ - ("client_id", application_secret.client_id.clone()), - ("scope", crate::helper::join(scopes, " ")), + ("client_id", application_secret.client_id.as_str()), + ("scope", crate::helper::join(scopes, " ").as_str()), ]) .finish(); @@ -276,7 +276,7 @@ where /// # Examples /// See test-cases in source code for a more complete example. async fn poll_token<'a>( - application_secret: ApplicationSecret, + application_secret: &ApplicationSecret, client: hyper::Client, device_code: String, pi: PollInformation, @@ -290,8 +290,8 @@ where // We should be ready for a new request let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ - ("client_id", &application_secret.client_id[..]), - ("client_secret", &application_secret.client_secret), + ("client_id", application_secret.client_id.as_str()), + ("client_secret", application_secret.client_secret.as_str()), ("code", &device_code), ("grant_type", "http://oauth.net/grant_type/device/1.0"), ]) diff --git a/src/installed.rs b/src/installed.rs index 26ee573..0f1db98 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -24,24 +24,17 @@ const OOB_REDIRECT_URI: &'static str = "urn:ietf:wg:oauth:2.0:oob"; /// Assembles a URL to request an authorization token (with user interaction). /// Note that the redirect_uri here has to be either None or some variation of /// http://localhost:{port}, or the authorization won't work (error "redirect_uri_mismatch") -fn build_authentication_request_url<'a, T, I>( +fn build_authentication_request_url( auth_uri: &str, client_id: &str, - scopes: I, - redirect_uri: Option, + scopes: &[T], + redirect_uri: Option<&str>, ) -> String where - T: AsRef + 'a, - I: IntoIterator, + T: AsRef, { let mut url = String::new(); - let mut scopes_string = scopes.into_iter().fold(String::new(), |mut acc, sc| { - acc.push_str(sc.as_ref()); - acc.push_str(" "); - acc - }); - // Remove last space - scopes_string.pop(); + let scopes_string = crate::helper::join(scopes, " "); url.push_str(auth_uri); vec![ @@ -49,7 +42,7 @@ where format!("&access_type=offline"), format!( "&redirect_uri={}", - redirect_uri.unwrap_or(OOB_REDIRECT_URI.to_string()) + redirect_uri.unwrap_or(OOB_REDIRECT_URI) ), format!("&response_type=code"), format!("&client_id={}", client_id), @@ -78,8 +71,8 @@ where fn api_key(&self) -> Option { None } - fn application_secret(&self) -> ApplicationSecret { - self.appsecret.clone() + fn application_secret(&self) -> &ApplicationSecret { + &self.appsecret } } @@ -232,6 +225,7 @@ where where T: AsRef, { + use std::borrow::Cow; let auth_delegate = &self.fd; let appsecret = &self.appsecret; let server = InstalledFlowServer::run(desired_port)?; @@ -240,13 +234,15 @@ where // Present url to user. // The redirect URI must be this very localhost URL, otherwise authorization is refused // by certain providers. + let redirect_uri: Cow = match auth_delegate.redirect_uri() { + Some(uri) => uri.into(), + None => format!("http://localhost:{}", bound_port).into(), + }; let url = build_authentication_request_url( &appsecret.auth_uri, &appsecret.client_id, scopes, - auth_delegate - .redirect_uri() - .or_else(|| Some(format!("http://localhost:{}", bound_port))), + Some(redirect_uri.as_ref()), ); let _ = auth_delegate .present_user_url(&url, false /* need code */) @@ -262,8 +258,8 @@ where port: Option, ) -> Result { let appsec = &self.appsecret; - let redirect_uri = &self.fd.redirect_uri(); - let request = Self::request_token(appsec.clone(), authcode, redirect_uri.clone(), port); + let redirect_uri = self.fd.redirect_uri(); + let request = Self::request_token(appsec, authcode, redirect_uri, port); let resp = self .client .request(request) @@ -310,27 +306,29 @@ where /// Sends the authorization code to the provider in order to obtain access and refresh tokens. fn request_token<'a>( - appsecret: ApplicationSecret, + appsecret: &ApplicationSecret, authcode: String, - custom_redirect_uri: Option, + custom_redirect_uri: Option<&str>, port: Option, ) -> hyper::Request { - let redirect_uri = custom_redirect_uri.unwrap_or_else(|| match port { - None => OOB_REDIRECT_URI.to_string(), - Some(port) => format!("http://localhost:{}", port), - }); + use std::borrow::Cow; + let redirect_uri: Cow = match (custom_redirect_uri, port) { + (Some(uri), _) => uri.into(), + (None, Some(port)) => format!("http://localhost:{}", port).into(), + (None, None) => OOB_REDIRECT_URI.into(), + }; let body = form_urlencoded::Serializer::new(String::new()) .extend_pairs(vec![ - ("code".to_string(), authcode.to_string()), - ("client_id".to_string(), appsecret.client_id.clone()), - ("client_secret".to_string(), appsecret.client_secret.clone()), - ("redirect_uri".to_string(), redirect_uri), - ("grant_type".to_string(), "authorization_code".to_string()), + ("code", authcode.as_str()), + ("client_id", appsecret.client_id.as_str()), + ("client_secret", appsecret.client_secret.as_str()), + ("redirect_uri", redirect_uri.as_ref()), + ("grant_type", "authorization_code"), ]) .finish(); - let request = hyper::Request::post(appsecret.token_uri) + let request = hyper::Request::post(&appsecret.token_uri) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(body)) .unwrap(); // TODO: error check @@ -657,7 +655,7 @@ mod tests { "https://accounts.google.com/o/oauth2/auth", "812741506391-h38jh0j4fv0ce1krdkiq0hfvt6n5am\ rf.apps.googleusercontent.com", - vec![&"email".to_string(), &"profile".to_string()], + &["email", "profile"], None ) ); diff --git a/src/refresh.rs b/src/refresh.rs index 717016f..da5bced 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -32,20 +32,20 @@ impl RefreshFlow { /// Please see the crate landing page for an example. pub async fn refresh_token( client: hyper::Client, - client_secret: ApplicationSecret, + client_secret: &ApplicationSecret, refresh_token: String, ) -> Result { // TODO: Does this function ever return RequestError? Maybe have it just return RefreshResult. let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ - ("client_id", client_secret.client_id.clone()), - ("client_secret", client_secret.client_secret.clone()), - ("refresh_token", refresh_token.to_string()), - ("grant_type", "refresh_token".to_string()), + ("client_id", client_secret.client_id.as_str()), + ("client_secret", client_secret.client_secret.as_str()), + ("refresh_token", refresh_token.as_str()), + ("grant_type", "refresh_token"), ]) .finish(); - let request = hyper::Request::post(client_secret.token_uri.clone()) + let request = hyper::Request::post(&client_secret.token_uri) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(req)) .unwrap(); // TODO: error handling @@ -130,7 +130,7 @@ mod tests { let fut = async { let rr = RefreshFlow::refresh_token( client.clone(), - app_secret.clone(), + &app_secret, refresh_token.clone(), ) .await @@ -158,7 +158,7 @@ mod tests { .create(); let fut = async { - let rr = RefreshFlow::refresh_token(client, app_secret, refresh_token) + let rr = RefreshFlow::refresh_token(client, &app_secret, refresh_token) .await .unwrap(); match rr { diff --git a/src/service_account.rs b/src/service_account.rs index 88831c8..0c2298d 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -382,8 +382,9 @@ where /// Returns an empty ApplicationSecret as tokens for service accounts don't need to be /// refreshed (they are simply reissued). - fn application_secret(&self) -> ApplicationSecret { - Default::default() + fn application_secret(&self) -> &ApplicationSecret { + static APP_SECRET: ApplicationSecret = ApplicationSecret::empty(); + &APP_SECRET } fn api_key(&self) -> Option { diff --git a/src/types.rs b/src/types.rs index 5d5b54f..697ddb0 100644 --- a/src/types.rs +++ b/src/types.rs @@ -247,7 +247,7 @@ pub trait GetToken: Send + Sync { /// Return an application secret with at least token_uri, client_secret, and client_id filled /// in. This is used for refreshing tokens without interaction from the flow. - fn application_secret(&self) -> ApplicationSecret; + fn application_secret(&self) -> &ApplicationSecret; } /// Represents a token as returned by OAuth2 servers. @@ -342,6 +342,22 @@ pub struct ApplicationSecret { pub client_x509_cert_url: Option, } +impl ApplicationSecret { + pub const fn empty() -> Self { + ApplicationSecret{ + client_id: String::new(), + client_secret: String::new(), + token_uri: String::new(), + auth_uri: String::new(), + redirect_uris: Vec::new(), + project_id: None, + client_email: None, + auth_provider_x509_cert_url: None, + client_x509_cert_url: None, + } + } +} + /// A type to facilitate reading and writing the json secret file /// as returned by the [google developer console](https://code.google.com/apis/console) #[derive(Deserialize, Serialize, Default)] From 916aaa84e9a13093833d4bc06b605e7f9a5c99a6 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 8 Nov 2019 13:44:51 -0800 Subject: [PATCH 07/71] Authenticator.store no longer needs to be reference counted. --- src/authenticator.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index 32f67ca..32cc348 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -29,7 +29,7 @@ struct AuthenticatorImpl< > { client: hyper::Client, inner: Arc, - store: Arc, + store: S, delegate: AD, } @@ -173,7 +173,7 @@ where C::Connector: 'static + Clone + Send, { let client = self.client.build_hyper_client(); - let store = Arc::new(self.store?); + let store = self.store?; let inner = Arc::new(self.token_getter.build_token_getter(client.clone())); Ok(AuthenticatorImpl { @@ -197,7 +197,7 @@ where T: AsRef + Sync, { let scope_key = hash_scopes(scopes); - let store = self.store.clone(); + let store = &self.store; let delegate = &self.delegate; let client = self.client.clone(); let appsecret = self.inner.application_secret(); @@ -213,7 +213,6 @@ where } // Implement refresh flow. let refresh_token = t.refresh_token.clone(); - let store = store.clone(); let rr = RefreshFlow::refresh_token( client.clone(), appsecret, @@ -254,7 +253,6 @@ where } } Ok(None) => { - let store = store.clone(); let t = gettoken.token(scopes).await?; if let Err(e) = store.set( scope_key, From a0c73d6087299e76d47ccfadb7f8a7e50b180601 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 8 Nov 2019 13:50:41 -0800 Subject: [PATCH 08/71] No need to clone the hyper::Client The ownership behavior is straightforward and more clear when not cloning arbitrary handles. --- src/authenticator.rs | 4 ++-- src/refresh.rs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index 32cc348..bc3a09a 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -199,7 +199,7 @@ where let scope_key = hash_scopes(scopes); let store = &self.store; let delegate = &self.delegate; - let client = self.client.clone(); + let client = &self.client; let appsecret = self.inner.application_secret(); let gettoken = self.inner.clone(); loop { @@ -214,7 +214,7 @@ where // Implement refresh flow. let refresh_token = t.refresh_token.clone(); let rr = RefreshFlow::refresh_token( - client.clone(), + client, appsecret, refresh_token.unwrap(), ) diff --git a/src/refresh.rs b/src/refresh.rs index da5bced..3f60561 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -31,7 +31,7 @@ impl RefreshFlow { /// # Examples /// Please see the crate landing page for an example. pub async fn refresh_token( - client: hyper::Client, + client: &hyper::Client, client_secret: &ApplicationSecret, refresh_token: String, ) -> Result { @@ -129,7 +129,7 @@ mod tests { .create(); let fut = async { let rr = RefreshFlow::refresh_token( - client.clone(), + &client, &app_secret, refresh_token.clone(), ) @@ -158,7 +158,7 @@ mod tests { .create(); let fut = async { - let rr = RefreshFlow::refresh_token(client, &app_secret, refresh_token) + let rr = RefreshFlow::refresh_token(&client, &app_secret, refresh_token) .await .unwrap(); match rr { From e9b2a3a0764c181322ddb2eb41e063d6322f0564 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 8 Nov 2019 13:53:21 -0800 Subject: [PATCH 09/71] The inner GetToken on Authenticator no longer needs to be reference counted. --- src/authenticator.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index bc3a09a..7fd25f3 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -9,7 +9,6 @@ use std::error::Error; use std::io; use std::path::Path; use std::pin::Pin; -use std::sync::Arc; /// Authenticator abstracts different `GetToken` implementations behind one type and handles /// caching received tokens. It's important to use it (instead of the flows directly) because @@ -28,7 +27,7 @@ struct AuthenticatorImpl< C: hyper::client::connect::Connect, > { client: hyper::Client, - inner: Arc, + inner: T, store: S, delegate: AD, } @@ -174,7 +173,7 @@ where { let client = self.client.build_hyper_client(); let store = self.store?; - let inner = Arc::new(self.token_getter.build_token_getter(client.clone())); + let inner = self.token_getter.build_token_getter(client.clone()); Ok(AuthenticatorImpl { client, @@ -200,8 +199,8 @@ where let store = &self.store; let delegate = &self.delegate; let client = &self.client; - let appsecret = self.inner.application_secret(); - let gettoken = self.inner.clone(); + let gettoken = &self.inner; + let appsecret = gettoken.application_secret(); loop { match store.get( scope_key, From bf0136067f301683d1a26788421b684e2915c63c Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 8 Nov 2019 14:10:24 -0800 Subject: [PATCH 10/71] Make some of the helpers a bit more idiomatic. --- src/helper.rs | 55 ++++++++++++++++++++------------------------------- 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/src/helper.rs b/src/helper.rs index 7ef7d06..cd3e3e2 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -9,57 +9,44 @@ use serde_json; -use std::fs; -use std::io::{self, Read}; +use std::io; use std::path::Path; use crate::service_account::ServiceAccountKey; use crate::types::{ApplicationSecret, ConsoleApplicationSecret}; /// Read an application secret from a file. -pub fn read_application_secret(path: &Path) -> io::Result { - let mut secret = String::new(); - let mut file = fs::OpenOptions::new().read(true).open(path)?; - file.read_to_string(&mut secret)?; - - parse_application_secret(&secret) +pub fn read_application_secret>(path: P) -> io::Result { + parse_application_secret(std::fs::read_to_string(path)?) } /// Read an application secret from a JSON string. pub fn parse_application_secret>(secret: S) -> io::Result { - let result: serde_json::Result = - serde_json::from_str(secret.as_ref()); - match result { - Err(e) => Err(io::Error::new( + let decoded: ConsoleApplicationSecret = serde_json::from_str(secret.as_ref()).map_err(|e| { + io::Error::new( io::ErrorKind::InvalidData, format!("Bad application secret: {}", e), - )), - Ok(decoded) => { - if decoded.web.is_some() { - Ok(decoded.web.unwrap()) - } else if decoded.installed.is_some() { - Ok(decoded.installed.unwrap()) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "Unknown application secret format", - )) - } - } + ) + })?; + + if let Some(web) = decoded.web { + Ok(web) + } else if let Some(installed) = decoded.installed { + Ok(installed) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "Unknown application secret format", + )) } } /// Read a service account key from a JSON file. You can download the JSON keys from the Google /// Cloud Console or the respective console of your service provider. pub fn service_account_key_from_file>(path: S) -> io::Result { - let mut key = String::new(); - let mut file = fs::OpenOptions::new().read(true).open(path)?; - file.read_to_string(&mut key)?; - - match serde_json::from_str(&key) { - Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, format!("{}", e))), - Ok(decoded) => Ok(decoded), - } + let key = std::fs::read_to_string(path)?; + serde_json::from_str(&key) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("{}", e))) } pub(crate) fn join(pieces: &[T], separator: &str) -> String @@ -82,4 +69,4 @@ where } debug_assert_eq!(size, result.len()); result -} \ No newline at end of file +} From 29f800ba7f9adfe5c0e61723733721b7b2ecec82 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 8 Nov 2019 14:10:46 -0800 Subject: [PATCH 11/71] Some more improvements to reduce unnecessary allocations. --- src/authenticator.rs | 3 +-- src/device.rs | 10 +++++----- src/installed.rs | 10 +++++----- src/refresh.rs | 8 ++++---- src/service_account.rs | 18 +++++++++--------- src/storage.rs | 4 ++-- 6 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index 7fd25f3..4b5a5ac 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -211,11 +211,10 @@ where return Ok(t); } // Implement refresh flow. - let refresh_token = t.refresh_token.clone(); let rr = RefreshFlow::refresh_token( client, appsecret, - refresh_token.unwrap(), + &t.refresh_token.as_ref().unwrap(), ) .await?; match rr { diff --git a/src/device.rs b/src/device.rs index c8573f2..44700c0 100644 --- a/src/device.rs +++ b/src/device.rs @@ -142,7 +142,7 @@ where let (pollinf, device_code) = Self::request_code( application_secret, client.clone(), - self.device_code_url.clone(), + &self.device_code_url, scopes, ) .await?; @@ -155,7 +155,7 @@ where let r = Self::poll_token( application_secret, client.clone(), - device_code.clone(), + &device_code, pollinf.clone(), fd.clone(), ) @@ -196,7 +196,7 @@ where async fn request_code( application_secret: &ApplicationSecret, client: hyper::Client, - device_code_url: String, + device_code_url: &str, scopes: &[T], ) -> Result<(PollInformation, String), RequestError> where @@ -278,7 +278,7 @@ where async fn poll_token<'a>( application_secret: &ApplicationSecret, client: hyper::Client, - device_code: String, + device_code: &str, pi: PollInformation, fd: FD, ) -> Result, PollError> { @@ -292,7 +292,7 @@ where .extend_pairs(&[ ("client_id", application_secret.client_id.as_str()), ("client_secret", application_secret.client_secret.as_str()), - ("code", &device_code), + ("code", device_code), ("grant_type", "http://oauth.net/grant_type/device/1.0"), ]) .finish(); diff --git a/src/installed.rs b/src/installed.rs index 0f1db98..3a2a5f0 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -214,7 +214,7 @@ where } _ => return Err(RequestError::UserError("couldn't read code".to_string())), }; - self.exchange_auth_code(authcode, None).await + self.exchange_auth_code(&authcode, None).await } async fn ask_auth_code_via_http( @@ -249,12 +249,12 @@ where .await; let auth_code = server.wait_for_auth_code().await; - self.exchange_auth_code(auth_code, Some(bound_port)).await + self.exchange_auth_code(&auth_code, Some(bound_port)).await } async fn exchange_auth_code( &self, - authcode: String, + authcode: &str, port: Option, ) -> Result { let appsec = &self.appsecret; @@ -307,7 +307,7 @@ where /// Sends the authorization code to the provider in order to obtain access and refresh tokens. fn request_token<'a>( appsecret: &ApplicationSecret, - authcode: String, + authcode: &str, custom_redirect_uri: Option<&str>, port: Option, ) -> hyper::Request { @@ -320,7 +320,7 @@ where let body = form_urlencoded::Serializer::new(String::new()) .extend_pairs(vec![ - ("code", authcode.as_str()), + ("code", authcode), ("client_id", appsecret.client_id.as_str()), ("client_secret", appsecret.client_secret.as_str()), ("redirect_uri", redirect_uri.as_ref()), diff --git a/src/refresh.rs b/src/refresh.rs index 3f60561..8963f92 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -33,14 +33,14 @@ impl RefreshFlow { pub async fn refresh_token( client: &hyper::Client, client_secret: &ApplicationSecret, - refresh_token: String, + refresh_token: &str, ) -> Result { // TODO: Does this function ever return RequestError? Maybe have it just return RefreshResult. let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ ("client_id", client_secret.client_id.as_str()), ("client_secret", client_secret.client_secret.as_str()), - ("refresh_token", refresh_token.as_str()), + ("refresh_token", refresh_token), ("grant_type", "refresh_token"), ]) .finish(); @@ -106,7 +106,7 @@ mod tests { let app_secret = r#"{"installed":{"client_id":"902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com","project_id":"yup-test-243420","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_secret":"iuMPN6Ne1PD7cos29Tk9rlqH","redirect_uris":["urn:ietf:wg:oauth:2.0:oob","http://localhost"]}}"#; let mut app_secret = helper::parse_application_secret(app_secret).unwrap(); app_secret.token_uri = format!("{}/token", server_url); - let refresh_token = "my-refresh-token".to_string(); + let refresh_token = "my-refresh-token"; let https = HttpsConnector::new(); let client = hyper::Client::builder() @@ -131,7 +131,7 @@ mod tests { let rr = RefreshFlow::refresh_token( &client, &app_secret, - refresh_token.clone(), + refresh_token, ) .await .unwrap(); diff --git a/src/service_account.rs b/src/service_account.rs index 0c2298d..eee3fe4 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -270,16 +270,16 @@ where { /// Send a request for a new Bearer token to the OAuth provider. async fn request_token( - client: hyper::client::Client, - sub: Option, - key: ServiceAccountKey, + client: &hyper::client::Client, + sub: Option<&str>, + key: &ServiceAccountKey, scopes: &[T], ) -> Result where T: AsRef, { let mut claims = init_claims_from_key(&key, scopes); - claims.sub = sub.clone(); + claims.sub = sub.map(|x| x.to_owned()); let signed = JWT::new(claims) .sign(key.private_key.as_ref().unwrap()) .map_err(RequestError::LowLevelError)?; @@ -289,7 +289,7 @@ where ("assertion".to_string(), signed), ]) .finish(); - let request = hyper::Request::post(key.token_uri.unwrap()) + let request = hyper::Request::post(key.token_uri.as_ref().unwrap()) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(rqbody)) .unwrap(); @@ -341,7 +341,7 @@ where T: AsRef, { let hash = hash_scopes(scopes); - let cache = self.cache.clone(); + let cache = &self.cache; match cache .lock() .unwrap() @@ -351,9 +351,9 @@ where _ => {} } let token = Self::request_token( - self.client.clone(), - self.sub.clone(), - self.key.clone(), + &self.client, + self.sub.as_ref().map(|x| x.as_str()), + &self.key, scopes, ) .await?; diff --git a/src/storage.rs b/src/storage.rs index c1403ca..51afac4 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -113,7 +113,7 @@ impl TokenStorage for MemoryStorage { tokens.push(JSONToken { hash: scope_hash, scopes: Some(scopes.into_iter().map(|x| x.as_ref().to_string()).collect()), - token: t.clone(), + token: t, }); () } @@ -244,7 +244,7 @@ impl TokenStorage for DiskTokenStorage { tokens.push(JSONToken { hash: scope_hash, scopes: Some(scopes.into_iter().map(|x| x.as_ref().to_string()).collect()), - token: t.clone(), + token: t, }); () } From 2cf2e465d12e2c01354594613beabbae3e3505b5 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 8 Nov 2019 15:22:38 -0800 Subject: [PATCH 12/71] Add JsonErrorOr enum to make json error handling more concise/consistent. JsonErrorOr is an untagged enum that is generic over arbitrary data. This means that when deserializing JsonErrorOr it will first check the json field for an 'error' attribute. If one exists it will deserialize into the JsonErrorOr::Err variant that contains a JsonError. If the message doesn't contain an 'error' field it will attempt to deserialize T into he JsonErrorOr::Data variant. --- src/device.rs | 31 ++++++++++++--------------- src/installed.rs | 48 ++++++++++++------------------------------ src/refresh.rs | 41 +++++++++++++++--------------------- src/service_account.rs | 32 ++++++++++++---------------- src/types.rs | 16 +++++++++++++- 5 files changed, 72 insertions(+), 96 deletions(-) diff --git a/src/device.rs b/src/device.rs index 44700c0..a618243 100644 --- a/src/device.rs +++ b/src/device.rs @@ -11,7 +11,7 @@ use url::form_urlencoded; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate, PollInformation, Retry}; use crate::types::{ - ApplicationSecret, GetToken, JsonError, PollError, RequestError, Token, + ApplicationSecret, GetToken, JsonErrorOr, PollError, RequestError, Token, }; pub const GOOGLE_DEVICE_CODE_URL: &'static str = "https://accounts.google.com/o/oauth2/device/code"; @@ -236,25 +236,20 @@ where } let json_bytes = resp.into_body().try_concat().await?; + match json::from_slice::>(&json_bytes)? { + JsonErrorOr::Err(e) => Err(e.into()), + JsonErrorOr::Data(decoded) => { + let expires_in = decoded.expires_in.unwrap_or(60 * 60); - // check for error - match json::from_slice::(&json_bytes) { - Err(_) => {} // ignore, move on - Ok(res) => return Err(RequestError::from(res)), + let pi = PollInformation { + user_code: decoded.user_code, + verification_url: decoded.verification_uri, + expires_at: Utc::now() + chrono::Duration::seconds(expires_in), + interval: Duration::from_secs(i64::abs(decoded.interval) as u64), + }; + Ok((pi, decoded.device_code)) + } } - - let decoded: JsonData = - json::from_slice(&json_bytes).map_err(|e| RequestError::JSONError(e))?; - - let expires_in = decoded.expires_in.unwrap_or(60 * 60); - - let pi = PollInformation { - user_code: decoded.user_code, - verification_url: decoded.verification_uri, - expires_at: Utc::now() + chrono::Duration::seconds(expires_in), - interval: Duration::from_secs(i64::abs(decoded.interval) as u64), - }; - Ok((pi, decoded.device_code)) } /// If the first call is successful, this method may be called. diff --git a/src/installed.rs b/src/installed.rs index 3a2a5f0..88d39bb 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -17,7 +17,7 @@ use url::form_urlencoded; use url::percent_encoding::{percent_encode, QUERY_ENCODE_SET}; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate}; -use crate::types::{ApplicationSecret, GetToken, RequestError, Token}; +use crate::types::{ApplicationSecret, GetToken, RequestError, Token, JsonErrorOr}; const OOB_REDIRECT_URI: &'static str = "urn:ietf:wg:oauth:2.0:oob"; @@ -270,24 +270,21 @@ where .try_concat() .await .map_err(|e| RequestError::ClientError(e))?; - let tokens: JSONTokenResponse = - serde_json::from_slice(&body).map_err(|e| RequestError::JSONError(e))?; - match tokens { - JSONTokenResponse { - error: Some(err), - error_description, - .. - } => Err(RequestError::NegativeServerResponse(err, error_description)), - JSONTokenResponse { - access_token: Some(access_token), - refresh_token, - token_type: Some(token_type), - expires_in, - .. - } => { + + #[derive(Deserialize)] + struct JSONTokenResponse { + access_token: String, + refresh_token: String, + token_type: String, + expires_in: Option, + } + + match serde_json::from_slice::>(&body)? { + JsonErrorOr::Err(err) => Err(err.into()), + JsonErrorOr::Data(JSONTokenResponse{access_token, refresh_token, token_type, expires_in}) => { let mut token = Token { access_token, - refresh_token, + refresh_token: Some(refresh_token), token_type, expires_in, expires_in_timestamp: None, @@ -295,12 +292,6 @@ where token.set_expiry_absolute(); Ok(token) } - JSONTokenResponse { - error_description, .. - } => Err(RequestError::NegativeServerResponse( - "".to_owned(), - error_description, - )), } } @@ -336,17 +327,6 @@ where } } -#[derive(Deserialize)] -struct JSONTokenResponse { - access_token: Option, - refresh_token: Option, - token_type: Option, - expires_in: Option, - - error: Option, - error_description: Option, -} - fn spawn_with_handle(f: F) -> impl Future where F: Future + 'static + Send, diff --git a/src/refresh.rs b/src/refresh.rs index 8963f92..4a4bd94 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -1,11 +1,10 @@ -use crate::types::{ApplicationSecret, JsonError, RefreshResult, RequestError}; +use crate::types::{ApplicationSecret, JsonErrorOr, RefreshResult, RequestError}; use super::Token; use chrono::Utc; use futures_util::try_stream::TryStreamExt; use hyper; use hyper::header; -use serde_json as json; use url::form_urlencoded; /// Implements the [OAuth2 Refresh Token Flow](https://developers.google.com/youtube/v3/guides/authentication#devices). @@ -58,34 +57,28 @@ impl RefreshFlow { Ok(body) => body, Err(err) => return Ok(RefreshResult::Error(err)), }; - if let Ok(json_err) = json::from_slice::(&body) { - return Ok(RefreshResult::RefreshError( - json_err.error, - json_err.error_description, - )); - } + #[derive(Deserialize)] struct JsonToken { access_token: String, token_type: String, expires_in: i64, } - let t: JsonToken = match json::from_slice(&body) { - Err(_) => { - return Ok(RefreshResult::RefreshError( - "failed to deserialized json token from refresh response".to_owned(), - None, - )) - } - Ok(token) => token, - }; - Ok(RefreshResult::Success(Token { - access_token: t.access_token, - token_type: t.token_type, - refresh_token: Some(refresh_token.to_string()), - expires_in: None, - expires_in_timestamp: Some(Utc::now().timestamp() + t.expires_in), - })) + + match serde_json::from_slice::>(&body) { + Err(_) => Ok(RefreshResult::RefreshError("failed to deserialized json token from refresh response".to_owned(), None)), + Ok(JsonErrorOr::Err(json_err)) => Ok(RefreshResult::RefreshError(json_err.error, json_err.error_description)), + Ok(JsonErrorOr::Data(JsonToken{access_token, token_type, expires_in})) => { + Ok(RefreshResult::Success( + Token{ + access_token, + token_type, + refresh_token: Some(refresh_token.to_string()), + expires_in: None, + expires_in_timestamp: Some(Utc::now().timestamp() + expires_in), + })) + }, + } } } diff --git a/src/service_account.rs b/src/service_account.rs index eee3fe4..9b93c94 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -17,7 +17,7 @@ use std::sync::{Arc, Mutex}; use crate::authenticator::{DefaultHyperClient, HyperClientBuilder}; use crate::storage::{hash_scopes, MemoryStorage, TokenStorage}; -use crate::types::{ApplicationSecret, GetToken, JsonError, RequestError, Token}; +use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, RequestError, Token}; use futures::prelude::*; use hyper::header; @@ -302,38 +302,32 @@ where .try_concat() .await .map_err(RequestError::ClientError)?; - if let Ok(jse) = serde_json::from_slice::(&body) { - return Err(RequestError::NegativeServerResponse( - jse.error, - jse.error_description, - )); - } - let token: TokenResponse = - serde_json::from_slice(&body).map_err(RequestError::JSONError)?; - let token = match token { - TokenResponse { + match serde_json::from_slice::>(&body)? { + JsonErrorOr::Err(err) => { + Err(err.into()) + }, + JsonErrorOr::Data(TokenResponse { access_token: Some(access_token), token_type: Some(token_type), expires_in: Some(expires_in), .. - } => { + }) => { let expires_ts = chrono::Utc::now().timestamp() + expires_in; - Token { + Ok(Token { access_token, token_type, refresh_token: None, expires_in: Some(expires_in), expires_in_timestamp: Some(expires_ts), - } - } - _ => { - return Err(RequestError::BadServerResponse(format!( + }) + }, + JsonErrorOr::Data(token) => { + Err(RequestError::BadServerResponse(format!( "Token response lacks fields: {:?}", token ))) } - }; - Ok(token) + } } async fn get_token(&self, scopes: &[T]) -> Result diff --git a/src/types.rs b/src/types.rs index 697ddb0..2c575ec 100644 --- a/src/types.rs +++ b/src/types.rs @@ -15,6 +15,14 @@ pub struct JsonError { pub error_uri: Option, } +/// A helper type to deserialize either a JsonError or another piece of data. +#[derive(Deserialize, Debug)] +#[serde(untagged)] +pub enum JsonErrorOr { + Err(JsonError), + Data(T), +} + /// All possible outcomes of the refresh flow #[derive(Debug)] pub enum RefreshResult { @@ -57,7 +65,7 @@ pub enum RequestError { /// A malformed server response. BadServerResponse(String), /// Error while decoding a JSON response. - JSONError(serde_json::error::Error), + JSONError(serde_json::Error), /// Error within user input. UserError(String), /// A lower level IO error. @@ -90,6 +98,12 @@ impl From for RequestError { } } +impl From for RequestError { + fn from(value: serde_json::Error) -> RequestError { + RequestError::JSONError(value) + } +} + impl fmt::Display for RequestError { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self { From 4bd81c3263084ea55ee7913567b1f2bfc2e32d57 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 8 Nov 2019 15:25:52 -0800 Subject: [PATCH 13/71] cargo fmt --- examples/test-svc-acct/src/main.rs | 10 ++----- src/authenticator.rs | 17 +++--------- src/authenticator_delegate.rs | 6 +---- src/device.rs | 11 +++----- src/installed.rs | 41 +++++++++++++--------------- src/lib.rs | 4 +-- src/refresh.rs | 41 +++++++++++++++------------- src/service_account.rs | 35 +++++++----------------- src/storage.rs | 43 +++++++++++++++--------------- src/types.rs | 2 +- 10 files changed, 84 insertions(+), 126 deletions(-) diff --git a/examples/test-svc-acct/src/main.rs b/examples/test-svc-acct/src/main.rs index 6ad49f1..4945adc 100644 --- a/examples/test-svc-acct/src/main.rs +++ b/examples/test-svc-acct/src/main.rs @@ -10,14 +10,8 @@ async fn main() { let sa = yup_oauth2::ServiceAccountAccess::new(creds).build(); let scopes = &["https://www.googleapis.com/auth/pubsub"]; - let tok = sa - .token(scopes) - .await - .unwrap(); + let tok = sa.token(scopes).await.unwrap(); println!("token is: {:?}", tok); - let tok = sa - .token(scopes) - .await - .unwrap(); + let tok = sa.token(scopes).await.unwrap(); println!("cached token is {:?} and should be identical", tok); } diff --git a/src/authenticator.rs b/src/authenticator.rs index 4b5a5ac..6c21771 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -202,10 +202,7 @@ where let gettoken = &self.inner; let appsecret = gettoken.application_secret(); loop { - match store.get( - scope_key, - scopes, - ) { + match store.get(scope_key, scopes) { Ok(Some(t)) => { if !t.expired() { return Ok(t); @@ -233,11 +230,7 @@ where return Err(RequestError::Refresh(rr)); } RefreshResult::Success(t) => { - let x = store.set( - scope_key, - scopes, - Some(t.clone()), - ); + let x = store.set(scope_key, scopes, Some(t.clone())); if let Err(e) = x { match delegate.token_storage_failure(true, &e) { Retry::Skip => return Ok(t), @@ -252,11 +245,7 @@ where } Ok(None) => { let t = gettoken.token(scopes).await?; - if let Err(e) = store.set( - scope_key, - scopes, - Some(t.clone()), - ) { + if let Err(e) = store.set(scope_key, scopes, Some(t.clone())) { match delegate.token_storage_failure(true, &e) { Retry::Skip => return Ok(t), Retry::Abort => return Err(RequestError::Cache(Box::new(e))), diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 4c25969..245cc54 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -95,11 +95,7 @@ pub trait AuthenticatorDelegate: Clone + Send + Sync { /// Called if we could not acquire a refresh token for a reason possibly specified /// by the server. /// This call is made for the delegate's information only. - fn token_refresh_failed>( - &self, - error: S, - error_description: &Option, - ) { + fn token_refresh_failed>(&self, error: S, error_description: &Option) { { let _ = error; } diff --git a/src/device.rs b/src/device.rs index a618243..a5bbb60 100644 --- a/src/device.rs +++ b/src/device.rs @@ -3,16 +3,14 @@ use std::time::Duration; use ::log::{error, log}; use chrono::{self, Utc}; -use futures::{prelude::*}; +use futures::prelude::*; use hyper; use hyper::header; use serde_json as json; use url::form_urlencoded; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate, PollInformation, Retry}; -use crate::types::{ - ApplicationSecret, GetToken, JsonErrorOr, PollError, RequestError, Token, -}; +use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, PollError, RequestError, Token}; pub const GOOGLE_DEVICE_CODE_URL: &'static str = "https://accounts.google.com/o/oauth2/device/code"; @@ -128,10 +126,7 @@ where { /// Essentially what `GetToken::token` does: Retrieve a token for the given scopes without /// caching. - pub async fn retrieve_device_token( - &self, - scopes: &[T], - ) -> Result + pub async fn retrieve_device_token(&self, scopes: &[T]) -> Result where T: AsRef, { diff --git a/src/installed.rs b/src/installed.rs index 88d39bb..88c90c0 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -17,7 +17,7 @@ use url::form_urlencoded; use url::percent_encoding::{percent_encode, QUERY_ENCODE_SET}; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate}; -use crate::types::{ApplicationSecret, GetToken, RequestError, Token, JsonErrorOr}; +use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, RequestError, Token}; const OOB_REDIRECT_URI: &'static str = "urn:ietf:wg:oauth:2.0:oob"; @@ -40,10 +40,7 @@ where vec![ format!("?scope={}", scopes_string), format!("&access_type=offline"), - format!( - "&redirect_uri={}", - redirect_uri.unwrap_or(OOB_REDIRECT_URI) - ), + format!("&redirect_uri={}", redirect_uri.unwrap_or(OOB_REDIRECT_URI)), format!("&response_type=code"), format!("&client_id={}", client_id), ] @@ -167,10 +164,7 @@ where /// . Return that token /// /// It's recommended not to use the DefaultFlowDelegate, but a specialized one. - async fn obtain_token( - &self, - scopes: &[T], - ) -> Result + async fn obtain_token(&self, scopes: &[T]) -> Result where T: AsRef, { @@ -281,7 +275,12 @@ where match serde_json::from_slice::>(&body)? { JsonErrorOr::Err(err) => Err(err.into()), - JsonErrorOr::Data(JSONTokenResponse{access_token, refresh_token, token_type, expires_in}) => { + JsonErrorOr::Data(JSONTokenResponse { + access_token, + refresh_token, + token_type, + expires_in, + }) => { let mut token = Token { access_token, refresh_token: Some(refresh_token), @@ -465,8 +464,8 @@ mod tests { use std::fmt; use std::str::FromStr; - use hyper::Uri; use hyper::client::connect::HttpConnector; + use hyper::Uri; use hyper_rustls::HttpsConnector; use mockito::{self, mock}; use tokio; @@ -539,10 +538,9 @@ mod tests { .build::<_, hyper::Body>(https); let fd = FD("authorizationcode".to_string(), client.clone()); - let inf = - InstalledFlow::new(app_secret.clone(), InstalledFlowReturnMethod::Interactive) - .delegate(fd) - .build_token_getter(client.clone()); + let inf = InstalledFlow::new(app_secret.clone(), InstalledFlowReturnMethod::Interactive) + .delegate(fd) + .build_token_getter(client.clone()); let rt = tokio::runtime::Builder::new() .core_threads(1) @@ -575,13 +573,12 @@ mod tests { } // Successful path with HTTP redirect. { - let inf = - InstalledFlow::new(app_secret, InstalledFlowReturnMethod::HTTPRedirect(8081)) - .delegate(FD( - "authorizationcodefromlocalserver".to_string(), - client.clone(), - )) - .build_token_getter(client.clone()); + let inf = InstalledFlow::new(app_secret, InstalledFlowReturnMethod::HTTPRedirect(8081)) + .delegate(FD( + "authorizationcodefromlocalserver".to_string(), + client.clone(), + )) + .build_token_getter(client.clone()); let _m = mock("POST", "/token") .match_body(mockito::Matcher::Regex(".*code=authorizationcodefromlocalserver.*client_id=9022167.*".to_string())) .with_body(r#"{"access_token": "accesstoken", "refresh_token": "refreshtoken", "token_type": "Bearer", "expires_in": 12345678}"#) diff --git a/src/lib.rs b/src/lib.rs index 181e1e2..ea9a9b7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,6 +100,6 @@ pub use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; pub use crate::service_account::*; pub use crate::storage::{DiskTokenStorage, MemoryStorage, NullStorage, TokenStorage}; pub use crate::types::{ - ApplicationSecret, ConsoleApplicationSecret, GetToken, PollError, RefreshResult, - RequestError, Scheme, Token, TokenType, + ApplicationSecret, ConsoleApplicationSecret, GetToken, PollError, RefreshResult, RequestError, + Scheme, Token, TokenType, }; diff --git a/src/refresh.rs b/src/refresh.rs index 4a4bd94..6c94c67 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -66,18 +66,25 @@ impl RefreshFlow { } match serde_json::from_slice::>(&body) { - Err(_) => Ok(RefreshResult::RefreshError("failed to deserialized json token from refresh response".to_owned(), None)), - Ok(JsonErrorOr::Err(json_err)) => Ok(RefreshResult::RefreshError(json_err.error, json_err.error_description)), - Ok(JsonErrorOr::Data(JsonToken{access_token, token_type, expires_in})) => { - Ok(RefreshResult::Success( - Token{ - access_token, - token_type, - refresh_token: Some(refresh_token.to_string()), - expires_in: None, - expires_in_timestamp: Some(Utc::now().timestamp() + expires_in), - })) - }, + Err(_) => Ok(RefreshResult::RefreshError( + "failed to deserialized json token from refresh response".to_owned(), + None, + )), + Ok(JsonErrorOr::Err(json_err)) => Ok(RefreshResult::RefreshError( + json_err.error, + json_err.error_description, + )), + Ok(JsonErrorOr::Data(JsonToken { + access_token, + token_type, + expires_in, + })) => Ok(RefreshResult::Success(Token { + access_token, + token_type, + refresh_token: Some(refresh_token.to_string()), + expires_in: None, + expires_in_timestamp: Some(Utc::now().timestamp() + expires_in), + })), } } } @@ -121,13 +128,9 @@ mod tests { .with_body(r#"{"access_token": "new-access-token", "token_type": "Bearer", "expires_in": 1234567}"#) .create(); let fut = async { - let rr = RefreshFlow::refresh_token( - &client, - &app_secret, - refresh_token, - ) - .await - .unwrap(); + let rr = RefreshFlow::refresh_token(&client, &app_secret, refresh_token) + .await + .unwrap(); match rr { RefreshResult::Success(tok) => { assert_eq!("new-access-token", tok.access_token); diff --git a/src/service_account.rs b/src/service_account.rs index 9b93c94..b2c93aa 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -264,7 +264,7 @@ struct TokenResponse { expires_in: Option, } -impl ServiceAccountAccessImpl +impl ServiceAccountAccessImpl where C: hyper::client::connect::Connect + 'static, { @@ -303,9 +303,7 @@ where .await .map_err(RequestError::ClientError)?; match serde_json::from_slice::>(&body)? { - JsonErrorOr::Err(err) => { - Err(err.into()) - }, + JsonErrorOr::Err(err) => Err(err.into()), JsonErrorOr::Data(TokenResponse { access_token: Some(access_token), token_type: Some(token_type), @@ -320,13 +318,11 @@ where expires_in: Some(expires_in), expires_in_timestamp: Some(expires_ts), }) - }, - JsonErrorOr::Data(token) => { - Err(RequestError::BadServerResponse(format!( - "Token response lacks fields: {:?}", - token - ))) } + JsonErrorOr::Data(token) => Err(RequestError::BadServerResponse(format!( + "Token response lacks fields: {:?}", + token + ))), } } @@ -336,11 +332,7 @@ where { let hash = hash_scopes(scopes); let cache = &self.cache; - match cache - .lock() - .unwrap() - .get(hash, scopes) - { + match cache.lock().unwrap().get(hash, scopes) { Ok(Some(token)) if !token.expired() => return Ok(token), _ => {} } @@ -351,11 +343,7 @@ where scopes, ) .await?; - let _ = cache.lock().unwrap().set( - hash, - scopes, - Some(token.clone()), - ); + let _ = cache.lock().unwrap().set(hash, scopes, Some(token.clone())); Ok(token) } } @@ -489,9 +477,7 @@ mod tests { .hyper_client(client.clone()) .build(); let fut = async { - let result = acc - .token(&["https://www.googleapis.com/auth/pubsub"]) - .await; + let result = acc.token(&["https://www.googleapis.com/auth/pubsub"]).await; assert!(result.is_err()); Ok(()) as Result<(), ()> }; @@ -520,8 +506,7 @@ mod tests { rt.block_on(async { println!( "{:?}", - acc.token(&["https://www.googleapis.com/auth/pubsub"]) - .await + acc.token(&["https://www.googleapis.com/auth/pubsub"]).await ); }); } diff --git a/src/storage.rs b/src/storage.rs index 51afac4..8d9df5c 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -65,14 +65,14 @@ impl TokenStorage for NullStorage { type Error = std::convert::Infallible; fn set(&self, _: u64, _: &[T], _: Option) -> Result<(), Self::Error> where - T: AsRef + T: AsRef, { Ok(()) } fn get(&self, _: u64, _: &[T]) -> Result, Self::Error> where - T: AsRef + T: AsRef, { Ok(None) } @@ -93,14 +93,9 @@ impl MemoryStorage { impl TokenStorage for MemoryStorage { type Error = std::convert::Infallible; - fn set( - &self, - scope_hash: u64, - scopes: &[T], - token: Option, - ) -> Result<(), Self::Error> + fn set(&self, scope_hash: u64, scopes: &[T], token: Option) -> Result<(), Self::Error> where - T: AsRef + T: AsRef, { let mut tokens = self.tokens.lock().expect("poisoned mutex"); let matched = tokens.iter().find_position(|x| x.hash == scope_hash); @@ -124,7 +119,7 @@ impl TokenStorage for MemoryStorage { fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> where - T: AsRef + T: AsRef, { let tokens = self.tokens.lock().expect("poisoned mutex"); Ok(token_for_scopes(&tokens, scope_hash, scopes)) @@ -222,14 +217,9 @@ fn load_from_file(filename: &Path) -> Result, io::Error> { impl TokenStorage for DiskTokenStorage { type Error = io::Error; - fn set( - &self, - scope_hash: u64, - scopes: &[T], - token: Option, - ) -> Result<(), Self::Error> + fn set(&self, scope_hash: u64, scopes: &[T], token: Option) -> Result<(), Self::Error> where - T: AsRef + T: AsRef, { { let mut tokens = self.tokens.lock().expect("poisoned mutex"); @@ -255,7 +245,7 @@ impl TokenStorage for DiskTokenStorage { fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> where - T: AsRef + T: AsRef, { let tokens = self.tokens.lock().expect("poisoned mutex"); Ok(token_for_scopes(&tokens, scope_hash, scopes)) @@ -268,11 +258,14 @@ where { for t in tokens.iter() { if let Some(token_scopes) = &t.scopes { - if scopes.iter().all(|s| token_scopes.iter().any(|t| t == s.as_ref())) { + if scopes + .iter() + .all(|s| token_scopes.iter().any(|t| t == s.as_ref())) + { return Some(t.token.clone()); } } else if scope_hash == t.hash { - return Some(t.token.clone()) + return Some(t.token.clone()); } } None @@ -288,10 +281,16 @@ mod tests { assert_eq!(hash_scopes(&["foo", "bar"]), hash_scopes(&["foo", "bar"])); // The hash should be order independent. assert_eq!(hash_scopes(&["bar", "foo"]), hash_scopes(&["foo", "bar"])); - assert_eq!(hash_scopes(&["bar", "baz", "bat"]), hash_scopes(&["baz", "bar", "bat"])); + assert_eq!( + hash_scopes(&["bar", "baz", "bat"]), + hash_scopes(&["baz", "bar", "bat"]) + ); // Ensure hashes differ when the contents are different by more than // just order. - assert_ne!(hash_scopes(&["foo", "bar", "baz"]), hash_scopes(&["foo", "bar"])); + assert_ne!( + hash_scopes(&["foo", "bar", "baz"]), + hash_scopes(&["foo", "bar"]) + ); } } diff --git a/src/types.rs b/src/types.rs index 2c575ec..35ab2a4 100644 --- a/src/types.rs +++ b/src/types.rs @@ -358,7 +358,7 @@ pub struct ApplicationSecret { impl ApplicationSecret { pub const fn empty() -> Self { - ApplicationSecret{ + ApplicationSecret { client_id: String::new(), client_secret: String::new(), token_uri: String::new(), From 8489f470a459f78e61917891bd0e327d51e5f5be Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 8 Nov 2019 15:49:07 -0800 Subject: [PATCH 14/71] cargo clippy fixes --- src/authenticator.rs | 10 +++++----- src/authenticator_delegate.rs | 2 +- src/device.rs | 8 ++++---- src/installed.rs | 17 ++++++++--------- src/service_account.rs | 17 ++++++++--------- src/storage.rs | 21 ++++++++------------- src/types.rs | 8 ++++---- 7 files changed, 38 insertions(+), 45 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index 6c21771..1beaddb 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -159,7 +159,7 @@ where client: self.client, token_getter: self.token_getter, store: self.store, - delegate: delegate, + delegate, } } @@ -217,15 +217,15 @@ where match rr { RefreshResult::Error(ref e) => { delegate.token_refresh_failed( - format!("{}", e.description().to_string()), - &Some("the request has likely timed out".to_string()), + e.description(), + Some("the request has likely timed out"), ); return Err(RequestError::Refresh(rr)); } RefreshResult::RefreshError(ref s, ref ss) => { delegate.token_refresh_failed( - format!("{} {}", s, ss.clone().map(|s| format!("({})", s)).unwrap_or("".to_string())), - &Some("the refresh token is likely invalid and your authorization has been revoked".to_string()), + &format!("{}{}", s, ss.as_ref().map(|s| format!(" ({})", s)).unwrap_or_else(String::new)), + Some("the refresh token is likely invalid and your authorization has been revoked"), ); return Err(RequestError::Refresh(rr)); } diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 245cc54..ad3b37c 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -95,7 +95,7 @@ pub trait AuthenticatorDelegate: Clone + Send + Sync { /// Called if we could not acquire a refresh token for a reason possibly specified /// by the server. /// This call is made for the delegate's information only. - fn token_refresh_failed>(&self, error: S, error_description: &Option) { + fn token_refresh_failed(&self, error: &str, error_description: Option<&str>) { { let _ = error; } diff --git a/src/device.rs b/src/device.rs index a5bbb60..8afc1f3 100644 --- a/src/device.rs +++ b/src/device.rs @@ -12,7 +12,7 @@ use url::form_urlencoded; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate, PollInformation, Retry}; use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, PollError, RequestError, Token}; -pub const GOOGLE_DEVICE_CODE_URL: &'static str = "https://accounts.google.com/o/oauth2/device/code"; +pub const GOOGLE_DEVICE_CODE_URL: &str = "https://accounts.google.com/o/oauth2/device/code"; /// Implements the [Oauth2 Device Flow](https://developers.google.com/youtube/v3/guides/authentication#devices) /// It operates in two steps: @@ -215,7 +215,7 @@ where let resp = client .request(req) .await - .map_err(|e| RequestError::ClientError(e))?; + .map_err(RequestError::ClientError)?; // This return type is defined in https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.2 // The alias is present as Google use a non-standard name for verification_uri. // According to the standard interval is optional, however, all tested implementations provide it. @@ -294,12 +294,12 @@ where let res = client .request(request) .await - .map_err(|e| PollError::HttpError(e))?; + .map_err(PollError::HttpError)?; let body = res .into_body() .try_concat() .await - .map_err(|e| PollError::HttpError(e))?; + .map_err(PollError::HttpError)?; #[derive(Deserialize)] struct JsonError { error: String, diff --git a/src/installed.rs b/src/installed.rs index 88c90c0..8928261 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -19,7 +19,7 @@ use url::percent_encoding::{percent_encode, QUERY_ENCODE_SET}; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate}; use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, RequestError, Token}; -const OOB_REDIRECT_URI: &'static str = "urn:ietf:wg:oauth:2.0:oob"; +const OOB_REDIRECT_URI: &str = "urn:ietf:wg:oauth:2.0:oob"; /// Assembles a URL to request an authorization token (with user interaction). /// Note that the redirect_uri here has to be either None or some variation of @@ -39,9 +39,9 @@ where url.push_str(auth_uri); vec![ format!("?scope={}", scopes_string), - format!("&access_type=offline"), + "&access_type=offline".to_string(), format!("&redirect_uri={}", redirect_uri.unwrap_or(OOB_REDIRECT_URI)), - format!("&response_type=code"), + "&response_type=code".to_string(), format!("&client_id={}", client_id), ] .into_iter() @@ -258,12 +258,12 @@ where .client .request(request) .await - .map_err(|e| RequestError::ClientError(e))?; + .map_err(RequestError::ClientError)?; let body = resp .into_body() .try_concat() .await - .map_err(|e| RequestError::ClientError(e))?; + .map_err(RequestError::ClientError)?; #[derive(Deserialize)] struct JSONTokenResponse { @@ -295,7 +295,7 @@ where } /// Sends the authorization code to the provider in order to obtain access and refresh tokens. - fn request_token<'a>( + fn request_token( appsecret: &ApplicationSecret, authcode: &str, custom_redirect_uri: Option<&str>, @@ -318,11 +318,10 @@ where ]) .finish(); - let request = hyper::Request::post(&appsecret.token_uri) + hyper::Request::post(&appsecret.token_uri) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(body)) - .unwrap(); // TODO: error check - request + .unwrap() // TODO: error check } } diff --git a/src/service_account.rs b/src/service_account.rs index b2c93aa..545d7f5 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -36,8 +36,8 @@ use chrono; use hyper; use serde_json; -const GRANT_TYPE: &'static str = "urn:ietf:params:oauth:grant-type:jwt-bearer"; -const GOOGLE_RS256_HEAD: &'static str = "{\"alg\":\"RS256\",\"typ\":\"JWT\"}"; +const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer"; +const GOOGLE_RS256_HEAD: &str = "{\"alg\":\"RS256\",\"typ\":\"JWT\"}"; /// Encodes s as Base64 fn encode_base64>(s: T) -> String { @@ -51,7 +51,7 @@ fn decode_rsa_key(pem_pkcs8: &str) -> Result { let private_keys = pemfile::pkcs8_private_keys(&mut private_reader); if let Ok(pk) = private_keys { - if pk.len() > 0 { + if !pk.is_empty() { Ok(pk[0].clone()) } else { Err(io::Error::new( @@ -112,7 +112,7 @@ impl JWT { fn new(claims: Claims) -> JWT { JWT { header: GOOGLE_RS256_HEAD.to_string(), - claims: claims, + claims, } } @@ -141,10 +141,9 @@ impl JWT { .map_err(|_| io::Error::new(io::ErrorKind::Other, "Couldn't initialize signer"))?; let signer = signing_key .choose_scheme(&[rustls::SignatureScheme::RSA_PKCS1_SHA256]) - .ok_or(io::Error::new( - io::ErrorKind::Other, - "Couldn't choose signing scheme", - ))?; + .ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "Couldn't choose signing scheme") + })?; let signature = signer .sign(jwt_head.as_bytes()) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("{}", e)))?; @@ -176,7 +175,7 @@ where iss: key.client_email.clone().unwrap(), aud: key.token_uri.clone().unwrap(), exp: expiry, - iat: iat, + iat, sub: None, scope: scopes_string, } diff --git a/src/storage.rs b/src/storage.rs index 8d9df5c..fef3c0a 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -103,17 +103,13 @@ impl TokenStorage for MemoryStorage { self.tokens.retain(|x| x.hash != scope_hash); } - match token { - Some(t) => { - tokens.push(JSONToken { - hash: scope_hash, - scopes: Some(scopes.into_iter().map(|x| x.as_ref().to_string()).collect()), - token: t, - }); - () - } - None => {} - }; + if let Some(t) = token { + tokens.push(JSONToken { + hash: scope_hash, + scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()), + token: t, + }); + } Ok(()) } @@ -233,10 +229,9 @@ impl TokenStorage for DiskTokenStorage { Some(t) => { tokens.push(JSONToken { hash: scope_hash, - scopes: Some(scopes.into_iter().map(|x| x.as_ref().to_string()).collect()), + scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()), token: t, }); - () } } } diff --git a/src/types.rs b/src/types.rs index 35ab2a4..a9d371c 100644 --- a/src/types.rs +++ b/src/types.rs @@ -91,7 +91,7 @@ impl From for RequestError { "invalid_scope" => RequestError::InvalidScope( value .error_description - .unwrap_or("no description provided".to_string()), + .unwrap_or_else(|| "no description provided".to_string()), ), _ => RequestError::NegativeServerResponse(value.error, value.error_description), } @@ -112,7 +112,7 @@ impl fmt::Display for RequestError { RequestError::InvalidScope(ref scope) => writeln!(f, "Invalid Scope: '{}'", scope), RequestError::NegativeServerResponse(ref error, ref desc) => { error.fmt(f)?; - if let &Some(ref desc) = desc { + if let Some(ref desc) = *desc { write!(f, ": {}", desc)?; } "\n".fmt(f) @@ -162,7 +162,7 @@ impl StringError { error.push_str(d.as_ref()); } - StringError { error: error } + StringError { error } } } @@ -298,7 +298,7 @@ impl Token { /// # Panics /// * if our access_token is unset pub fn expired(&self) -> bool { - if self.access_token.len() == 0 { + if self.access_token.is_empty() { panic!("called expired() on unset token"); } if let Some(expiry_date) = self.expiry_date() { From 0e9cf512ba91a8bdf070e94b8f48eb6dacb21f6e Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 8 Nov 2019 15:59:29 -0800 Subject: [PATCH 15/71] Remove the HTTPRedirectEphemeral variant. In favor of making it the default and removing the option to specify a port to listen on. If needed a variant can be added to specify a port explicitly, but most users should want an ephemeral port chosen so making it the default makes sense while other breaking changes are in flight. --- examples/test-installed/src/main.rs | 2 +- src/installed.rs | 49 +++++++++++------------------ src/lib.rs | 2 +- 3 files changed, 20 insertions(+), 33 deletions(-) diff --git a/examples/test-installed/src/main.rs b/examples/test-installed/src/main.rs index c333255..54be93f 100644 --- a/examples/test-installed/src/main.rs +++ b/examples/test-installed/src/main.rs @@ -10,7 +10,7 @@ async fn main() { let auth = Authenticator::new(InstalledFlow::new( secret, - yup_oauth2::InstalledFlowReturnMethod::HTTPRedirectEphemeral, + yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect, )) .persist_tokens_to_disk("tokencache.json") .build() diff --git a/src/installed.rs b/src/installed.rs index 8928261..d31ea38 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -92,11 +92,7 @@ pub enum InstalledFlowReturnMethod { Interactive, /// Involves spinning up a local HTTP server and Google redirecting the browser to /// the server with a URL containing the code (preferred, but not as reliable). - HTTPRedirectEphemeral, - /// Involves spinning up a local HTTP server and Google redirecting the browser to - /// the server with a URL containing the code (preferred, but not as reliable). The - /// parameter is the port to listen on. - HTTPRedirect(u16), + HTTPRedirect, } /// InstalledFlowImpl provides tokens for services that follow the "Installed" OAuth flow. (See @@ -169,12 +165,7 @@ where T: AsRef, { match self.method { - InstalledFlowReturnMethod::HTTPRedirect(port) => { - self.ask_auth_code_via_http(scopes, port).await - } - InstalledFlowReturnMethod::HTTPRedirectEphemeral => { - self.ask_auth_code_via_http(scopes, 0).await - } + InstalledFlowReturnMethod::HTTPRedirect => self.ask_auth_code_via_http(scopes).await, InstalledFlowReturnMethod::Interactive => { self.ask_auth_code_interactively(scopes).await } @@ -211,26 +202,22 @@ where self.exchange_auth_code(&authcode, None).await } - async fn ask_auth_code_via_http( - &self, - scopes: &[T], - desired_port: u16, - ) -> Result + async fn ask_auth_code_via_http(&self, scopes: &[T]) -> Result where T: AsRef, { use std::borrow::Cow; let auth_delegate = &self.fd; let appsecret = &self.appsecret; - let server = InstalledFlowServer::run(desired_port)?; - let bound_port = server.local_addr().port(); + let server = InstalledFlowServer::run()?; + let server_addr = server.local_addr(); // Present url to user. // The redirect URI must be this very localhost URL, otherwise authorization is refused // by certain providers. let redirect_uri: Cow = match auth_delegate.redirect_uri() { Some(uri) => uri.into(), - None => format!("http://localhost:{}", bound_port).into(), + None => format!("http://{}", server_addr).into(), }; let url = build_authentication_request_url( &appsecret.auth_uri, @@ -243,17 +230,17 @@ where .await; let auth_code = server.wait_for_auth_code().await; - self.exchange_auth_code(&auth_code, Some(bound_port)).await + self.exchange_auth_code(&auth_code, Some(server_addr)).await } async fn exchange_auth_code( &self, authcode: &str, - port: Option, + server_addr: Option, ) -> Result { let appsec = &self.appsecret; let redirect_uri = self.fd.redirect_uri(); - let request = Self::request_token(appsec, authcode, redirect_uri, port); + let request = Self::request_token(appsec, authcode, redirect_uri, server_addr); let resp = self .client .request(request) @@ -299,12 +286,12 @@ where appsecret: &ApplicationSecret, authcode: &str, custom_redirect_uri: Option<&str>, - port: Option, + server_addr: Option, ) -> hyper::Request { use std::borrow::Cow; - let redirect_uri: Cow = match (custom_redirect_uri, port) { + let redirect_uri: Cow = match (custom_redirect_uri, server_addr) { (Some(uri), _) => uri.into(), - (None, Some(port)) => format!("http://localhost:{}", port).into(), + (None, Some(addr)) => format!("http://{}", addr).into(), (None, None) => OOB_REDIRECT_URI.into(), }; @@ -344,7 +331,7 @@ struct InstalledFlowServer { } impl InstalledFlowServer { - fn run(desired_port: u16) -> Result { + fn run() -> Result { use hyper::service::{make_service_fn, service_fn}; let (auth_code_tx, auth_code_rx) = oneshot::channel::(); let (trigger_shutdown_tx, trigger_shutdown_rx) = oneshot::channel::<()>(); @@ -359,7 +346,7 @@ impl InstalledFlowServer { })) } }); - let addr: std::net::SocketAddr = ([127, 0, 0, 1], desired_port).into(); + let addr: std::net::SocketAddr = ([127, 0, 0, 1], 0).into(); let server = hyper::server::Server::try_bind(&addr)?; let server = server.http1_only(true).serve(service); let addr = server.local_addr(); @@ -572,7 +559,7 @@ mod tests { } // Successful path with HTTP redirect. { - let inf = InstalledFlow::new(app_secret, InstalledFlowReturnMethod::HTTPRedirect(8081)) + let inf = InstalledFlow::new(app_secret, InstalledFlowReturnMethod::HTTPRedirect) .delegate(FD( "authorizationcodefromlocalserver".to_string(), client.clone(), @@ -639,8 +626,8 @@ mod tests { #[tokio::test] async fn test_server_random_local_port() { - let addr1 = InstalledFlowServer::run(0).unwrap().local_addr(); - let addr2 = InstalledFlowServer::run(0).unwrap().local_addr(); + let addr1 = InstalledFlowServer::run().unwrap().local_addr(); + let addr2 = InstalledFlowServer::run().unwrap().local_addr(); assert_ne!(addr1.port(), addr2.port()); } @@ -662,7 +649,7 @@ mod tests { async fn test_server() { let client: hyper::Client = hyper::Client::builder().build_http(); - let server = InstalledFlowServer::run(0).unwrap(); + let server = InstalledFlowServer::run().unwrap(); let response = client .get(format!("http://{}/", server.local_addr()).parse().unwrap()) diff --git a/src/lib.rs b/src/lib.rs index ea9a9b7..209874f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,7 +59,7 @@ //! // authenticator takes care of caching tokens to disk and refreshing tokens once //! // they've expired. //! let mut auth = Authenticator::new( -//! InstalledFlow::new(secret, yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect(0)) +//! InstalledFlow::new(secret, yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect) //! ) //! .persist_tokens_to_disk("tokencache.json") //! .build() From 744620042104fc1f45d21cf70eb52c278d424d7b Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 8 Nov 2019 16:08:27 -0800 Subject: [PATCH 16/71] Remove unnecessary trait bounds on hyper connector. Send+Sync is implied by the trait, and Clone is no longer necessary. --- src/authenticator.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index 1beaddb..1057755 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -189,7 +189,7 @@ where GT: 'static + GetToken, S: 'static + TokenStorage, AD: 'static + AuthenticatorDelegate, - C: 'static + hyper::client::connect::Connect + Clone + Send, + C: 'static + hyper::client::connect::Connect, { async fn get_token(&self, scopes: &[T]) -> Result where @@ -264,12 +264,12 @@ where } } -impl< - GT: 'static + GetToken, - S: 'static + TokenStorage, - AD: 'static + AuthenticatorDelegate, - C: 'static + hyper::client::connect::Connect + Clone + Send, - > GetToken for AuthenticatorImpl +impl GetToken for AuthenticatorImpl +where + GT: 'static + GetToken, + S: 'static + TokenStorage, + AD: 'static + AuthenticatorDelegate, + C: 'static + hyper::client::connect::Connect, { /// Returns the API Key of the inner flow. fn api_key(&self) -> Option { From fa121d41b2b1ff9862986b0bcbccf0e9e11a466f Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Mon, 11 Nov 2019 08:43:40 -0800 Subject: [PATCH 17/71] Delegates no longer need to implement Clone. --- src/authenticator_delegate.rs | 4 ++-- src/device.rs | 6 +++--- src/lib.rs | 4 +--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index ad3b37c..f5bc4a9 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -71,7 +71,7 @@ impl Error for PollError { /// /// The only method that needs to be implemented manually is `present_user_code(...)`, /// as no assumptions are made on how this presentation should happen. -pub trait AuthenticatorDelegate: Clone + Send + Sync { +pub trait AuthenticatorDelegate: Send + Sync { /// Called whenever there is an client, usually if there are network problems. /// /// Return retry information. @@ -107,7 +107,7 @@ pub trait AuthenticatorDelegate: Clone + Send + Sync { /// FlowDelegate methods are called when an OAuth flow needs to ask the application what to do in /// certain cases. -pub trait FlowDelegate: Clone + Send + Sync { +pub trait FlowDelegate: Send + Sync { /// Called if the request code is expired. You will have to start over in this case. /// This will be the last call the delegate receives. /// Given `DateTime` is the expiration date diff --git a/src/device.rs b/src/device.rs index 8afc1f3..82e1aee 100644 --- a/src/device.rs +++ b/src/device.rs @@ -133,7 +133,7 @@ where let application_secret = &self.application_secret; let client = self.client.clone(); let wait = self.wait; - let fd = self.fd.clone(); + let fd = &self.fd; let (pollinf, device_code) = Self::request_code( application_secret, client.clone(), @@ -152,7 +152,7 @@ where client.clone(), &device_code, pollinf.clone(), - fd.clone(), + fd, ) .await; match r { @@ -270,7 +270,7 @@ where client: hyper::Client, device_code: &str, pi: PollInformation, - fd: FD, + fd: &FD, ) -> Result, PollError> { if pi.expires_at <= Utc::now() { fd.expired(&pi.expires_at); diff --git a/src/lib.rs b/src/lib.rs index 209874f..f53d54c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,13 +45,11 @@ //! use hyper::client::Client; //! use hyper_rustls::HttpsConnector; //! -//! use std::path::Path; -//! //! #[tokio::main] //! async fn main() { //! // Read application secret from a file. Sometimes it's easier to compile it directly into //! // the binary. The clientsecret file contains JSON like `{"installed":{"client_id": ... }}` -//! let secret = yup_oauth2::read_application_secret(Path::new("clientsecret.json")) +//! let secret = yup_oauth2::read_application_secret("clientsecret.json") //! .expect("clientsecret.json"); //! //! // Create an authenticator that uses an InstalledFlow to authenticate. The From b6affacbf0d1cb9475490fc0e7aac654893a8d72 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Mon, 11 Nov 2019 08:56:02 -0800 Subject: [PATCH 18/71] Unify trait bounds on Authenticator::build --- src/authenticator.rs | 4 ++-- src/service_account.rs | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index 1057755..97e6ddb 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -167,9 +167,9 @@ where pub fn build(self) -> io::Result where T::TokenGetter: 'static + GetToken, - S: 'static + Send, + S: 'static, AD: 'static, - C::Connector: 'static + Clone + Send, + C::Connector: 'static + hyper::client::connect::Connect, { let client = self.client.build_hyper_client(); let store = self.store?; diff --git a/src/service_account.rs b/src/service_account.rs index 545d7f5..ebe938f 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -233,7 +233,6 @@ where } } -#[derive(Clone)] struct ServiceAccountAccessImpl { client: hyper::Client, key: ServiceAccountKey, From ef23eef31d199b5bef020b98c5644d07e67b0997 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Mon, 11 Nov 2019 09:41:30 -0800 Subject: [PATCH 19/71] Remove more unnecessary clones. --- src/authenticator_delegate.rs | 2 +- src/device.rs | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index f5bc4a9..eee7ec9 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -111,7 +111,7 @@ pub trait FlowDelegate: Send + Sync { /// Called if the request code is expired. You will have to start over in this case. /// This will be the last call the delegate receives. /// Given `DateTime` is the expiration date - fn expired(&self, _: &DateTime) {} + fn expired(&self, _: DateTime) {} /// Called if the user denied access. You would have to start over. /// This will be the last call the delegate receives. diff --git a/src/device.rs b/src/device.rs index 82e1aee..2c31af9 100644 --- a/src/device.rs +++ b/src/device.rs @@ -2,7 +2,7 @@ use std::pin::Pin; use std::time::Duration; use ::log::{error, log}; -use chrono::{self, Utc}; +use chrono::{DateTime, Utc}; use futures::prelude::*; use hyper; use hyper::header; @@ -144,14 +144,12 @@ where fd.present_user_code(&pollinf); let maxn = wait.as_secs() / pollinf.interval.as_secs(); for _ in 0..maxn { - let fd = fd.clone(); - let pollinf = pollinf.clone(); tokio::timer::delay_for(pollinf.interval).await; let r = Self::poll_token( application_secret, client.clone(), &device_code, - pollinf.clone(), + pollinf.expires_at, fd, ) .await; @@ -269,12 +267,12 @@ where application_secret: &ApplicationSecret, client: hyper::Client, device_code: &str, - pi: PollInformation, + expires_at: DateTime, fd: &FD, ) -> Result, PollError> { - if pi.expires_at <= Utc::now() { - fd.expired(&pi.expires_at); - return Err(PollError::Expired(pi.expires_at)); + if expires_at <= Utc::now() { + fd.expired(expires_at); + return Err(PollError::Expired(expires_at)); } // We should be ready for a new request From e8675fa1da6e8b8c845ce643365410065a9055d0 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Mon, 11 Nov 2019 10:40:05 -0800 Subject: [PATCH 20/71] Refactor retrieve_device_token. wait_for_device_token polls indefinitely at the specified intervals. Use tokio::timer::Timeout to bound the time it will poll for. --- src/device.rs | 53 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/src/device.rs b/src/device.rs index 2c31af9..2346e1a 100644 --- a/src/device.rs +++ b/src/device.rs @@ -131,44 +131,55 @@ where T: AsRef, { let application_secret = &self.application_secret; - let client = self.client.clone(); - let wait = self.wait; - let fd = &self.fd; let (pollinf, device_code) = Self::request_code( application_secret, - client.clone(), + &self.client, &self.device_code_url, scopes, ) .await?; - fd.present_user_code(&pollinf); - let maxn = wait.as_secs() / pollinf.interval.as_secs(); - for _ in 0..maxn { - tokio::timer::delay_for(pollinf.interval).await; + self.fd.present_user_code(&pollinf); + tokio::timer::Timeout::new( + self.wait_for_device_token(&pollinf, &device_code), + self.wait, + ) + .await + .map_err(|_| RequestError::Poll(PollError::TimedOut))? + } + + async fn wait_for_device_token( + &self, + pollinf: &PollInformation, + device_code: &str, + ) -> Result { + let mut interval = pollinf.interval; + loop { + tokio::timer::delay_for(interval).await; let r = Self::poll_token( - application_secret, - client.clone(), - &device_code, + &self.application_secret, + &self.client, + device_code, pollinf.expires_at, - fd, + &self.fd, ) .await; - match r { - Ok(None) => match fd.pending(&pollinf) { + interval = match r { + Ok(None) => match self.fd.pending(&pollinf) { Retry::Abort | Retry::Skip => { return Err(RequestError::Poll(PollError::TimedOut)) } - Retry::After(d) => tokio::timer::delay_for(d).await, + Retry::After(d) => d, }, Ok(Some(tok)) => return Ok(tok), Err(e @ PollError::AccessDenied) | Err(e @ PollError::TimedOut) | Err(e @ PollError::Expired(_)) => return Err(RequestError::Poll(e)), - Err(ref e) => error!("Unknown error from poll token api: {}", e), + Err(ref e) => { + error!("Unknown error from poll token api: {}", e); + pollinf.interval + } } } - error!("Too many poll attempts"); - Err(RequestError::Poll(PollError::TimedOut)) } /// The first step involves asking the server for a code that the user @@ -188,15 +199,13 @@ where /// See test-cases in source code for a more complete example. async fn request_code( application_secret: &ApplicationSecret, - client: hyper::Client, + client: &hyper::Client, device_code_url: &str, scopes: &[T], ) -> Result<(PollInformation, String), RequestError> where T: AsRef, { - // note: cloned() shouldn't be needed, see issue - // https://github.com/servo/rust-url/issues/81 let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ ("client_id", application_secret.client_id.as_str()), @@ -265,7 +274,7 @@ where /// See test-cases in source code for a more complete example. async fn poll_token<'a>( application_secret: &ApplicationSecret, - client: hyper::Client, + client: &hyper::Client, device_code: &str, expires_at: DateTime, fd: &FD, From 05f7c10533dbc86bf8239e86a619367e69ae78af Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Mon, 11 Nov 2019 13:11:16 -0800 Subject: [PATCH 21/71] Remove unnecessary 'static bounds --- src/authenticator.rs | 35 ++++++++++++++--------------------- src/device.rs | 8 +++----- src/installed.rs | 12 ++++++------ src/refresh.rs | 2 +- src/service_account.rs | 1 - 5 files changed, 24 insertions(+), 34 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index 97e6ddb..f7016be 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -34,7 +34,7 @@ struct AuthenticatorImpl< /// A trait implemented for any hyper::Client as well as teh DefaultHyperClient. pub trait HyperClientBuilder { - type Connector: hyper::client::connect::Connect; + type Connector: hyper::client::connect::Connect + 'static; fn build_hyper_client(self) -> hyper::Client; } @@ -53,7 +53,7 @@ impl HyperClientBuilder for DefaultHyperClient { impl HyperClientBuilder for hyper::Client where - C: hyper::client::connect::Connect, + C: hyper::client::connect::Connect + 'static, { type Connector = C; @@ -72,12 +72,7 @@ pub trait AuthFlow { /// An authenticator can be used with `InstalledFlow`'s or `DeviceFlow`'s and /// will refresh tokens as they expire as well as optionally persist tokens to /// disk. -pub struct Authenticator< - T: AuthFlow, - S: TokenStorage, - AD: AuthenticatorDelegate, - C: HyperClientBuilder, -> { +pub struct Authenticator { client: C, token_getter: T, store: io::Result, @@ -125,7 +120,7 @@ where hyper_client: hyper::Client, ) -> Authenticator> where - NewC: hyper::client::connect::Connect, + NewC: hyper::client::connect::Connect + 'static, T: AuthFlow, { Authenticator { @@ -166,10 +161,8 @@ where /// Create the authenticator. pub fn build(self) -> io::Result where - T::TokenGetter: 'static + GetToken, - S: 'static, - AD: 'static, - C::Connector: 'static + hyper::client::connect::Connect, + T::TokenGetter: GetToken, + C::Connector: hyper::client::connect::Connect + 'static, { let client = self.client.build_hyper_client(); let store = self.store?; @@ -186,10 +179,10 @@ where impl AuthenticatorImpl where - GT: 'static + GetToken, - S: 'static + TokenStorage, - AD: 'static + AuthenticatorDelegate, - C: 'static + hyper::client::connect::Connect, + GT: GetToken, + S: TokenStorage, + AD: AuthenticatorDelegate, + C: hyper::client::connect::Connect + 'static, { async fn get_token(&self, scopes: &[T]) -> Result where @@ -266,10 +259,10 @@ where impl GetToken for AuthenticatorImpl where - GT: 'static + GetToken, - S: 'static + TokenStorage, - AD: 'static + AuthenticatorDelegate, - C: 'static + hyper::client::connect::Connect, + GT: GetToken, + S: TokenStorage, + AD: AuthenticatorDelegate, + C: hyper::client::connect::Connect + 'static, { /// Returns the API Key of the inner flow. fn api_key(&self) -> Option { diff --git a/src/device.rs b/src/device.rs index 2346e1a..51d97b3 100644 --- a/src/device.rs +++ b/src/device.rs @@ -69,7 +69,7 @@ impl DeviceFlow { impl crate::authenticator::AuthFlow for DeviceFlow where - FD: FlowDelegate + 'static, + FD: FlowDelegate, C: hyper::client::connect::Connect + 'static, { type TokenGetter = DeviceFlowImpl; @@ -97,7 +97,7 @@ pub struct DeviceFlowImpl { impl GetToken for DeviceFlowImpl where - FD: FlowDelegate + 'static, + FD: FlowDelegate, C: hyper::client::connect::Connect + 'static, { fn token<'a, T>( @@ -120,9 +120,7 @@ where impl DeviceFlowImpl where C: hyper::client::connect::Connect + 'static, - C::Transport: 'static, - C::Future: 'static, - FD: FlowDelegate + 'static, + FD: FlowDelegate, { /// Essentially what `GetToken::token` does: Retrieve a token for the given scopes without /// caching. diff --git a/src/installed.rs b/src/installed.rs index d31ea38..5dc19ed 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -53,7 +53,7 @@ where impl GetToken for InstalledFlowImpl where - FD: FlowDelegate + 'static, + FD: FlowDelegate, C: hyper::client::connect::Connect + 'static, { fn token<'a, T>( @@ -76,8 +76,8 @@ where /// The InstalledFlow implementation. pub struct InstalledFlowImpl where - FD: FlowDelegate + 'static, - C: hyper::client::connect::Connect + 'static, + FD: FlowDelegate, + C: hyper::client::connect::Connect, { method: InstalledFlowReturnMethod, client: hyper::client::Client, @@ -98,7 +98,7 @@ pub enum InstalledFlowReturnMethod { /// InstalledFlowImpl provides tokens for services that follow the "Installed" OAuth flow. (See /// https://www.oauth.com/oauth2-servers/authorization/, /// https://developers.google.com/identity/protocols/OAuth2InstalledApp). -pub struct InstalledFlow { +pub struct InstalledFlow { method: InstalledFlowReturnMethod, flow_delegate: FD, appsecret: ApplicationSecret, @@ -134,7 +134,7 @@ where impl crate::authenticator::AuthFlow for InstalledFlow where - FD: FlowDelegate + 'static, + FD: FlowDelegate, C: hyper::client::connect::Connect + 'static, { type TokenGetter = InstalledFlowImpl; @@ -151,7 +151,7 @@ where impl InstalledFlowImpl where - FD: FlowDelegate + 'static, + FD: FlowDelegate, C: hyper::client::connect::Connect + 'static, { /// Handles the token request flow; it consists of the following steps: diff --git a/src/refresh.rs b/src/refresh.rs index 6c94c67..7b33269 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -29,7 +29,7 @@ impl RefreshFlow { /// /// # Examples /// Please see the crate landing page for an example. - pub async fn refresh_token( + pub async fn refresh_token( client: &hyper::Client, client_secret: &ApplicationSecret, refresh_token: &str, diff --git a/src/service_account.rs b/src/service_account.rs index ebe938f..2c4cb3f 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -205,7 +205,6 @@ impl ServiceAccountAccess { impl ServiceAccountAccess where C: HyperClientBuilder, - C::Connector: 'static, { /// Use the provided hyper client. pub fn hyper_client( From 060eb92bf7d2f7f90ae57db136910d0448dfc57e Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Mon, 11 Nov 2019 14:04:27 -0800 Subject: [PATCH 22/71] Refactor JWT handling in ServiceAccountAccess. Avoid reading and parsing the private key file on every invocation of token() in favor or reading it once when the ServiceAccountAccess is built. Also avoid unnecessary allocations when signing JWT tokens and renamed sub to subject to avoid any confusion with the std::ops::Sub trait. --- examples/test-svc-acct/src/main.rs | 4 +- src/service_account.rs | 195 ++++++++++++++--------------- 2 files changed, 94 insertions(+), 105 deletions(-) diff --git a/examples/test-svc-acct/src/main.rs b/examples/test-svc-acct/src/main.rs index 4945adc..e5ef33c 100644 --- a/examples/test-svc-acct/src/main.rs +++ b/examples/test-svc-acct/src/main.rs @@ -7,7 +7,9 @@ use yup_oauth2::GetToken; async fn main() { let creds = yup_oauth2::service_account_key_from_file(path::Path::new("serviceaccount.json")).unwrap(); - let sa = yup_oauth2::ServiceAccountAccess::new(creds).build(); + let sa = yup_oauth2::ServiceAccountAccess::new(creds) + .build() + .unwrap(); let scopes = &["https://www.googleapis.com/auth/pubsub"]; let tok = sa.token(scopes).await.unwrap(); diff --git a/src/service_account.rs b/src/service_account.rs index 2c4cb3f..03525a5 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -37,11 +37,11 @@ use hyper; use serde_json; const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer"; -const GOOGLE_RS256_HEAD: &str = "{\"alg\":\"RS256\",\"typ\":\"JWT\"}"; +const GOOGLE_RS256_HEAD: &str = r#"{"alg":"RS256","typ":"JWT"}"#; /// Encodes s as Base64 -fn encode_base64>(s: T) -> String { - base64::encode_config(s.as_ref(), base64::URL_SAFE) +fn append_base64 + ?Sized>(s: &T, out: &mut String) { + base64::encode_config_buf(s, base64::URL_SAFE, out) } /// Decode a PKCS8 formatted RSA key. @@ -78,11 +78,11 @@ pub struct ServiceAccountKey { pub key_type: Option, pub project_id: Option, pub private_key_id: Option, - pub private_key: Option, - pub client_email: Option, + pub private_key: String, + pub client_email: String, pub client_id: Option, pub auth_uri: Option, - pub token_uri: Option, + pub token_uri: String, pub auth_provider_x509_cert_url: Option, pub client_x509_cert_url: Option, } @@ -90,52 +90,42 @@ pub struct ServiceAccountKey { /// Permissions requested for a JWT. /// See https://developers.google.com/identity/protocols/OAuth2ServiceAccount#authorizingrequests. #[derive(Serialize, Debug)] -struct Claims { - iss: String, - aud: String, +struct Claims<'a> { + iss: &'a str, + aud: &'a str, exp: i64, iat: i64, - sub: Option, + subject: Option<&'a str>, scope: String, } -/// A JSON Web Token ready for signing. -struct JWT { - /// The value of GOOGLE_RS256_HEAD. - header: String, - /// A Claims struct, expressing the set of desired permissions etc. - claims: Claims, -} +impl<'a> Claims<'a> { + fn new(key: &'a ServiceAccountKey, scopes: &[T], subject: Option<&'a str>) -> Self + where + T: AsRef, + { + let iat = chrono::Utc::now().timestamp(); + let expiry = iat + 3600 - 5; // Max validity is 1h. -impl JWT { - /// Create a new JWT from claims. - fn new(claims: Claims) -> JWT { - JWT { - header: GOOGLE_RS256_HEAD.to_string(), - claims, + let scope = crate::helper::join(scopes, " "); + Claims { + iss: &key.client_email, + aud: &key.token_uri, + exp: expiry, + iat, + subject, + scope, } } +} - /// Set JWT header. Default is `{"alg":"RS256","typ":"JWT"}`. - #[allow(dead_code)] - pub fn set_header(&mut self, head: String) { - self.header = head; - } +/// A JSON Web Token ready for signing. +struct JWTSigner { + signer: Box, +} - /// Encodes the first two parts (header and claims) to base64 and assembles them into a form - /// ready to be signed. - fn encode_claims(&self) -> String { - let mut head = encode_base64(&self.header); - let claims = encode_base64(serde_json::to_string(&self.claims).unwrap()); - - head.push_str("."); - head.push_str(&claims); - head - } - - /// Sign a JWT base string with `private_key`, which is a PKCS8 string. - fn sign(&self, private_key: &str) -> Result { - let mut jwt_head = self.encode_claims(); +impl JWTSigner { + fn new(private_key: &str) -> Result { let key = decode_rsa_key(private_key)?; let signing_key = sign::RSASigningKey::new(&key) .map_err(|_| io::Error::new(io::ErrorKind::Other, "Couldn't initialize signer"))?; @@ -144,40 +134,25 @@ impl JWT { .ok_or_else(|| { io::Error::new(io::ErrorKind::Other, "Couldn't choose signing scheme") })?; - let signature = signer - .sign(jwt_head.as_bytes()) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("{}", e)))?; - let signature_b64 = encode_base64(signature); + Ok(JWTSigner { signer }) + } + fn sign_claims(&self, claims: &Claims) -> Result { + let mut jwt_head = Self::encode_claims(claims); + let signature = self.signer.sign(jwt_head.as_bytes())?; jwt_head.push_str("."); - jwt_head.push_str(&signature_b64); - + append_base64(&signature, &mut jwt_head); Ok(jwt_head) } -} -/// Set `iss`, `aud`, `exp`, `iat`, `scope` field in the returned `Claims`. -fn init_claims_from_key(key: &ServiceAccountKey, scopes: &[T]) -> Claims -where - T: AsRef, -{ - let iat = chrono::Utc::now().timestamp(); - let expiry = iat + 3600 - 5; // Max validity is 1h. - - let mut scopes_string = scopes.iter().fold(String::new(), |mut acc, sc| { - acc.push_str(sc.as_ref()); - acc.push_str(" "); - acc - }); - scopes_string.pop(); - - Claims { - iss: key.client_email.clone().unwrap(), - aud: key.token_uri.clone().unwrap(), - exp: expiry, - iat, - sub: None, - scope: scopes_string, + /// Encodes the first two parts (header and claims) to base64 and assembles them into a form + /// ready to be signed. + fn encode_claims(claims: &Claims) -> String { + let mut head = String::new(); + append_base64(GOOGLE_RS256_HEAD, &mut head); + head.push_str("."); + append_base64(&serde_json::to_string(&claims).unwrap(), &mut head); + head } } @@ -188,7 +163,7 @@ where pub struct ServiceAccountAccess { client: C, key: ServiceAccountKey, - sub: Option, + subject: Option, } impl ServiceAccountAccess { @@ -197,7 +172,7 @@ impl ServiceAccountAccess { ServiceAccountAccess { client: DefaultHyperClient, key, - sub: None, + subject: None, } } } @@ -214,21 +189,21 @@ where ServiceAccountAccess { client: hyper_client, key: self.key, - sub: self.sub, + subject: self.subject, } } - /// Use the provided sub. - pub fn sub(self, sub: String) -> Self { + /// Use the provided subject. + pub fn subject(self, subject: String) -> Self { ServiceAccountAccess { - sub: Some(sub), + subject: Some(subject), ..self } } /// Build the configured ServiceAccountAccess. - pub fn build(self) -> impl GetToken { - ServiceAccountAccessImpl::new(self.client.build_hyper_client(), self.key, self.sub) + pub fn build(self) -> Result { + ServiceAccountAccessImpl::new(self.client.build_hyper_client(), self.key, self.subject) } } @@ -236,20 +211,27 @@ struct ServiceAccountAccessImpl { client: hyper::Client, key: ServiceAccountKey, cache: Arc>, - sub: Option, + subject: Option, + signer: JWTSigner, } impl ServiceAccountAccessImpl where C: hyper::client::connect::Connect, { - fn new(client: hyper::Client, key: ServiceAccountKey, sub: Option) -> Self { - ServiceAccountAccessImpl { + fn new( + client: hyper::Client, + key: ServiceAccountKey, + subject: Option, + ) -> Result { + let signer = JWTSigner::new(&key.private_key)?; + Ok(ServiceAccountAccessImpl { client, key, cache: Arc::new(Mutex::new(MemoryStorage::default())), - sub, - } + subject, + signer, + }) } } @@ -268,25 +250,25 @@ where /// Send a request for a new Bearer token to the OAuth provider. async fn request_token( client: &hyper::client::Client, - sub: Option<&str>, + signer: &JWTSigner, + subject: Option<&str>, key: &ServiceAccountKey, scopes: &[T], ) -> Result where T: AsRef, { - let mut claims = init_claims_from_key(&key, scopes); - claims.sub = sub.map(|x| x.to_owned()); - let signed = JWT::new(claims) - .sign(key.private_key.as_ref().unwrap()) - .map_err(RequestError::LowLevelError)?; + let claims = Claims::new(key, scopes, subject); + let signed = signer.sign_claims(&claims).map_err(|_| { + RequestError::LowLevelError(io::Error::new( + io::ErrorKind::Other, + "unable to sign claims", + )) + })?; let rqbody = form_urlencoded::Serializer::new(String::new()) - .extend_pairs(vec![ - ("grant_type".to_string(), GRANT_TYPE.to_string()), - ("assertion".to_string(), signed), - ]) + .extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", signed.as_str())]) .finish(); - let request = hyper::Request::post(key.token_uri.as_ref().unwrap()) + let request = hyper::Request::post(&key.token_uri) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(rqbody)) .unwrap(); @@ -335,7 +317,8 @@ where } let token = Self::request_token( &self.client, - self.sub.as_ref().map(|x| x.as_str()), + &self.signer, + self.subject.as_ref().map(|x| x.as_str()), &self.key, scopes, ) @@ -399,7 +382,7 @@ mod tests { "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/yup-test-sa-1%40yup-test-243420.iam.gserviceaccount.com" }"#; let mut key: ServiceAccountKey = serde_json::from_str(client_secret).unwrap(); - key.token_uri = Some(format!("{}/token", server_url)); + key.token_uri = format!("{}/token", server_url); let json_response = r#"{ "access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk", @@ -429,7 +412,7 @@ mod tests { .with_body(json_response) .expect(1) .create(); - let acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None); + let acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None).unwrap(); let fut = async { let tok = acc .token(&["https://www.googleapis.com/auth/pubsub"]) @@ -472,7 +455,8 @@ mod tests { .create(); let acc = ServiceAccountAccess::new(key.clone()) .hyper_client(client.clone()) - .build(); + .build() + .unwrap(); let fut = async { let result = acc.token(&["https://www.googleapis.com/auth/pubsub"]).await; assert!(result.is_err()); @@ -494,7 +478,10 @@ mod tests { let key = service_account_key_from_file(&TEST_PRIVATE_KEY_PATH.to_string()).unwrap(); let https = HttpsConnector::new(); let client = hyper::Client::builder().build(https); - let acc = ServiceAccountAccess::new(key).hyper_client(client).build(); + let acc = ServiceAccountAccess::new(key) + .hyper_client(client) + .build() + .unwrap(); let rt = tokio::runtime::Builder::new() .core_threads(1) .panic_handler(|e| std::panic::resume_unwind(e)) @@ -512,7 +499,7 @@ mod tests { fn test_jwt_initialize_claims() { let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap(); let scopes = vec!["scope1", "scope2", "scope3"]; - let claims = super::init_claims_from_key(&key, &scopes); + let claims = Claims::new(&key, &scopes, None); assert_eq!( claims.iss, @@ -532,9 +519,9 @@ mod tests { fn test_jwt_sign() { let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap(); let scopes = vec!["scope1", "scope2", "scope3"]; - let claims = super::init_claims_from_key(&key, &scopes); - let jwt = super::JWT::new(claims); - let signature = jwt.sign(key.private_key.as_ref().unwrap()); + let signer = JWTSigner::new(&key.private_key).unwrap(); + let claims = Claims::new(&key, &scopes, None); + let signature = signer.sign_claims(&claims); assert!(signature.is_ok()); From e1f08191561a7f79f21c0c223b71052e1643cdba Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Tue, 12 Nov 2019 10:12:08 -0800 Subject: [PATCH 23/71] Authenticator should handle the server not returning a refresh_token. Currently the authenticator will panic when trying to refresh an expired token that does not have a refresh token. This change handles it so that the authenticator will only attempt a refresh when a refresh_token exists, and otherwise will attempt to retrieve a fresh token. --- src/authenticator.rs | 28 ++++++++++++++++------------ src/installed.rs | 4 ++-- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index f7016be..dd2292b 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -196,17 +196,16 @@ where let appsecret = gettoken.application_secret(); loop { match store.get(scope_key, scopes) { - Ok(Some(t)) => { - if !t.expired() { - return Ok(t); - } - // Implement refresh flow. - let rr = RefreshFlow::refresh_token( - client, - appsecret, - &t.refresh_token.as_ref().unwrap(), - ) - .await?; + Ok(Some(t)) if !t.expired() => { + // unexpired token found + return Ok(t); + } + Ok(Some(Token { + refresh_token: Some(refresh_token), + .. + })) => { + // token is expired but has a refresh token. + let rr = RefreshFlow::refresh_token(client, appsecret, &refresh_token).await?; match rr { RefreshResult::Error(ref e) => { delegate.token_refresh_failed( @@ -236,7 +235,12 @@ where } } } - Ok(None) => { + Ok(None) + | Ok(Some(Token { + refresh_token: None, + .. + })) => { + // no token in the cache or the token returned does not contain a refresh token. let t = gettoken.token(scopes).await?; if let Err(e) = store.set(scope_key, scopes, Some(t.clone())) { match delegate.token_storage_failure(true, &e) { diff --git a/src/installed.rs b/src/installed.rs index 5dc19ed..d100fd0 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -255,7 +255,7 @@ where #[derive(Deserialize)] struct JSONTokenResponse { access_token: String, - refresh_token: String, + refresh_token: Option, token_type: String, expires_in: Option, } @@ -270,7 +270,7 @@ where }) => { let mut token = Token { access_token, - refresh_token: Some(refresh_token), + refresh_token, token_type, expires_in, expires_in_timestamp: None, From 88a8f74406327b2f6ade50d3985426cb7fe2e1cc Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Tue, 12 Nov 2019 13:03:01 -0800 Subject: [PATCH 24/71] Refactor token storage. The current code uses standard blocking i/o operations (std::fs::*) this is problematic as it would block the entire futures executor waiting for i/o. This change is a major refactoring to make the token storage mechansim async i/o friendly. The first major decision was to abandon the GetToken trait. The trait is only implemented internally and there was no mechanism for users to provide their own, but async fn's are not currently supported in trait impls so keeping the trait would have required Boxing futures. This probably would have been fine, but seemed unnecessary. Instead of a trait the storage mechanism is just an enum with a choice between Memory and Disk storage. The DiskStorage works primarily as it did before, rewriting the entire contents of the file on every set() invocation. The only difference is that we now defer the actual writing to a separate task so that it does not block the return of the Token to the user. If disk i/o is too slow to keep up with the rate of incoming writes it will push back and will eventually block the return of tokens, this is to prevent a buildup of in-flight requests. One major drawback to this approach is that any errors that happen on write are simply logged and no delegate function is invoked on error because the delegate no longer has the ability to say to sleep, retry, etc. --- Cargo.toml | 3 +- examples/test-device/src/main.rs | 1 + examples/test-installed/src/main.rs | 1 + src/authenticator.rs | 172 ++++++-------- src/authenticator_delegate.rs | 10 - src/device.rs | 2 +- src/lib.rs | 2 +- src/service_account.rs | 26 +-- src/storage.rs | 348 ++++++++++++---------------- 9 files changed, 241 insertions(+), 324 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e1ca60f..88618be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,8 +16,7 @@ chrono = "0.4" http = "0.1" hyper = {version = "0.13.0-alpha.4", features = ["unstable-stream"]} hyper-rustls = "=0.18.0-alpha.2" -itertools = "0.8" -log = "0.3" +log = "0.4" rustls = "0.16" serde = "1.0" serde_json = "1.0" diff --git a/examples/test-device/src/main.rs b/examples/test-device/src/main.rs index 64e413c..62bdf10 100644 --- a/examples/test-device/src/main.rs +++ b/examples/test-device/src/main.rs @@ -10,6 +10,7 @@ async fn main() { let auth = Authenticator::new(DeviceFlow::new(creds)) .persist_tokens_to_disk("tokenstorage.json") .build() + .await .expect("authenticator"); let scopes = &["https://www.googleapis.com/auth/youtube.readonly"]; diff --git a/examples/test-installed/src/main.rs b/examples/test-installed/src/main.rs index 54be93f..1bdfee0 100644 --- a/examples/test-installed/src/main.rs +++ b/examples/test-installed/src/main.rs @@ -14,6 +14,7 @@ async fn main() { )) .persist_tokens_to_disk("tokencache.json") .build() + .await .unwrap(); let scopes = &["https://www.googleapis.com/auth/drive.file"]; diff --git a/src/authenticator.rs b/src/authenticator.rs index dd2292b..b332e5a 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -1,14 +1,15 @@ -use crate::authenticator_delegate::{AuthenticatorDelegate, DefaultAuthenticatorDelegate, Retry}; +use crate::authenticator_delegate::{AuthenticatorDelegate, DefaultAuthenticatorDelegate}; use crate::refresh::RefreshFlow; -use crate::storage::{hash_scopes, DiskTokenStorage, MemoryStorage, TokenStorage}; +use crate::storage::{self, Storage}; use crate::types::{ApplicationSecret, GetToken, RefreshResult, RequestError, Token}; use futures::prelude::*; use std::error::Error; use std::io; -use std::path::Path; +use std::path::PathBuf; use std::pin::Pin; +use std::sync::Mutex; /// Authenticator abstracts different `GetToken` implementations behind one type and handles /// caching received tokens. It's important to use it (instead of the flows directly) because @@ -20,15 +21,11 @@ use std::pin::Pin; /// NOTE: It is recommended to use a client constructed like this in order to prevent functions /// like `hyper::run()` from hanging: `let client = hyper::Client::builder().keep_alive(false);`. /// Due to token requests being rare, this should not result in a too bad performance problem. -struct AuthenticatorImpl< - T: GetToken, - S: TokenStorage, - AD: AuthenticatorDelegate, - C: hyper::client::connect::Connect, -> { +struct AuthenticatorImpl +{ client: hyper::Client, inner: T, - store: S, + store: Storage, delegate: AD, } @@ -69,17 +66,22 @@ pub trait AuthFlow { fn build_token_getter(self, client: hyper::Client) -> Self::TokenGetter; } +enum StorageType { + Memory, + Disk(PathBuf), +} + /// An authenticator can be used with `InstalledFlow`'s or `DeviceFlow`'s and /// will refresh tokens as they expire as well as optionally persist tokens to /// disk. -pub struct Authenticator { +pub struct Authenticator { client: C, token_getter: T, - store: io::Result, + storage_type: StorageType, delegate: AD, } -impl Authenticator +impl Authenticator where T: AuthFlow<::Connector>, { @@ -90,27 +92,27 @@ where /// /// Examples /// ``` + /// # #[tokio::main] + /// # async fn main() { /// use std::path::Path; /// use yup_oauth2::{ApplicationSecret, Authenticator, DeviceFlow}; /// let creds = ApplicationSecret::default(); - /// let auth = Authenticator::new(DeviceFlow::new(creds)).build().unwrap(); + /// let auth = Authenticator::new(DeviceFlow::new(creds)).build().await.unwrap(); + /// # } /// ``` - pub fn new( - flow: T, - ) -> Authenticator { + pub fn new(flow: T) -> Authenticator { Authenticator { client: DefaultHyperClient, token_getter: flow, - store: Ok(MemoryStorage::new()), + storage_type: StorageType::Memory, delegate: DefaultAuthenticatorDelegate, } } } -impl Authenticator +impl Authenticator where T: AuthFlow, - S: TokenStorage, AD: AuthenticatorDelegate, C: HyperClientBuilder, { @@ -118,7 +120,7 @@ where pub fn hyper_client( self, hyper_client: hyper::Client, - ) -> Authenticator> + ) -> Authenticator> where NewC: hyper::client::connect::Connect + 'static, T: AuthFlow, @@ -126,21 +128,17 @@ where Authenticator { client: hyper_client, token_getter: self.token_getter, - store: self.store, + storage_type: self.storage_type, delegate: self.delegate, } } /// Persist tokens to disk in the provided filename. - pub fn persist_tokens_to_disk>( - self, - path: P, - ) -> Authenticator { - let disk_storage = DiskTokenStorage::new(path.as_ref().to_str().unwrap()); + pub fn persist_tokens_to_disk>(self, path: P) -> Authenticator { Authenticator { client: self.client, token_getter: self.token_getter, - store: disk_storage, + storage_type: StorageType::Disk(path.into()), delegate: self.delegate, } } @@ -149,24 +147,29 @@ where pub fn delegate( self, delegate: NewAD, - ) -> Authenticator { + ) -> Authenticator { Authenticator { client: self.client, token_getter: self.token_getter, - store: self.store, + storage_type: self.storage_type, delegate, } } /// Create the authenticator. - pub fn build(self) -> io::Result + pub async fn build(self) -> io::Result where T::TokenGetter: GetToken, C::Connector: hyper::client::connect::Connect + 'static, { let client = self.client.build_hyper_client(); - let store = self.store?; let inner = self.token_getter.build_token_getter(client.clone()); + let store = match self.storage_type { + StorageType::Memory => Storage::Memory { + tokens: Mutex::new(storage::JSONTokens::new()), + }, + StorageType::Disk(path) => Storage::Disk(storage::DiskStorage::new(path).await?), + }; Ok(AuthenticatorImpl { client, @@ -177,10 +180,9 @@ where } } -impl AuthenticatorImpl +impl AuthenticatorImpl where GT: GetToken, - S: TokenStorage, AD: AuthenticatorDelegate, C: hyper::client::connect::Connect + 'static, { @@ -188,83 +190,61 @@ where where T: AsRef + Sync, { - let scope_key = hash_scopes(scopes); + let scope_key = storage::ScopeHash::new(scopes); let store = &self.store; let delegate = &self.delegate; let client = &self.client; let gettoken = &self.inner; let appsecret = gettoken.application_secret(); - loop { - match store.get(scope_key, scopes) { - Ok(Some(t)) if !t.expired() => { - // unexpired token found - return Ok(t); - } - Ok(Some(Token { - refresh_token: Some(refresh_token), - .. - })) => { - // token is expired but has a refresh token. - let rr = RefreshFlow::refresh_token(client, appsecret, &refresh_token).await?; - match rr { - RefreshResult::Error(ref e) => { - delegate.token_refresh_failed( - e.description(), - Some("the request has likely timed out"), + match store.get(scope_key, scopes) { + Some(t) if !t.expired() => { + // unexpired token found + Ok(t) + } + Some(Token { + refresh_token: Some(refresh_token), + .. + }) => { + // token is expired but has a refresh token. + let rr = RefreshFlow::refresh_token(client, appsecret, &refresh_token).await?; + match rr { + RefreshResult::Error(ref e) => { + delegate.token_refresh_failed( + e.description(), + Some("the request has likely timed out"), + ); + Err(RequestError::Refresh(rr)) + } + RefreshResult::RefreshError(ref s, ref ss) => { + delegate.token_refresh_failed( + &format!("{}{}", s, ss.as_ref().map(|s| format!(" ({})", s)).unwrap_or_else(String::new)), + Some("the refresh token is likely invalid and your authorization has been revoked"), ); - return Err(RequestError::Refresh(rr)); - } - RefreshResult::RefreshError(ref s, ref ss) => { - delegate.token_refresh_failed( - &format!("{}{}", s, ss.as_ref().map(|s| format!(" ({})", s)).unwrap_or_else(String::new)), - Some("the refresh token is likely invalid and your authorization has been revoked"), - ); - return Err(RequestError::Refresh(rr)); - } - RefreshResult::Success(t) => { - let x = store.set(scope_key, scopes, Some(t.clone())); - if let Err(e) = x { - match delegate.token_storage_failure(true, &e) { - Retry::Skip => return Ok(t), - Retry::Abort => return Err(RequestError::Cache(Box::new(e))), - Retry::After(d) => tokio::timer::delay_for(d).await, - } - } else { - return Ok(t); - } - } + Err(RequestError::Refresh(rr)) + } + RefreshResult::Success(t) => { + store.set(scope_key, scopes, Some(t.clone())).await; + Ok(t) } } - Ok(None) - | Ok(Some(Token { - refresh_token: None, - .. - })) => { - // no token in the cache or the token returned does not contain a refresh token. - let t = gettoken.token(scopes).await?; - if let Err(e) = store.set(scope_key, scopes, Some(t.clone())) { - match delegate.token_storage_failure(true, &e) { - Retry::Skip => return Ok(t), - Retry::Abort => return Err(RequestError::Cache(Box::new(e))), - Retry::After(d) => tokio::timer::delay_for(d).await, - } - } else { - return Ok(t); - } - } - Err(err) => match delegate.token_storage_failure(false, &err) { - Retry::Abort | Retry::Skip => return Err(RequestError::Cache(Box::new(err))), - Retry::After(d) => tokio::timer::delay_for(d).await, - }, + } + None + | Some(Token { + refresh_token: None, + .. + }) => { + // no token in the cache or the token returned does not contain a refresh token. + let t = gettoken.token(scopes).await?; + store.set(scope_key, scopes, Some(t.clone())).await; + Ok(t) } } } } -impl GetToken for AuthenticatorImpl +impl GetToken for AuthenticatorImpl where GT: GetToken, - S: TokenStorage, AD: AuthenticatorDelegate, C: hyper::client::connect::Connect + 'static, { diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index eee7ec9..b039e3b 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -79,16 +79,6 @@ pub trait AuthenticatorDelegate: Send + Sync { Retry::Abort } - /// Called whenever we failed to retrieve a token or set a token due to a storage error. - /// You may use it to either ignore the incident or retry. - /// This can be useful if the underlying `TokenStorage` may fail occasionally. - /// if `is_set` is true, the failure resulted from `TokenStorage.set(...)`. Otherwise, - /// it was `TokenStorage.get(...)` - fn token_storage_failure(&self, is_set: bool, _: &(dyn Error + Send + Sync)) -> Retry { - let _ = is_set; - Retry::Abort - } - /// The server denied the attempt to obtain a request code fn request_failure(&self, _: RequestError) {} diff --git a/src/device.rs b/src/device.rs index 51d97b3..a72fddd 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,7 +1,7 @@ use std::pin::Pin; use std::time::Duration; -use ::log::{error, log}; +use ::log::error; use chrono::{DateTime, Utc}; use futures::prelude::*; use hyper; diff --git a/src/lib.rs b/src/lib.rs index f53d54c..549ea28 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -61,6 +61,7 @@ //! ) //! .persist_tokens_to_disk("tokencache.json") //! .build() +//! .await //! .unwrap(); //! //! let scopes = &["https://www.googleapis.com/auth/drive.file"]; @@ -96,7 +97,6 @@ pub use crate::device::{DeviceFlow, GOOGLE_DEVICE_CODE_URL}; pub use crate::helper::*; pub use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; pub use crate::service_account::*; -pub use crate::storage::{DiskTokenStorage, MemoryStorage, NullStorage, TokenStorage}; pub use crate::types::{ ApplicationSecret, ConsoleApplicationSecret, GetToken, PollError, RefreshResult, RequestError, Scheme, Token, TokenType, diff --git a/src/service_account.rs b/src/service_account.rs index 03525a5..1b10de8 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -11,12 +11,11 @@ //! Copyright (c) 2016 Google Inc (lewinb@google.com). //! -use std::default::Default; use std::pin::Pin; -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; use crate::authenticator::{DefaultHyperClient, HyperClientBuilder}; -use crate::storage::{hash_scopes, MemoryStorage, TokenStorage}; +use crate::storage::{self, Storage}; use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, RequestError, Token}; use futures::prelude::*; @@ -210,7 +209,7 @@ where struct ServiceAccountAccessImpl { client: hyper::Client, key: ServiceAccountKey, - cache: Arc>, + cache: Storage, subject: Option, signer: JWTSigner, } @@ -228,7 +227,9 @@ where Ok(ServiceAccountAccessImpl { client, key, - cache: Arc::new(Mutex::new(MemoryStorage::default())), + cache: Storage::Memory { + tokens: Mutex::new(storage::JSONTokens::new()), + }, subject, signer, }) @@ -309,10 +310,10 @@ where where T: AsRef, { - let hash = hash_scopes(scopes); + let hash = storage::ScopeHash::new(scopes); let cache = &self.cache; - match cache.lock().unwrap().get(hash, scopes) { - Ok(Some(token)) if !token.expired() => return Ok(token), + match cache.get(hash, scopes) { + Some(token) if !token.expired() => return Ok(token), _ => {} } let token = Self::request_token( @@ -323,7 +324,7 @@ where scopes, ) .await?; - let _ = cache.lock().unwrap().set(hash, scopes, Some(token.clone())); + cache.set(hash, scopes, Some(token.clone())).await; Ok(token) } } @@ -425,13 +426,12 @@ mod tests { assert!(acc .cache - .lock() - .unwrap() .get( - 3502164897243251857, + dbg!(storage::ScopeHash::new(&[ + "https://www.googleapis.com/auth/pubsub" + ])), &["https://www.googleapis.com/auth/pubsub"], ) - .unwrap() .is_some()); // Test that token is in cache (otherwise mock will tell us) let fut = async { diff --git a/src/storage.rs b/src/storage.rs index fef3c0a..17cf2db 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -5,127 +5,65 @@ use std::cmp::Ordering; use std::collections::hash_map::DefaultHasher; -use std::error::Error; -use std::fs; use std::hash::{Hash, Hasher}; use std::io; -use std::io::Write; use std::path::{Path, PathBuf}; use std::sync::Mutex; use crate::types::Token; -use itertools::Itertools; -/// Implements a specialized storage to set and retrieve `Token` instances. -/// The `scope_hash` represents the signature of the scopes for which the given token -/// should be stored or retrieved. -/// For completeness, the underlying, sorted scopes are provided as well. They might be -/// useful for presentation to the user. -pub trait TokenStorage: Send + Sync { - type Error: 'static + Error + Send + Sync; +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct ScopeHash(u64); - /// If `token` is None, it is invalid or revoked and should be removed from storage. - /// Otherwise, it should be saved. - fn set( - &self, - scope_hash: u64, - scopes: &[T], - token: Option, - ) -> Result<(), Self::Error> - where - T: AsRef; - - /// A `None` result indicates that there is no token for the given scope_hash. - fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> - where - T: AsRef; -} - -/// Calculate a hash value describing the scopes. The order of the scopes in the -/// list does not change the hash value. i.e. two lists that contains the exact -/// same scopes, but in different order will return the same hash value. -pub fn hash_scopes(scopes: &[T]) -> u64 -where - T: AsRef, -{ - let mut hash_sum = DefaultHasher::new().finish(); - for scope in scopes { - let mut hasher = DefaultHasher::new(); - scope.as_ref().hash(&mut hasher); - hash_sum ^= hasher.finish(); - } - hash_sum -} - -/// A storage that remembers nothing. -#[derive(Default)] -pub struct NullStorage; - -impl TokenStorage for NullStorage { - type Error = std::convert::Infallible; - fn set(&self, _: u64, _: &[T], _: Option) -> Result<(), Self::Error> +impl ScopeHash { + /// Calculate a hash value describing the scopes. The order of the scopes in the + /// list does not change the hash value. i.e. two lists that contains the exact + /// same scopes, but in different order will return the same hash value. + pub fn new(scopes: &[T]) -> Self where T: AsRef, { - Ok(()) - } - - fn get(&self, _: u64, _: &[T]) -> Result, Self::Error> - where - T: AsRef, - { - Ok(None) - } -} - -/// A storage that remembers values for one session only. -#[derive(Debug, Default)] -pub struct MemoryStorage { - tokens: Mutex>, -} - -impl MemoryStorage { - pub fn new() -> MemoryStorage { - Default::default() - } -} - -impl TokenStorage for MemoryStorage { - type Error = std::convert::Infallible; - - fn set(&self, scope_hash: u64, scopes: &[T], token: Option) -> Result<(), Self::Error> - where - T: AsRef, - { - let mut tokens = self.tokens.lock().expect("poisoned mutex"); - let matched = tokens.iter().find_position(|x| x.hash == scope_hash); - if let Some((idx, _)) = matched { - self.tokens.retain(|x| x.hash != scope_hash); + let mut hash_sum = DefaultHasher::new().finish(); + for scope in scopes { + let mut hasher = DefaultHasher::new(); + scope.as_ref().hash(&mut hasher); + hash_sum ^= hasher.finish(); } - - if let Some(t) = token { - tokens.push(JSONToken { - hash: scope_hash, - scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()), - token: t, - }); - } - Ok(()) + ScopeHash(hash_sum) } +} - fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> +pub(crate) enum Storage { + Memory { tokens: Mutex }, + Disk(DiskStorage), +} + +impl Storage { + pub(crate) async fn set(&self, h: ScopeHash, scopes: &[T], token: Option) where T: AsRef, { - let tokens = self.tokens.lock().expect("poisoned mutex"); - Ok(token_for_scopes(&tokens, scope_hash, scopes)) + match self { + Storage::Memory { tokens } => tokens.lock().unwrap().set(h, scopes, token), + Storage::Disk(disk_storage) => disk_storage.set(h, scopes, token).await, + } + } + + pub(crate) fn get(&self, h: ScopeHash, scopes: &[T]) -> Option + where + T: AsRef, + { + match self { + Storage::Memory { tokens } => tokens.lock().unwrap().get(h, scopes), + Storage::Disk(disk_storage) => disk_storage.get(h, scopes), + } } } /// A single stored token. #[derive(Debug, Clone, Serialize, Deserialize)] struct JSONToken { - pub hash: u64, + pub hash: ScopeHash, pub scopes: Option>, pub token: Token, } @@ -151,121 +89,123 @@ impl Ord for JSONToken { } /// List of tokens in a JSON object -#[derive(Serialize, Deserialize)] -struct JSONTokens { - pub tokens: Vec, +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct JSONTokens { + tokens: Vec, } -/// Serializes tokens to a JSON file on disk. -#[derive(Default)] -pub struct DiskTokenStorage { - location: PathBuf, - tokens: Mutex>, +impl JSONTokens { + pub(crate) fn new() -> Self { + JSONTokens { tokens: Vec::new() } + } + + pub(crate) async fn load_from_file(filename: &Path) -> Result { + let contents = tokio::fs::read(filename).await?; + let container: JSONTokens = serde_json::from_slice(&contents) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(container) + } + + fn get(&self, h: ScopeHash, scopes: &[T]) -> Option + where + T: AsRef, + { + for t in self.tokens.iter() { + if let Some(token_scopes) = &t.scopes { + if scopes + .iter() + .all(|s| token_scopes.iter().any(|t| t == s.as_ref())) + { + return Some(t.token.clone()); + } + } else if h == t.hash { + return Some(t.token.clone()); + } + } + None + } + + fn set(&mut self, h: ScopeHash, scopes: &[T], token: Option) + where + T: AsRef, + { + eprintln!("setting: {:?}, {:?}", h, token); + self.tokens.retain(|x| x.hash != h); + + match token { + None => (), + Some(t) => { + self.tokens.push(JSONToken { + hash: h, + scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()), + token: t, + }); + } + } + } + + // TODO: ideally this function would accept &Path, but tokio requires the + // path be 'static. Revisit this and ask why tokio::fs::write has that + // limitation. + async fn dump_to_file(&self, path: PathBuf) -> Result<(), io::Error> { + let serialized = serde_json::to_string(self) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + tokio::fs::write(path, &serialized).await + } } -impl DiskTokenStorage { - pub fn new>(location: S) -> Result { - let filename = location.into(); - let tokens = match load_from_file(&filename) { - Ok(tokens) => tokens, - Err(e) if e.kind() == io::ErrorKind::NotFound => Vec::new(), - Err(e) => return Err(e), - }; - Ok(DiskTokenStorage { - location: filename, +pub(crate) struct DiskStorage { + tokens: Mutex, + write_tx: tokio::sync::mpsc::Sender, +} + +impl DiskStorage { + pub(crate) async fn new(path: PathBuf) -> Result { + let tokens = JSONTokens::load_from_file(&path).await?; + // Writing to disk will happen in a separate task. This means in the + // common case returning a token to the user will not be required to + // wait for disk i/o. We communicate with a dedicated writer task via a + // buffered channel. This ensures that the writes happen in the order + // received, and if writes fall too far behind we will block GetToken + // requests until disk i/o completes. + let (write_tx, mut write_rx) = tokio::sync::mpsc::channel::(2); + tokio::spawn(async move { + while let Some(tokens) = write_rx.recv().await { + if let Err(e) = tokens.dump_to_file(path.to_path_buf()).await { + log::error!("Failed to write token storage to disk: {}", e); + } + } + }); + Ok(DiskStorage { tokens: Mutex::new(tokens), + write_tx, }) } - pub fn dump_to_file(&self) -> Result<(), io::Error> { - let mut jsontokens = JSONTokens { tokens: Vec::new() }; - - { - let tokens = self.tokens.lock().expect("mutex poisoned"); - for token in tokens.iter() { - jsontokens.tokens.push((*token).clone()); - } - } - - let serialized; - - match serde_json::to_string(&jsontokens) { - Result::Err(e) => return Result::Err(io::Error::new(io::ErrorKind::InvalidData, e)), - Result::Ok(s) => serialized = s, - } - - // TODO: Write to disk asynchronously so that we don't stall the eventloop if invoked in async context. - let mut f = fs::OpenOptions::new() - .create(true) - .write(true) - .truncate(true) - .open(&self.location)?; - f.write(serialized.as_ref()).map(|_| ()) - } -} - -fn load_from_file(filename: &Path) -> Result, io::Error> { - let contents = std::fs::read_to_string(filename)?; - let container: JSONTokens = serde_json::from_str(&contents) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - Ok(container.tokens) -} - -impl TokenStorage for DiskTokenStorage { - type Error = io::Error; - fn set(&self, scope_hash: u64, scopes: &[T], token: Option) -> Result<(), Self::Error> + async fn set(&self, h: ScopeHash, scopes: &[T], token: Option) where T: AsRef, { - { - let mut tokens = self.tokens.lock().expect("poisoned mutex"); - let matched = tokens.iter().find_position(|x| x.hash == scope_hash); - if let Some((idx, _)) = matched { - self.tokens.retain(|x| x.hash != scope_hash); - } - - match token { - None => (), - Some(t) => { - tokens.push(JSONToken { - hash: scope_hash, - scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()), - token: t, - }); - } - } - } - self.dump_to_file() + let cloned_tokens = { + let mut tokens = self.tokens.lock().unwrap(); + tokens.set(h, scopes, token); + tokens.clone() + }; + self.write_tx + .clone() + .send(cloned_tokens) + .await + .expect("disk storage task not running"); } - fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> + pub(crate) fn get(&self, h: ScopeHash, scopes: &[T]) -> Option where T: AsRef, { - let tokens = self.tokens.lock().expect("poisoned mutex"); - Ok(token_for_scopes(&tokens, scope_hash, scopes)) + self.tokens.lock().unwrap().get(h, scopes) } } -fn token_for_scopes(tokens: &[JSONToken], scope_hash: u64, scopes: &[T]) -> Option -where - T: AsRef, -{ - for t in tokens.iter() { - if let Some(token_scopes) = &t.scopes { - if scopes - .iter() - .all(|s| token_scopes.iter().any(|t| t == s.as_ref())) - { - return Some(t.token.clone()); - } - } else if scope_hash == t.hash { - return Some(t.token.clone()); - } - } - None -} - #[cfg(test)] mod tests { use super::*; @@ -273,19 +213,25 @@ mod tests { #[test] fn test_hash_scopes() { // Idential list should hash equal. - assert_eq!(hash_scopes(&["foo", "bar"]), hash_scopes(&["foo", "bar"])); - // The hash should be order independent. - assert_eq!(hash_scopes(&["bar", "foo"]), hash_scopes(&["foo", "bar"])); assert_eq!( - hash_scopes(&["bar", "baz", "bat"]), - hash_scopes(&["baz", "bar", "bat"]) + ScopeHash::new(&["foo", "bar"]), + ScopeHash::new(&["foo", "bar"]) + ); + // The hash should be order independent. + assert_eq!( + ScopeHash::new(&["bar", "foo"]), + ScopeHash::new(&["foo", "bar"]) + ); + assert_eq!( + ScopeHash::new(&["bar", "baz", "bat"]), + ScopeHash::new(&["baz", "bar", "bat"]) ); // Ensure hashes differ when the contents are different by more than // just order. assert_ne!( - hash_scopes(&["foo", "bar", "baz"]), - hash_scopes(&["foo", "bar"]) + ScopeHash::new(&["foo", "bar", "baz"]), + ScopeHash::new(&["foo", "bar"]) ); } } From c0919bee86cda7e7f752ed8655d95124ed821a86 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Tue, 12 Nov 2019 13:23:37 -0800 Subject: [PATCH 25/71] allow setting grant_type for device code --- src/device.rs | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/device.rs b/src/device.rs index a72fddd..45e117e 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::pin::Pin; use std::time::Duration; @@ -14,6 +15,9 @@ use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, PollError, RequestE pub const GOOGLE_DEVICE_CODE_URL: &str = "https://accounts.google.com/o/oauth2/device/code"; +// https://developers.google.com/identity/protocols/OAuth2ForDevices#step-4:-poll-googles-authorization-server +pub const GOOGLE_GRANT_TYPE: &str = "http://oauth.net/grant_type/device/1.0"; + /// Implements the [Oauth2 Device Flow](https://developers.google.com/youtube/v3/guides/authentication#devices) /// It operates in two steps: /// * obtain a code to show to the user @@ -21,9 +25,10 @@ pub const GOOGLE_DEVICE_CODE_URL: &str = "https://accounts.google.com/o/oauth2/d #[derive(Clone)] pub struct DeviceFlow { application_secret: ApplicationSecret, - device_code_url: String, + device_code_url: Cow<'static, str>, flow_delegate: FD, wait: Duration, + grant_type: Cow<'static, str>, } impl DeviceFlow { @@ -32,9 +37,10 @@ impl DeviceFlow { pub fn new(secret: ApplicationSecret) -> DeviceFlow { DeviceFlow { application_secret: secret, - device_code_url: GOOGLE_DEVICE_CODE_URL.to_string(), + device_code_url: GOOGLE_DEVICE_CODE_URL.into(), flow_delegate: DefaultFlowDelegate, wait: Duration::from_secs(120), + grant_type: GOOGLE_GRANT_TYPE.into(), } } } @@ -43,7 +49,7 @@ impl DeviceFlow { /// Use the provided device code url. pub fn device_code_url(self, url: String) -> Self { DeviceFlow { - device_code_url: url, + device_code_url: url.into(), ..self } } @@ -55,6 +61,7 @@ impl DeviceFlow { device_code_url: self.device_code_url, flow_delegate: delegate, wait: self.wait, + grant_type: self.grant_type, } } @@ -65,6 +72,13 @@ impl DeviceFlow { ..self } } + + pub fn grant_type(self, grant_type: String) -> Self { + DeviceFlow { + grant_type: grant_type.into(), + ..self + } + } } impl crate::authenticator::AuthFlow for DeviceFlow @@ -81,6 +95,7 @@ where device_code_url: self.device_code_url, fd: self.flow_delegate, wait: Duration::from_secs(1200), + grant_type: self.grant_type, } } } @@ -90,9 +105,10 @@ pub struct DeviceFlowImpl { client: hyper::Client, application_secret: ApplicationSecret, /// Usually GOOGLE_DEVICE_CODE_URL - device_code_url: String, + device_code_url: Cow<'static, str>, fd: FD, wait: Duration, + grant_type: Cow<'static, str>, } impl GetToken for DeviceFlowImpl @@ -138,7 +154,7 @@ where .await?; self.fd.present_user_code(&pollinf); tokio::timer::Timeout::new( - self.wait_for_device_token(&pollinf, &device_code), + self.wait_for_device_token(&pollinf, &device_code, &self.grant_type), self.wait, ) .await @@ -149,6 +165,7 @@ where &self, pollinf: &PollInformation, device_code: &str, + grant_type: &str, ) -> Result { let mut interval = pollinf.interval; loop { @@ -157,6 +174,7 @@ where &self.application_secret, &self.client, device_code, + grant_type, pollinf.expires_at, &self.fd, ) @@ -274,6 +292,7 @@ where application_secret: &ApplicationSecret, client: &hyper::Client, device_code: &str, + grant_type: &str, expires_at: DateTime, fd: &FD, ) -> Result, PollError> { @@ -288,7 +307,7 @@ where ("client_id", application_secret.client_id.as_str()), ("client_secret", application_secret.client_secret.as_str()), ("code", device_code), - ("grant_type", "http://oauth.net/grant_type/device/1.0"), + ("grant_type", grant_type), ]) .finish(); From 911fec82f1229168009ecd45d94211815439c96c Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Tue, 12 Nov 2019 16:24:21 -0800 Subject: [PATCH 26/71] Make FlowDelegate object safe. --- src/authenticator_delegate.rs | 8 ++++---- src/installed.rs | 5 ++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index b039e3b..b8fe106 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -143,9 +143,9 @@ pub trait FlowDelegate: Send + Sync { /// We need the user to navigate to a URL using their browser and potentially paste back a code /// (or maybe not). Whether they have to enter a code depends on the InstalledFlowReturnMethod /// used. - fn present_user_url<'a, S: AsRef + fmt::Display + Send + Sync + 'a>( + fn present_user_url<'a>( &'a self, - url: S, + url: &'a str, need_code: bool, ) -> Pin>> + Send + 'a>> { @@ -153,8 +153,8 @@ pub trait FlowDelegate: Send + Sync { } } -async fn present_user_url + fmt::Display>( - url: S, +async fn present_user_url( + url: &str, need_code: bool, ) -> Result> { if need_code { diff --git a/src/installed.rs b/src/installed.rs index d100fd0..fa72489 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -447,7 +447,6 @@ mod installed_flow_server { #[cfg(test)] mod tests { use std::error::Error; - use std::fmt; use std::str::FromStr; use hyper::client::connect::HttpConnector; @@ -472,9 +471,9 @@ mod tests { impl FlowDelegate for FD { /// Depending on need_code, return the pre-set code or send the code to the server at /// the redirect_uri given in the url. - fn present_user_url<'a, S: AsRef + fmt::Display + Send + Sync + 'a>( + fn present_user_url<'a>( &'a self, - url: S, + url: &'a str, need_code: bool, ) -> Pin< Box>> + Send + 'a>, From 3aadc6b0efb339ebcd2975e4b90a87b7a6bb743f Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Wed, 13 Nov 2019 11:51:28 -0800 Subject: [PATCH 27/71] Major refactor of the public API. 1) Remove the GetToken trait. The trait seemed to be organically designed. It appeared to be mostly tailored for simplifying the implementation since there was no way for users to provide their own implementation to Authenticator. It sadly seemed to get in the way of implementations more than it helped. An enum representing the known implementations is a more straightforward way to accomplish the goal and also has the benefit of not requiring Boxing when returning features (which admittedly is a minor concern for this use case). 2) Reduce the number of type parameters by using trait object for delegates. This simplifies the code considerably and the performance impact of virtual dispatch for the delegate calls is a non-factor. 3) With the above two simplifications it became easier to unify the public interface for building an authenticator. See the examples for how InstalledFlow, DeviceFlow, and ServiceAccount authenticators are now created. --- examples/test-device/src/main.rs | 6 +- examples/test-installed/src/main.rs | 19 +- examples/test-svc-acct/src/main.rs | 11 +- src/authenticator.rs | 505 ++++++++++++++++------------ src/device.rs | 222 +++++------- src/installed.rs | 203 +++++------ src/lib.rs | 29 +- src/service_account.rs | 156 ++++----- src/types.rs | 21 -- 9 files changed, 519 insertions(+), 653 deletions(-) diff --git a/examples/test-device/src/main.rs b/examples/test-device/src/main.rs index 62bdf10..82a4741 100644 --- a/examples/test-device/src/main.rs +++ b/examples/test-device/src/main.rs @@ -1,13 +1,13 @@ -use yup_oauth2::{self, Authenticator, DeviceFlow, GetToken}; +use yup_oauth2::DeviceFlowAuthenticator; use std::path; use tokio; #[tokio::main] async fn main() { - let creds = yup_oauth2::read_application_secret(path::Path::new("clientsecret.json")) + let app_secret = yup_oauth2::read_application_secret(path::Path::new("clientsecret.json")) .expect("clientsecret"); - let auth = Authenticator::new(DeviceFlow::new(creds)) + let auth = DeviceFlowAuthenticator::builder(app_secret) .persist_tokens_to_disk("tokenstorage.json") .build() .await diff --git a/examples/test-installed/src/main.rs b/examples/test-installed/src/main.rs index 1bdfee0..3febb75 100644 --- a/examples/test-installed/src/main.rs +++ b/examples/test-installed/src/main.rs @@ -1,21 +1,18 @@ -use yup_oauth2::GetToken; -use yup_oauth2::{Authenticator, InstalledFlow}; +use yup_oauth2::{InstalledFlowAuthenticator, InstalledFlowReturnMethod}; use std::path::Path; #[tokio::main] async fn main() { - let secret = yup_oauth2::read_application_secret(Path::new("clientsecret.json")) + let app_secret = yup_oauth2::read_application_secret(Path::new("clientsecret.json")) .expect("clientsecret.json"); - let auth = Authenticator::new(InstalledFlow::new( - secret, - yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect, - )) - .persist_tokens_to_disk("tokencache.json") - .build() - .await - .unwrap(); + let auth = + InstalledFlowAuthenticator::builder(app_secret, InstalledFlowReturnMethod::HTTPRedirect) + .persist_tokens_to_disk("tokencache.json") + .build() + .await + .unwrap(); let scopes = &["https://www.googleapis.com/auth/drive.file"]; match auth.token(scopes).await { diff --git a/examples/test-svc-acct/src/main.rs b/examples/test-svc-acct/src/main.rs index e5ef33c..8ff8db0 100644 --- a/examples/test-svc-acct/src/main.rs +++ b/examples/test-svc-acct/src/main.rs @@ -1,15 +1,10 @@ -use std::path; use tokio; -use yup_oauth2; -use yup_oauth2::GetToken; +use yup_oauth2::ServiceAccountAuthenticator; #[tokio::main] async fn main() { - let creds = - yup_oauth2::service_account_key_from_file(path::Path::new("serviceaccount.json")).unwrap(); - let sa = yup_oauth2::ServiceAccountAccess::new(creds) - .build() - .unwrap(); + let creds = yup_oauth2::service_account_key_from_file("serviceaccount.json").unwrap(); + let sa = ServiceAccountAuthenticator::builder(creds).build().unwrap(); let scopes = &["https://www.googleapis.com/auth/pubsub"]; let tok = sa.token(scopes).await.unwrap(); diff --git a/src/authenticator.rs b/src/authenticator.rs index b332e5a..1b5a2de 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -1,32 +1,293 @@ -use crate::authenticator_delegate::{AuthenticatorDelegate, DefaultAuthenticatorDelegate}; +use crate::authenticator_delegate::{ + AuthenticatorDelegate, DefaultAuthenticatorDelegate, FlowDelegate, +}; +use crate::device::DeviceFlow; +use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; use crate::refresh::RefreshFlow; use crate::storage::{self, Storage}; -use crate::types::{ApplicationSecret, GetToken, RefreshResult, RequestError, Token}; - -use futures::prelude::*; +use crate::types::{ApplicationSecret, RefreshResult, RequestError, Token}; +use private::AuthFlow; +use std::borrow::Cow; use std::error::Error; use std::io; use std::path::PathBuf; -use std::pin::Pin; use std::sync::Mutex; +use std::time::Duration; -/// Authenticator abstracts different `GetToken` implementations behind one type and handles -/// caching received tokens. It's important to use it (instead of the flows directly) because -/// otherwise the user needs to be asked for new authorization every time a token is generated. -/// -/// `ServiceAccountAccess` does not need (and does not work) with `Authenticator`, given that it -/// does not require interaction and implements its own caching. Use it directly. -/// -/// NOTE: It is recommended to use a client constructed like this in order to prevent functions -/// like `hyper::run()` from hanging: `let client = hyper::Client::builder().keep_alive(false);`. -/// Due to token requests being rare, this should not result in a too bad performance problem. -struct AuthenticatorImpl +pub struct Authenticator { + hyper_client: hyper::Client, + app_secret: ApplicationSecret, + auth_delegate: Box, + storage: Storage, + auth_flow: AuthFlow, +} + +impl Authenticator +where + C: hyper::client::connect::Connect + 'static, { - client: hyper::Client, - inner: T, - store: Storage, - delegate: AD, + pub async fn token<'a, T>(&'a self, scopes: &'a [T]) -> Result + where + T: AsRef, + { + let scope_key = storage::ScopeHash::new(scopes); + match self.storage.get(scope_key, scopes) { + Some(t) if !t.expired() => { + // unexpired token found + Ok(t) + } + Some(Token { + refresh_token: Some(refresh_token), + .. + }) => { + // token is expired but has a refresh token. + let rr = RefreshFlow::refresh_token( + &self.hyper_client, + &self.app_secret, + &refresh_token, + ) + .await?; + match rr { + RefreshResult::Error(ref e) => { + self.auth_delegate.token_refresh_failed( + e.description(), + Some("the request has likely timed out"), + ); + Err(RequestError::Refresh(rr)) + } + RefreshResult::RefreshError(ref s, ref ss) => { + self.auth_delegate.token_refresh_failed( + &format!("{}{}", s, ss.as_ref().map(|s| format!(" ({})", s)).unwrap_or_else(String::new)), + Some("the refresh token is likely invalid and your authorization has been revoked"), + ); + Err(RequestError::Refresh(rr)) + } + RefreshResult::Success(t) => { + self.storage.set(scope_key, scopes, Some(t.clone())).await; + Ok(t) + } + } + } + None + | Some(Token { + refresh_token: None, + .. + }) => { + // no token in the cache or the token returned does not contain a refresh token. + let t = self + .auth_flow + .token(&self.hyper_client, &self.app_secret, scopes) + .await?; + self.storage.set(scope_key, scopes, Some(t.clone())).await; + Ok(t) + } + } + } +} + +pub struct AuthenticatorBuilder { + hyper_client_builder: C, + app_secret: ApplicationSecret, + auth_delegate: Box, + storage_type: StorageType, + auth_flow: F, +} + +pub struct InstalledFlowAuthenticator; +impl InstalledFlowAuthenticator { + pub fn builder( + app_secret: ApplicationSecret, + method: InstalledFlowReturnMethod, + ) -> AuthenticatorBuilder { + AuthenticatorBuilder::::with_auth_flow( + app_secret, + InstalledFlow::new(method), + ) + } +} + +pub struct DeviceFlowAuthenticator; +impl DeviceFlowAuthenticator { + pub fn builder( + app_secret: ApplicationSecret, + ) -> AuthenticatorBuilder { + AuthenticatorBuilder::::with_auth_flow(app_secret, DeviceFlow::new()) + } +} + +impl AuthenticatorBuilder { + fn with_auth_flow( + app_secret: ApplicationSecret, + auth_flow: F, + ) -> AuthenticatorBuilder { + AuthenticatorBuilder { + hyper_client_builder: DefaultHyperClient, + app_secret, + auth_delegate: Box::new(DefaultAuthenticatorDelegate), + storage_type: StorageType::Memory, + auth_flow, + } + } + + /// Use the provided hyper client. + pub fn hyper_client( + self, + hyper_client: hyper::Client, + ) -> AuthenticatorBuilder, F> { + AuthenticatorBuilder { + hyper_client_builder: hyper_client, + app_secret: self.app_secret, + auth_delegate: self.auth_delegate, + storage_type: self.storage_type, + auth_flow: self.auth_flow, + } + } + + /// Persist tokens to disk in the provided filename. + pub fn persist_tokens_to_disk>(self, path: P) -> AuthenticatorBuilder { + AuthenticatorBuilder { + storage_type: StorageType::Disk(path.into()), + ..self + } + } + + /// Use the provided authenticator delegate. + pub fn auth_delegate( + self, + auth_delegate: Box, + ) -> AuthenticatorBuilder { + AuthenticatorBuilder { + auth_delegate, + ..self + } + } + + /// Create the authenticator. + pub async fn build(self) -> io::Result> + where + C: HyperClientBuilder, + F: Into, + { + let hyper_client = self.hyper_client_builder.build_hyper_client(); + let storage = match self.storage_type { + StorageType::Memory => Storage::Memory { + tokens: Mutex::new(storage::JSONTokens::new()), + }, + StorageType::Disk(path) => Storage::Disk(storage::DiskStorage::new(path).await?), + }; + + Ok(Authenticator { + hyper_client, + app_secret: self.app_secret, + storage, + auth_delegate: self.auth_delegate, + auth_flow: self.auth_flow.into(), + }) + } +} + +impl AuthenticatorBuilder { + /// Use the provided device code url. + pub fn device_code_url(self, url: impl Into>) -> Self { + AuthenticatorBuilder { + auth_flow: DeviceFlow { + device_code_url: url.into(), + ..self.auth_flow + }, + ..self + } + } + + /// Use the provided FlowDelegate. + pub fn flow_delegate(self, flow_delegate: Box) -> Self { + AuthenticatorBuilder { + auth_flow: DeviceFlow { + flow_delegate, + ..self.auth_flow + }, + ..self + } + } + + /// Use the provided wait duration. + pub fn wait_duration(self, wait_duration: Duration) -> Self { + AuthenticatorBuilder { + auth_flow: DeviceFlow { + wait_duration, + ..self.auth_flow + }, + ..self + } + } + + /// Use the provided grant type. + pub fn grant_type(self, grant_type: impl Into>) -> Self { + AuthenticatorBuilder { + auth_flow: DeviceFlow { + grant_type: grant_type.into(), + ..self.auth_flow + }, + ..self + } + } +} + +impl AuthenticatorBuilder { + /// Use the provided FlowDelegate. + pub fn flow_delegate(self, flow_delegate: Box) -> Self { + AuthenticatorBuilder { + auth_flow: InstalledFlow { + flow_delegate, + ..self.auth_flow + }, + ..self + } + } +} + +mod private { + use crate::device::DeviceFlow; + use crate::installed::InstalledFlow; + use crate::types::{ApplicationSecret, RequestError, Token}; + pub enum AuthFlow { + DeviceFlow(DeviceFlow), + InstalledFlow(InstalledFlow), + } + + impl From for AuthFlow { + fn from(device_flow: DeviceFlow) -> AuthFlow { + AuthFlow::DeviceFlow(device_flow) + } + } + + impl From for AuthFlow { + fn from(installed_flow: InstalledFlow) -> AuthFlow { + AuthFlow::InstalledFlow(installed_flow) + } + } + + impl AuthFlow { + pub(crate) async fn token<'a, C, T>( + &'a self, + hyper_client: &'a hyper::Client, + app_secret: &'a ApplicationSecret, + scopes: &'a [T], + ) -> Result + where + T: AsRef, + C: hyper::client::connect::Connect + 'static, + { + match self { + AuthFlow::DeviceFlow(device_flow) => { + device_flow.token(hyper_client, app_secret, scopes).await + } + AuthFlow::InstalledFlow(installed_flow) => { + installed_flow.token(hyper_client, app_secret, scopes).await + } + } + } + } } /// A trait implemented for any hyper::Client as well as teh DefaultHyperClient. @@ -59,211 +320,7 @@ where } } -/// An internal trait implemented by flows to be used by an authenticator. -pub trait AuthFlow { - type TokenGetter: GetToken; - - fn build_token_getter(self, client: hyper::Client) -> Self::TokenGetter; -} - enum StorageType { Memory, Disk(PathBuf), } - -/// An authenticator can be used with `InstalledFlow`'s or `DeviceFlow`'s and -/// will refresh tokens as they expire as well as optionally persist tokens to -/// disk. -pub struct Authenticator { - client: C, - token_getter: T, - storage_type: StorageType, - delegate: AD, -} - -impl Authenticator -where - T: AuthFlow<::Connector>, -{ - /// Create a new authenticator with the provided flow. By default a new - /// hyper::Client will be created the default authenticator delegate will be - /// used, and tokens will not be persisted to disk. - /// Accepted flow types are DeviceFlow and InstalledFlow. - /// - /// Examples - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use std::path::Path; - /// use yup_oauth2::{ApplicationSecret, Authenticator, DeviceFlow}; - /// let creds = ApplicationSecret::default(); - /// let auth = Authenticator::new(DeviceFlow::new(creds)).build().await.unwrap(); - /// # } - /// ``` - pub fn new(flow: T) -> Authenticator { - Authenticator { - client: DefaultHyperClient, - token_getter: flow, - storage_type: StorageType::Memory, - delegate: DefaultAuthenticatorDelegate, - } - } -} - -impl Authenticator -where - T: AuthFlow, - AD: AuthenticatorDelegate, - C: HyperClientBuilder, -{ - /// Use the provided hyper client. - pub fn hyper_client( - self, - hyper_client: hyper::Client, - ) -> Authenticator> - where - NewC: hyper::client::connect::Connect + 'static, - T: AuthFlow, - { - Authenticator { - client: hyper_client, - token_getter: self.token_getter, - storage_type: self.storage_type, - delegate: self.delegate, - } - } - - /// Persist tokens to disk in the provided filename. - pub fn persist_tokens_to_disk>(self, path: P) -> Authenticator { - Authenticator { - client: self.client, - token_getter: self.token_getter, - storage_type: StorageType::Disk(path.into()), - delegate: self.delegate, - } - } - - /// Use the provided authenticator delegate. - pub fn delegate( - self, - delegate: NewAD, - ) -> Authenticator { - Authenticator { - client: self.client, - token_getter: self.token_getter, - storage_type: self.storage_type, - delegate, - } - } - - /// Create the authenticator. - pub async fn build(self) -> io::Result - where - T::TokenGetter: GetToken, - C::Connector: hyper::client::connect::Connect + 'static, - { - let client = self.client.build_hyper_client(); - let inner = self.token_getter.build_token_getter(client.clone()); - let store = match self.storage_type { - StorageType::Memory => Storage::Memory { - tokens: Mutex::new(storage::JSONTokens::new()), - }, - StorageType::Disk(path) => Storage::Disk(storage::DiskStorage::new(path).await?), - }; - - Ok(AuthenticatorImpl { - client, - inner, - store, - delegate: self.delegate, - }) - } -} - -impl AuthenticatorImpl -where - GT: GetToken, - AD: AuthenticatorDelegate, - C: hyper::client::connect::Connect + 'static, -{ - async fn get_token(&self, scopes: &[T]) -> Result - where - T: AsRef + Sync, - { - let scope_key = storage::ScopeHash::new(scopes); - let store = &self.store; - let delegate = &self.delegate; - let client = &self.client; - let gettoken = &self.inner; - let appsecret = gettoken.application_secret(); - match store.get(scope_key, scopes) { - Some(t) if !t.expired() => { - // unexpired token found - Ok(t) - } - Some(Token { - refresh_token: Some(refresh_token), - .. - }) => { - // token is expired but has a refresh token. - let rr = RefreshFlow::refresh_token(client, appsecret, &refresh_token).await?; - match rr { - RefreshResult::Error(ref e) => { - delegate.token_refresh_failed( - e.description(), - Some("the request has likely timed out"), - ); - Err(RequestError::Refresh(rr)) - } - RefreshResult::RefreshError(ref s, ref ss) => { - delegate.token_refresh_failed( - &format!("{}{}", s, ss.as_ref().map(|s| format!(" ({})", s)).unwrap_or_else(String::new)), - Some("the refresh token is likely invalid and your authorization has been revoked"), - ); - Err(RequestError::Refresh(rr)) - } - RefreshResult::Success(t) => { - store.set(scope_key, scopes, Some(t.clone())).await; - Ok(t) - } - } - } - None - | Some(Token { - refresh_token: None, - .. - }) => { - // no token in the cache or the token returned does not contain a refresh token. - let t = gettoken.token(scopes).await?; - store.set(scope_key, scopes, Some(t.clone())).await; - Ok(t) - } - } - } -} - -impl GetToken for AuthenticatorImpl -where - GT: GetToken, - AD: AuthenticatorDelegate, - C: hyper::client::connect::Connect + 'static, -{ - /// Returns the API Key of the inner flow. - fn api_key(&self) -> Option { - self.inner.api_key() - } - /// Returns the application secret of the inner flow. - fn application_secret(&self) -> &ApplicationSecret { - self.inner.application_secret() - } - - fn token<'a, T>( - &'a self, - scopes: &'a [T], - ) -> Pin> + Send + 'a>> - where - T: AsRef + Sync, - { - Box::pin(self.get_token(scopes)) - } -} diff --git a/src/device.rs b/src/device.rs index 45e117e..373c36a 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,5 +1,4 @@ use std::borrow::Cow; -use std::pin::Pin; use std::time::Duration; use ::log::error; @@ -11,7 +10,7 @@ use serde_json as json; use url::form_urlencoded; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate, PollInformation, Retry}; -use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, PollError, RequestError, Token}; +use crate::types::{ApplicationSecret, JsonErrorOr, PollError, RequestError, Token}; pub const GOOGLE_DEVICE_CODE_URL: &str = "https://accounts.google.com/o/oauth2/device/code"; @@ -22,165 +21,77 @@ pub const GOOGLE_GRANT_TYPE: &str = "http://oauth.net/grant_type/device/1.0"; /// It operates in two steps: /// * obtain a code to show to the user // * (repeatedly) poll for the user to authenticate your application -#[derive(Clone)] -pub struct DeviceFlow { - application_secret: ApplicationSecret, - device_code_url: Cow<'static, str>, - flow_delegate: FD, - wait: Duration, - grant_type: Cow<'static, str>, +pub struct DeviceFlow { + pub(crate) device_code_url: Cow<'static, str>, + pub(crate) flow_delegate: Box, + pub(crate) wait_duration: Duration, + pub(crate) grant_type: Cow<'static, str>, } -impl DeviceFlow { +impl DeviceFlow { /// Create a new DeviceFlow. The default FlowDelegate will be used and the /// default wait time is 120 seconds. - pub fn new(secret: ApplicationSecret) -> DeviceFlow { + pub(crate) fn new() -> Self { DeviceFlow { - application_secret: secret, device_code_url: GOOGLE_DEVICE_CODE_URL.into(), - flow_delegate: DefaultFlowDelegate, - wait: Duration::from_secs(120), + flow_delegate: Box::new(DefaultFlowDelegate), + wait_duration: Duration::from_secs(120), grant_type: GOOGLE_GRANT_TYPE.into(), } } -} -impl DeviceFlow { - /// Use the provided device code url. - pub fn device_code_url(self, url: String) -> Self { - DeviceFlow { - device_code_url: url.into(), - ..self - } - } - - /// Use the provided FlowDelegate. - pub fn delegate(self, delegate: NewFD) -> DeviceFlow { - DeviceFlow { - application_secret: self.application_secret, - device_code_url: self.device_code_url, - flow_delegate: delegate, - wait: self.wait, - grant_type: self.grant_type, - } - } - - /// Use the provided wait duration. - pub fn wait_duration(self, duration: Duration) -> Self { - DeviceFlow { - wait: duration, - ..self - } - } - - pub fn grant_type(self, grant_type: String) -> Self { - DeviceFlow { - grant_type: grant_type.into(), - ..self - } - } -} - -impl crate::authenticator::AuthFlow for DeviceFlow -where - FD: FlowDelegate, - C: hyper::client::connect::Connect + 'static, -{ - type TokenGetter = DeviceFlowImpl; - - fn build_token_getter(self, client: hyper::Client) -> Self::TokenGetter { - DeviceFlowImpl { - client, - application_secret: self.application_secret, - device_code_url: self.device_code_url, - fd: self.flow_delegate, - wait: Duration::from_secs(1200), - grant_type: self.grant_type, - } - } -} - -/// The DeviceFlow implementation. -pub struct DeviceFlowImpl { - client: hyper::Client, - application_secret: ApplicationSecret, - /// Usually GOOGLE_DEVICE_CODE_URL - device_code_url: Cow<'static, str>, - fd: FD, - wait: Duration, - grant_type: Cow<'static, str>, -} - -impl GetToken for DeviceFlowImpl -where - FD: FlowDelegate, - C: hyper::client::connect::Connect + 'static, -{ - fn token<'a, T>( - &'a self, - scopes: &'a [T], - ) -> Pin> + Send + 'a>> - where - T: AsRef + Sync, - { - Box::pin(self.retrieve_device_token(scopes)) - } - fn api_key(&self) -> Option { - None - } - fn application_secret(&self) -> &ApplicationSecret { - &self.application_secret - } -} - -impl DeviceFlowImpl -where - C: hyper::client::connect::Connect + 'static, - FD: FlowDelegate, -{ - /// Essentially what `GetToken::token` does: Retrieve a token for the given scopes without - /// caching. - pub async fn retrieve_device_token(&self, scopes: &[T]) -> Result + pub(crate) async fn token( + &self, + hyper_client: &hyper::Client, + app_secret: &ApplicationSecret, + scopes: &[T], + ) -> Result where T: AsRef, + C: hyper::client::connect::Connect + 'static, { - let application_secret = &self.application_secret; - let (pollinf, device_code) = Self::request_code( - application_secret, - &self.client, - &self.device_code_url, - scopes, - ) - .await?; - self.fd.present_user_code(&pollinf); + let (pollinf, device_code) = + Self::request_code(app_secret, hyper_client, &self.device_code_url, scopes).await?; + self.flow_delegate.present_user_code(&pollinf); tokio::timer::Timeout::new( - self.wait_for_device_token(&pollinf, &device_code, &self.grant_type), - self.wait, + self.wait_for_device_token( + hyper_client, + app_secret, + &pollinf, + &device_code, + &self.grant_type, + ), + self.wait_duration, ) .await .map_err(|_| RequestError::Poll(PollError::TimedOut))? } - async fn wait_for_device_token( + async fn wait_for_device_token( &self, + hyper_client: &hyper::Client, + app_secret: &ApplicationSecret, pollinf: &PollInformation, device_code: &str, grant_type: &str, - ) -> Result { + ) -> Result + where + C: hyper::client::connect::Connect + 'static, + { let mut interval = pollinf.interval; loop { tokio::timer::delay_for(interval).await; - let r = Self::poll_token( - &self.application_secret, - &self.client, + interval = match Self::poll_token( + &app_secret, + hyper_client, device_code, grant_type, pollinf.expires_at, - &self.fd, + &*self.flow_delegate as &dyn FlowDelegate, ) - .await; - interval = match r { - Ok(None) => match self.fd.pending(&pollinf) { + .await + { + Ok(None) => match self.flow_delegate.pending(&pollinf) { Retry::Abort | Retry::Skip => { return Err(RequestError::Poll(PollError::TimedOut)) } @@ -213,7 +124,7 @@ where /// * If called after a successful result was returned at least once. /// # Examples /// See test-cases in source code for a more complete example. - async fn request_code( + async fn request_code( application_secret: &ApplicationSecret, client: &hyper::Client, device_code_url: &str, @@ -221,6 +132,7 @@ where ) -> Result<(PollInformation, String), RequestError> where T: AsRef, + C: hyper::client::connect::Connect + 'static, { let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ @@ -288,16 +200,19 @@ where /// /// # Examples /// See test-cases in source code for a more complete example. - async fn poll_token<'a>( + async fn poll_token<'a, C>( application_secret: &ApplicationSecret, client: &hyper::Client, device_code: &str, grant_type: &str, expires_at: DateTime, - fd: &FD, - ) -> Result, PollError> { + flow_delegate: &dyn FlowDelegate, + ) -> Result, PollError> + where + C: hyper::client::connect::Connect + 'static, + { if expires_at <= Utc::now() { - fd.expired(expires_at); + flow_delegate.expired(expires_at); return Err(PollError::Expired(expires_at)); } @@ -334,7 +249,7 @@ where Ok(res) => { match res.error.as_ref() { "access_denied" => { - fd.denied(); + flow_delegate.denied(); return Err(PollError::AccessDenied); } "authorization_pending" => return Ok(None), @@ -364,7 +279,6 @@ mod tests { use tokio; use super::*; - use crate::authenticator::AuthFlow; use crate::helper::parse_application_secret; #[test] @@ -388,10 +302,12 @@ mod tests { .keep_alive(false) .build::<_, hyper::Body>(https); - let flow = DeviceFlow::new(app_secret) - .delegate(FD) - .device_code_url(device_code_url) - .build_token_getter(client); + let flow = DeviceFlow { + device_code_url: device_code_url.into(), + flow_delegate: Box::new(FD), + wait_duration: Duration::from_secs(5), + grant_type: GOOGLE_GRANT_TYPE.into(), + }; let rt = tokio::runtime::Builder::new() .core_threads(1) @@ -420,7 +336,11 @@ mod tests { let fut = async { let token = flow - .token(&["https://www.googleapis.com/scope/1"]) + .token( + &client, + &app_secret, + &["https://www.googleapis.com/scope/1"], + ) .await .unwrap(); assert_eq!("accesstoken", token.access_token); @@ -452,7 +372,13 @@ mod tests { .create(); let fut = async { - let res = flow.token(&["https://www.googleapis.com/scope/1"]).await; + let res = flow + .token( + &client, + &app_secret, + &["https://www.googleapis.com/scope/1"], + ) + .await; assert!(res.is_err()); assert!(format!("{}", res.unwrap_err()).contains("invalid_client_id")); Ok(()) as Result<(), ()> @@ -482,7 +408,13 @@ mod tests { .create(); let fut = async { - let res = flow.token(&["https://www.googleapis.com/scope/1"]).await; + let res = flow + .token( + &client, + &app_secret, + &["https://www.googleapis.com/scope/1"], + ) + .await; assert!(res.is_err()); assert!(format!("{}", res.unwrap_err()).contains("Access denied by user")); Ok(()) as Result<(), ()> diff --git a/src/installed.rs b/src/installed.rs index fa72489..8e1bf3d 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -17,7 +17,7 @@ use url::form_urlencoded; use url::percent_encoding::{percent_encode, QUERY_ENCODE_SET}; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate}; -use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, RequestError, Token}; +use crate::types::{ApplicationSecret, JsonErrorOr, RequestError, Token}; const OOB_REDIRECT_URI: &str = "urn:ietf:wg:oauth:2.0:oob"; @@ -51,40 +51,6 @@ where }) } -impl GetToken for InstalledFlowImpl -where - FD: FlowDelegate, - C: hyper::client::connect::Connect + 'static, -{ - fn token<'a, T>( - &'a self, - scopes: &'a [T], - ) -> Pin> + Send + 'a>> - where - T: AsRef + Sync, - { - Box::pin(self.obtain_token(scopes)) - } - fn api_key(&self) -> Option { - None - } - fn application_secret(&self) -> &ApplicationSecret { - &self.appsecret - } -} - -/// The InstalledFlow implementation. -pub struct InstalledFlowImpl -where - FD: FlowDelegate, - C: hyper::client::connect::Connect, -{ - method: InstalledFlowReturnMethod, - client: hyper::client::Client, - fd: FD, - appsecret: ApplicationSecret, -} - /// cf. https://developers.google.com/identity/protocols/OAuth2InstalledApp#choosingredirecturi pub enum InstalledFlowReturnMethod { /// Involves showing a URL to the user and asking to copy a code from their browser @@ -98,151 +64,133 @@ pub enum InstalledFlowReturnMethod { /// InstalledFlowImpl provides tokens for services that follow the "Installed" OAuth flow. (See /// https://www.oauth.com/oauth2-servers/authorization/, /// https://developers.google.com/identity/protocols/OAuth2InstalledApp). -pub struct InstalledFlow { - method: InstalledFlowReturnMethod, - flow_delegate: FD, - appsecret: ApplicationSecret, +pub struct InstalledFlow { + pub(crate) method: InstalledFlowReturnMethod, + pub(crate) flow_delegate: Box, } -impl InstalledFlow { +impl InstalledFlow { /// Create a new InstalledFlow with the provided secret and method. - pub fn new( - secret: ApplicationSecret, - method: InstalledFlowReturnMethod, - ) -> InstalledFlow { + pub(crate) fn new(method: InstalledFlowReturnMethod) -> InstalledFlow { InstalledFlow { method, - flow_delegate: DefaultFlowDelegate, - appsecret: secret, + flow_delegate: Box::new(DefaultFlowDelegate), } } -} -impl InstalledFlow -where - FD: FlowDelegate, -{ - /// Use the provided FlowDelegate. - pub fn delegate(self, delegate: NewFD) -> InstalledFlow { - InstalledFlow { - method: self.method, - flow_delegate: delegate, - appsecret: self.appsecret, - } - } -} - -impl crate::authenticator::AuthFlow for InstalledFlow -where - FD: FlowDelegate, - C: hyper::client::connect::Connect + 'static, -{ - type TokenGetter = InstalledFlowImpl; - - fn build_token_getter(self, client: hyper::Client) -> Self::TokenGetter { - InstalledFlowImpl { - method: self.method, - fd: self.flow_delegate, - appsecret: self.appsecret, - client, - } - } -} - -impl InstalledFlowImpl -where - FD: FlowDelegate, - C: hyper::client::connect::Connect + 'static, -{ /// Handles the token request flow; it consists of the following steps: /// . Obtain a authorization code with user cooperation or internal redirect. /// . Obtain a token and refresh token using that code. /// . Return that token /// /// It's recommended not to use the DefaultFlowDelegate, but a specialized one. - async fn obtain_token(&self, scopes: &[T]) -> Result + pub(crate) async fn token( + &self, + hyper_client: &hyper::Client, + app_secret: &ApplicationSecret, + scopes: &[T], + ) -> Result where T: AsRef, + C: hyper::client::connect::Connect + 'static, { match self.method { - InstalledFlowReturnMethod::HTTPRedirect => self.ask_auth_code_via_http(scopes).await, + InstalledFlowReturnMethod::HTTPRedirect => { + self.ask_auth_code_via_http(hyper_client, app_secret, scopes) + .await + } InstalledFlowReturnMethod::Interactive => { - self.ask_auth_code_interactively(scopes).await + self.ask_auth_code_interactively(hyper_client, app_secret, scopes) + .await } } } - async fn ask_auth_code_interactively(&self, scopes: &[T]) -> Result + async fn ask_auth_code_interactively( + &self, + hyper_client: &hyper::Client, + app_secret: &ApplicationSecret, + scopes: &[T], + ) -> Result where T: AsRef, + C: hyper::client::connect::Connect + 'static, { - let auth_delegate = &self.fd; - let appsecret = &self.appsecret; let url = build_authentication_request_url( - &appsecret.auth_uri, - &appsecret.client_id, + &app_secret.auth_uri, + &app_secret.client_id, scopes, - auth_delegate.redirect_uri(), + self.flow_delegate.redirect_uri(), ); - let authcode = match auth_delegate + let authcode = match self + .flow_delegate .present_user_url(&url, true /* need code */) .await { Ok(mut code) => { // Partial backwards compatibility in case an implementation adds a new line // due to previous behaviour. - let ends_with_newline = code.chars().last().map(|c| c == '\n').unwrap_or(false); - if ends_with_newline { + if code.ends_with('\n') { code.pop(); } code } _ => return Err(RequestError::UserError("couldn't read code".to_string())), }; - self.exchange_auth_code(&authcode, None).await + self.exchange_auth_code(&authcode, hyper_client, app_secret, None) + .await } - async fn ask_auth_code_via_http(&self, scopes: &[T]) -> Result + async fn ask_auth_code_via_http( + &self, + hyper_client: &hyper::Client, + app_secret: &ApplicationSecret, + scopes: &[T], + ) -> Result where T: AsRef, + C: hyper::client::connect::Connect + 'static, { use std::borrow::Cow; - let auth_delegate = &self.fd; - let appsecret = &self.appsecret; let server = InstalledFlowServer::run()?; let server_addr = server.local_addr(); // Present url to user. // The redirect URI must be this very localhost URL, otherwise authorization is refused // by certain providers. - let redirect_uri: Cow = match auth_delegate.redirect_uri() { + let redirect_uri: Cow = match self.flow_delegate.redirect_uri() { Some(uri) => uri.into(), None => format!("http://{}", server_addr).into(), }; let url = build_authentication_request_url( - &appsecret.auth_uri, - &appsecret.client_id, + &app_secret.auth_uri, + &app_secret.client_id, scopes, Some(redirect_uri.as_ref()), ); - let _ = auth_delegate + let _ = self + .flow_delegate .present_user_url(&url, false /* need code */) .await; let auth_code = server.wait_for_auth_code().await; - self.exchange_auth_code(&auth_code, Some(server_addr)).await + self.exchange_auth_code(&auth_code, hyper_client, app_secret, Some(server_addr)) + .await } - async fn exchange_auth_code( + async fn exchange_auth_code( &self, authcode: &str, + hyper_client: &hyper::Client, + app_secret: &ApplicationSecret, server_addr: Option, - ) -> Result { - let appsec = &self.appsecret; - let redirect_uri = self.fd.redirect_uri(); - let request = Self::request_token(appsec, authcode, redirect_uri, server_addr); - let resp = self - .client + ) -> Result + where + C: hyper::client::connect::Connect + 'static, + { + let redirect_uri = self.flow_delegate.redirect_uri(); + let request = Self::request_token(app_secret, authcode, redirect_uri, server_addr); + let resp = hyper_client .request(request) .await .map_err(RequestError::ClientError)?; @@ -283,7 +231,7 @@ where /// Sends the authorization code to the provider in order to obtain access and refresh tokens. fn request_token( - appsecret: &ApplicationSecret, + app_secret: &ApplicationSecret, authcode: &str, custom_redirect_uri: Option<&str>, server_addr: Option, @@ -298,14 +246,14 @@ where let body = form_urlencoded::Serializer::new(String::new()) .extend_pairs(vec![ ("code", authcode), - ("client_id", appsecret.client_id.as_str()), - ("client_secret", appsecret.client_secret.as_str()), + ("client_id", app_secret.client_id.as_str()), + ("client_secret", app_secret.client_secret.as_str()), ("redirect_uri", redirect_uri.as_ref()), ("grant_type", "authorization_code"), ]) .finish(); - hyper::Request::post(&appsecret.token_uri) + hyper::Request::post(&app_secret.token_uri) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(body)) .unwrap() // TODO: error check @@ -456,7 +404,6 @@ mod tests { use tokio; use super::*; - use crate::authenticator::AuthFlow; use crate::authenticator_delegate::FlowDelegate; use crate::helper::*; use crate::types::StringError; @@ -523,9 +470,10 @@ mod tests { .build::<_, hyper::Body>(https); let fd = FD("authorizationcode".to_string(), client.clone()); - let inf = InstalledFlow::new(app_secret.clone(), InstalledFlowReturnMethod::Interactive) - .delegate(fd) - .build_token_getter(client.clone()); + let inf = InstalledFlow { + method: InstalledFlowReturnMethod::Interactive, + flow_delegate: Box::new(fd), + }; let rt = tokio::runtime::Builder::new() .core_threads(1) @@ -544,7 +492,7 @@ mod tests { let fut = || { async { let tok = inf - .token(&["https://googleapis.com/some/scope"]) + .token(&client, &app_secret, &["https://googleapis.com/some/scope"]) .await .map_err(|_| ())?; assert_eq!("accesstoken", tok.access_token); @@ -558,12 +506,13 @@ mod tests { } // Successful path with HTTP redirect. { - let inf = InstalledFlow::new(app_secret, InstalledFlowReturnMethod::HTTPRedirect) - .delegate(FD( + let inf = InstalledFlow { + method: InstalledFlowReturnMethod::HTTPRedirect, + flow_delegate: Box::new(FD( "authorizationcodefromlocalserver".to_string(), client.clone(), - )) - .build_token_getter(client.clone()); + )), + }; let _m = mock("POST", "/token") .match_body(mockito::Matcher::Regex(".*code=authorizationcodefromlocalserver.*client_id=9022167.*".to_string())) .with_body(r#"{"access_token": "accesstoken", "refresh_token": "refreshtoken", "token_type": "Bearer", "expires_in": 12345678}"#) @@ -572,7 +521,7 @@ mod tests { let fut = async { let tok = inf - .token(&["https://googleapis.com/some/scope"]) + .token(&client, &app_secret, &["https://googleapis.com/some/scope"]) .await .map_err(|_| ())?; assert_eq!("accesstoken", tok.access_token); @@ -595,7 +544,9 @@ mod tests { .create(); let fut = async { - let tokr = inf.token(&["https://googleapis.com/some/scope"]).await; + let tokr = inf + .token(&client, &app_secret, &["https://googleapis.com/some/scope"]) + .await; assert!(tokr.is_err()); assert!(format!("{}", tokr.unwrap_err()).contains("invalid_code")); Ok(()) as Result<(), ()> diff --git a/src/lib.rs b/src/lib.rs index 549ea28..152f5b4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,12 +38,7 @@ //! `examples/test-installed/`, shows the basics of using this crate: //! //! ```test_harness,no_run -//! use futures::prelude::*; -//! use yup_oauth2::GetToken; -//! use yup_oauth2::{Authenticator, InstalledFlow}; -//! -//! use hyper::client::Client; -//! use hyper_rustls::HttpsConnector; +//! use yup_oauth2::{InstalledFlowAuthenticator, InstalledFlowReturnMethod}; //! //! #[tokio::main] //! async fn main() { @@ -53,12 +48,10 @@ //! .expect("clientsecret.json"); //! //! // Create an authenticator that uses an InstalledFlow to authenticate. The -//! // authentication tokens are persisted to a file named tokencache.json. The -//! // authenticator takes care of caching tokens to disk and refreshing tokens once -//! // they've expired. -//! let mut auth = Authenticator::new( -//! InstalledFlow::new(secret, yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect) -//! ) +//! // authentication tokens are persisted to a file named tokencache.json. The +//! // authenticator takes care of caching tokens to disk and refreshing tokens once +//! // they've expired. +//! let mut auth = InstalledFlowAuthenticator::builder(secret, InstalledFlowReturnMethod::HTTPRedirect) //! .persist_tokens_to_disk("tokencache.json") //! .build() //! .await @@ -88,16 +81,18 @@ mod service_account; mod storage; mod types; -pub use crate::authenticator::{AuthFlow, Authenticator}; +pub use crate::authenticator::{ + Authenticator, AuthenticatorBuilder, DeviceFlowAuthenticator, InstalledFlowAuthenticator, +}; pub use crate::authenticator_delegate::{ AuthenticatorDelegate, DefaultAuthenticatorDelegate, DefaultFlowDelegate, FlowDelegate, PollInformation, }; -pub use crate::device::{DeviceFlow, GOOGLE_DEVICE_CODE_URL}; +pub use crate::device::GOOGLE_DEVICE_CODE_URL; pub use crate::helper::*; -pub use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; +pub use crate::installed::InstalledFlowReturnMethod; pub use crate::service_account::*; pub use crate::types::{ - ApplicationSecret, ConsoleApplicationSecret, GetToken, PollError, RefreshResult, RequestError, - Scheme, Token, TokenType, + ApplicationSecret, ConsoleApplicationSecret, PollError, RefreshResult, RequestError, Scheme, + Token, TokenType, }; diff --git a/src/service_account.rs b/src/service_account.rs index 1b10de8..eb013e7 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -11,12 +11,11 @@ //! Copyright (c) 2016 Google Inc (lewinb@google.com). //! -use std::pin::Pin; use std::sync::Mutex; use crate::authenticator::{DefaultHyperClient, HyperClientBuilder}; use crate::storage::{self, Storage}; -use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, RequestError, Token}; +use crate::types::{JsonErrorOr, RequestError, Token}; use futures::prelude::*; use hyper::header; @@ -155,20 +154,10 @@ impl JWTSigner { } } -/// A token source (`GetToken`) yielding OAuth tokens for services that use ServiceAccount authorization. -/// This token source caches token and automatically renews expired ones, meaning you do not need -/// (and you also should not) use this with `Authenticator`. Just use it directly. -#[derive(Clone)] -pub struct ServiceAccountAccess { - client: C, - key: ServiceAccountKey, - subject: Option, -} - -impl ServiceAccountAccess { - /// Create a new ServiceAccountAccess with the provided key. - pub fn new(key: ServiceAccountKey) -> Self { - ServiceAccountAccess { +pub struct ServiceAccountAuthenticator; +impl ServiceAccountAuthenticator { + pub fn builder(key: ServiceAccountKey) -> Builder { + Builder { client: DefaultHyperClient, key, subject: None, @@ -176,16 +165,16 @@ impl ServiceAccountAccess { } } -impl ServiceAccountAccess -where - C: HyperClientBuilder, -{ +pub struct Builder { + client: C, + key: ServiceAccountKey, + subject: Option, +} + +impl Builder { /// Use the provided hyper client. - pub fn hyper_client( - self, - hyper_client: NewC, - ) -> ServiceAccountAccess { - ServiceAccountAccess { + pub fn hyper_client(self, hyper_client: NewC) -> Builder { + Builder { client: hyper_client, key: self.key, subject: self.subject, @@ -194,29 +183,32 @@ where /// Use the provided subject. pub fn subject(self, subject: String) -> Self { - ServiceAccountAccess { + Builder { subject: Some(subject), ..self } } /// Build the configured ServiceAccountAccess. - pub fn build(self) -> Result { - ServiceAccountAccessImpl::new(self.client.build_hyper_client(), self.key, self.subject) + pub fn build(self) -> Result, io::Error> + where + C: HyperClientBuilder, + { + ServiceAccountAccess::new(self.client.build_hyper_client(), self.key, self.subject) } } -struct ServiceAccountAccessImpl { - client: hyper::Client, +pub struct ServiceAccountAccess { + client: hyper::Client, key: ServiceAccountKey, cache: Storage, subject: Option, signer: JWTSigner, } -impl ServiceAccountAccessImpl +impl ServiceAccountAccess where - C: hyper::client::connect::Connect, + C: hyper::client::connect::Connect + 'static, { fn new( client: hyper::Client, @@ -224,7 +216,7 @@ where subject: Option, ) -> Result { let signer = JWTSigner::new(&key.private_key)?; - Ok(ServiceAccountAccessImpl { + Ok(ServiceAccountAccess { client, key, cache: Storage::Memory { @@ -234,20 +226,28 @@ where signer, }) } -} -/// This is the schema of the server's response. -#[derive(Deserialize, Debug)] -struct TokenResponse { - access_token: Option, - token_type: Option, - expires_in: Option, -} - -impl ServiceAccountAccessImpl -where - C: hyper::client::connect::Connect + 'static, -{ + pub async fn token(&self, scopes: &[T]) -> Result + where + T: AsRef, + { + let hash = storage::ScopeHash::new(scopes); + let cache = &self.cache; + match cache.get(hash, scopes) { + Some(token) if !token.expired() => return Ok(token), + _ => {} + } + let token = Self::request_token( + &self.client, + &self.signer, + self.subject.as_ref().map(|x| x.as_str()), + &self.key, + scopes, + ) + .await?; + cache.set(hash, scopes, Some(token.clone())).await; + Ok(token) + } /// Send a request for a new Bearer token to the OAuth provider. async fn request_token( client: &hyper::client::Client, @@ -282,6 +282,15 @@ where .try_concat() .await .map_err(RequestError::ClientError)?; + + /// This is the schema of the server's response. + #[derive(Deserialize, Debug)] + struct TokenResponse { + access_token: Option, + token_type: Option, + expires_in: Option, + } + match serde_json::from_slice::>(&body)? { JsonErrorOr::Err(err) => Err(err.into()), JsonErrorOr::Data(TokenResponse { @@ -305,61 +314,12 @@ where ))), } } - - async fn get_token(&self, scopes: &[T]) -> Result - where - T: AsRef, - { - let hash = storage::ScopeHash::new(scopes); - let cache = &self.cache; - match cache.get(hash, scopes) { - Some(token) if !token.expired() => return Ok(token), - _ => {} - } - let token = Self::request_token( - &self.client, - &self.signer, - self.subject.as_ref().map(|x| x.as_str()), - &self.key, - scopes, - ) - .await?; - cache.set(hash, scopes, Some(token.clone())).await; - Ok(token) - } -} - -impl GetToken for ServiceAccountAccessImpl -where - C: hyper::client::connect::Connect + 'static, -{ - fn token<'a, T>( - &'a self, - scopes: &'a [T], - ) -> Pin> + Send + 'a>> - where - T: AsRef + Sync, - { - Box::pin(self.get_token(scopes)) - } - - /// Returns an empty ApplicationSecret as tokens for service accounts don't need to be - /// refreshed (they are simply reissued). - fn application_secret(&self) -> &ApplicationSecret { - static APP_SECRET: ApplicationSecret = ApplicationSecret::empty(); - &APP_SECRET - } - - fn api_key(&self) -> Option { - None - } } #[cfg(test)] mod tests { use super::*; use crate::helper::service_account_key_from_file; - use crate::types::GetToken; use hyper; use hyper_rustls::HttpsConnector; @@ -413,7 +373,7 @@ mod tests { .with_body(json_response) .expect(1) .create(); - let acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None).unwrap(); + let acc = ServiceAccountAccess::new(client.clone(), key.clone(), None).unwrap(); let fut = async { let tok = acc .token(&["https://www.googleapis.com/auth/pubsub"]) @@ -453,7 +413,7 @@ mod tests { .with_header("content-type", "text/json") .with_body(bad_json_response) .create(); - let acc = ServiceAccountAccess::new(key.clone()) + let acc = ServiceAccountAuthenticator::builder(key.clone()) .hyper_client(client.clone()) .build() .unwrap(); @@ -478,7 +438,7 @@ mod tests { let key = service_account_key_from_file(&TEST_PRIVATE_KEY_PATH.to_string()).unwrap(); let https = HttpsConnector::new(); let client = hyper::Client::builder().build(https); - let acc = ServiceAccountAccess::new(key) + let acc = ServiceAccountAuthenticator::builder(key) .hyper_client(client) .build() .unwrap(); diff --git a/src/types.rs b/src/types.rs index a9d371c..796cb90 100644 --- a/src/types.rs +++ b/src/types.rs @@ -3,11 +3,8 @@ use hyper; use std::error::Error; use std::fmt; use std::io; -use std::pin::Pin; use std::str::FromStr; -use futures::prelude::*; - #[derive(Deserialize, Debug)] pub struct JsonError { pub error: String, @@ -246,24 +243,6 @@ impl FromStr for Scheme { } } -/// A provider for authorization tokens, yielding tokens valid for a given scope. -/// The `api_key()` method is an alternative in case there are no scopes or -/// if no user is involved. -pub trait GetToken: Send + Sync { - fn token<'a, T>( - &'a self, - scopes: &'a [T], - ) -> Pin> + Send + 'a>> - where - T: AsRef + Sync; - - fn api_key(&self) -> Option; - - /// Return an application secret with at least token_uri, client_secret, and client_id filled - /// in. This is used for refreshing tokens without interaction from the flow. - fn application_secret(&self) -> &ApplicationSecret; -} - /// Represents a token as returned by OAuth2 servers. /// /// It is produced by all authentication flows. From 0fe66619dd7d04fac151de712205225a047ed80a Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Wed, 13 Nov 2019 13:23:37 -0800 Subject: [PATCH 28/71] Minimize the number of items on the rustdoc landing page. Restructure the modules and imports to increase the signal to noise ration on the cargo doc landing page. This includes exposing some modules as public so that they can contain things that need to be public but that users will rarely need to interact with. Most items from types.rs were moved into an error.rs module that is now exposed publicly. --- src/authenticator.rs | 41 +++--- src/authenticator_delegate.rs | 11 +- src/device.rs | 26 ++-- src/error.rs | 174 ++++++++++++++++++++++ src/installed.rs | 41 +++--- src/lib.rs | 30 ++-- src/refresh.rs | 65 +++------ src/service_account.rs | 12 +- src/types.rs | 265 ---------------------------------- 9 files changed, 264 insertions(+), 401 deletions(-) create mode 100644 src/error.rs diff --git a/src/authenticator.rs b/src/authenticator.rs index 1b5a2de..b1824e4 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -2,14 +2,14 @@ use crate::authenticator_delegate::{ AuthenticatorDelegate, DefaultAuthenticatorDelegate, FlowDelegate, }; use crate::device::DeviceFlow; +use crate::error::RequestError; use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; use crate::refresh::RefreshFlow; use crate::storage::{self, Storage}; -use crate::types::{ApplicationSecret, RefreshResult, RequestError, Token}; +use crate::types::{ApplicationSecret, Token}; use private::AuthFlow; use std::borrow::Cow; -use std::error::Error; use std::io; use std::path::PathBuf; use std::sync::Mutex; @@ -42,32 +42,23 @@ where .. }) => { // token is expired but has a refresh token. - let rr = RefreshFlow::refresh_token( + let token = match RefreshFlow::refresh_token( &self.hyper_client, &self.app_secret, &refresh_token, ) - .await?; - match rr { - RefreshResult::Error(ref e) => { - self.auth_delegate.token_refresh_failed( - e.description(), - Some("the request has likely timed out"), - ); - Err(RequestError::Refresh(rr)) + .await + { + Err(err) => { + self.auth_delegate.token_refresh_failed(&err); + return Err(err.into()); } - RefreshResult::RefreshError(ref s, ref ss) => { - self.auth_delegate.token_refresh_failed( - &format!("{}{}", s, ss.as_ref().map(|s| format!(" ({})", s)).unwrap_or_else(String::new)), - Some("the refresh token is likely invalid and your authorization has been revoked"), - ); - Err(RequestError::Refresh(rr)) - } - RefreshResult::Success(t) => { - self.storage.set(scope_key, scopes, Some(t.clone())).await; - Ok(t) - } - } + Ok(token) => token, + }; + self.storage + .set(scope_key, scopes, Some(token.clone())) + .await; + Ok(token) } None | Some(Token { @@ -248,8 +239,10 @@ impl AuthenticatorBuilder { mod private { use crate::device::DeviceFlow; + use crate::error::RequestError; use crate::installed::InstalledFlow; - use crate::types::{ApplicationSecret, RequestError, Token}; + use crate::types::{ApplicationSecret, Token}; + pub enum AuthFlow { DeviceFlow(DeviceFlow), InstalledFlow(InstalledFlow), diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index b8fe106..92f5329 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -4,7 +4,7 @@ use std::error::Error; use std::fmt; use std::pin::Pin; -use crate::types::{PollError, RequestError}; +use crate::error::{PollError, RefreshError, RequestError}; use chrono::{DateTime, Local, Utc}; use std::time::Duration; @@ -85,14 +85,7 @@ pub trait AuthenticatorDelegate: Send + Sync { /// Called if we could not acquire a refresh token for a reason possibly specified /// by the server. /// This call is made for the delegate's information only. - fn token_refresh_failed(&self, error: &str, error_description: Option<&str>) { - { - let _ = error; - } - { - let _ = error_description; - } - } + fn token_refresh_failed(&self, _: &RefreshError) {} } /// FlowDelegate methods are called when an OAuth flow needs to ask the application what to do in diff --git a/src/device.rs b/src/device.rs index 373c36a..e1bdff6 100644 --- a/src/device.rs +++ b/src/device.rs @@ -10,7 +10,8 @@ use serde_json as json; use url::form_urlencoded; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate, PollInformation, Retry}; -use crate::types::{ApplicationSecret, JsonErrorOr, PollError, RequestError, Token}; +use crate::error::{JsonErrorOr, PollError, RequestError}; +use crate::types::{ApplicationSecret, Token}; pub const GOOGLE_DEVICE_CODE_URL: &str = "https://accounts.google.com/o/oauth2/device/code"; @@ -166,20 +167,15 @@ impl DeviceFlow { } let json_bytes = resp.into_body().try_concat().await?; - match json::from_slice::>(&json_bytes)? { - JsonErrorOr::Err(e) => Err(e.into()), - JsonErrorOr::Data(decoded) => { - let expires_in = decoded.expires_in.unwrap_or(60 * 60); - - let pi = PollInformation { - user_code: decoded.user_code, - verification_url: decoded.verification_uri, - expires_at: Utc::now() + chrono::Duration::seconds(expires_in), - interval: Duration::from_secs(i64::abs(decoded.interval) as u64), - }; - Ok((pi, decoded.device_code)) - } - } + let decoded: JsonData = json::from_slice::>(&json_bytes)?.into_result()?; + let expires_in = decoded.expires_in.unwrap_or(60 * 60); + let pi = PollInformation { + user_code: decoded.user_code, + verification_url: decoded.verification_uri, + expires_at: Utc::now() + chrono::Duration::seconds(expires_in), + interval: Duration::from_secs(i64::abs(decoded.interval) as u64), + }; + Ok((pi, decoded.device_code)) } /// If the first call is successful, this method may be called. diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..191007a --- /dev/null +++ b/src/error.rs @@ -0,0 +1,174 @@ +use std::error::Error; +use std::fmt; +use std::io; + +use chrono::{DateTime, Utc}; + +#[derive(Deserialize, Debug)] +pub(crate) struct JsonError { + pub error: String, + pub error_description: Option, + pub error_uri: Option, +} + +/// A helper type to deserialize either a JsonError or another piece of data. +#[derive(Deserialize, Debug)] +#[serde(untagged)] +pub(crate) enum JsonErrorOr { + Err(JsonError), + Data(T), +} + +impl JsonErrorOr { + pub(crate) fn into_result(self) -> Result { + match self { + JsonErrorOr::Err(err) => Result::Err(err), + JsonErrorOr::Data(value) => Result::Ok(value), + } + } +} + +/// Encapsulates all possible results of a `poll_token(...)` operation in the Device flow. +#[derive(Debug)] +pub enum PollError { + /// Connection failure - retry if you think it's worth it + HttpError(hyper::Error), + /// Indicates we are expired, including the expiration date + Expired(DateTime), + /// Indicates that the user declined access. String is server response + AccessDenied, + /// Indicates that too many attempts failed. + TimedOut, + /// Other type of error. + Other(String), +} + +/// Encapsulates all possible results of the `token(...)` operation +#[derive(Debug)] +pub enum RequestError { + /// Indicates connection failure + ClientError(hyper::Error), + /// The OAuth client was not found + InvalidClient, + /// Some requested scopes were invalid. String contains the scopes as part of + /// the server error message + InvalidScope(String), + /// A 'catch-all' variant containing the server error and description + /// First string is the error code, the second may be a more detailed description + NegativeServerResponse(String, Option), + /// A malformed server response. + BadServerResponse(String), + /// Error while decoding a JSON response. + JSONError(serde_json::Error), + /// Error within user input. + UserError(String), + /// A lower level IO error. + LowLevelError(io::Error), + /// A poll error occurred in the DeviceFlow. + Poll(PollError), + /// An error occurred while refreshing tokens. + Refresh(RefreshError), + /// Error in token cache layer + Cache(Box), +} + +impl From for RequestError { + fn from(error: hyper::Error) -> RequestError { + RequestError::ClientError(error) + } +} + +impl From for RequestError { + fn from(value: JsonError) -> RequestError { + match &*value.error { + "invalid_client" => RequestError::InvalidClient, + "invalid_scope" => RequestError::InvalidScope( + value + .error_description + .unwrap_or_else(|| "no description provided".to_string()), + ), + _ => RequestError::NegativeServerResponse(value.error, value.error_description), + } + } +} + +impl From for RequestError { + fn from(value: serde_json::Error) -> RequestError { + RequestError::JSONError(value) + } +} + +impl From for RequestError { + fn from(value: RefreshError) -> RequestError { + RequestError::Refresh(value) + } +} + +impl fmt::Display for RequestError { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + match *self { + RequestError::ClientError(ref err) => err.fmt(f), + RequestError::InvalidClient => "Invalid Client".fmt(f), + RequestError::InvalidScope(ref scope) => writeln!(f, "Invalid Scope: '{}'", scope), + RequestError::NegativeServerResponse(ref error, ref desc) => { + error.fmt(f)?; + if let Some(ref desc) = *desc { + write!(f, ": {}", desc)?; + } + "\n".fmt(f) + } + RequestError::BadServerResponse(ref s) => s.fmt(f), + RequestError::JSONError(ref e) => format!( + "JSON Error; this might be a bug with unexpected server responses! {}", + e + ) + .fmt(f), + RequestError::UserError(ref s) => s.fmt(f), + RequestError::LowLevelError(ref e) => e.fmt(f), + RequestError::Poll(ref pe) => pe.fmt(f), + RequestError::Refresh(ref rr) => format!("{:?}", rr).fmt(f), + RequestError::Cache(ref e) => e.fmt(f), + } + } +} + +impl Error for RequestError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match *self { + RequestError::ClientError(ref err) => Some(err), + RequestError::LowLevelError(ref err) => Some(err), + RequestError::JSONError(ref err) => Some(err), + _ => None, + } + } +} + +/// All possible outcomes of the refresh flow +#[derive(Debug)] +pub enum RefreshError { + /// Indicates connection failure + ConnectionError(hyper::Error), + /// The server did not answer with a new token, providing the server message + ServerError(String, Option), +} + +impl From for RefreshError { + fn from(value: hyper::Error) -> Self { + RefreshError::ConnectionError(value) + } +} + +impl From for RefreshError { + fn from(value: JsonError) -> Self { + RefreshError::ServerError(value.error, value.error_description) + } +} + +impl From for RefreshError { + fn from(_value: serde_json::Error) -> Self { + RefreshError::ServerError( + "failed to deserialize json token from refresh response".to_owned(), + None, + ) + } +} diff --git a/src/installed.rs b/src/installed.rs index 8e1bf3d..b70979f 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -17,7 +17,8 @@ use url::form_urlencoded; use url::percent_encoding::{percent_encode, QUERY_ENCODE_SET}; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate}; -use crate::types::{ApplicationSecret, JsonErrorOr, RequestError, Token}; +use crate::error::{JsonErrorOr, RequestError}; +use crate::types::{ApplicationSecret, Token}; const OOB_REDIRECT_URI: &str = "urn:ietf:wg:oauth:2.0:oob"; @@ -208,25 +209,21 @@ impl InstalledFlow { expires_in: Option, } - match serde_json::from_slice::>(&body)? { - JsonErrorOr::Err(err) => Err(err.into()), - JsonErrorOr::Data(JSONTokenResponse { - access_token, - refresh_token, - token_type, - expires_in, - }) => { - let mut token = Token { - access_token, - refresh_token, - token_type, - expires_in, - expires_in_timestamp: None, - }; - token.set_expiry_absolute(); - Ok(token) - } - } + let JSONTokenResponse { + access_token, + refresh_token, + token_type, + expires_in, + } = serde_json::from_slice::>(&body)?.into_result()?; + let mut token = Token { + access_token, + refresh_token, + token_type, + expires_in, + expires_in_timestamp: None, + }; + token.set_expiry_absolute(); + Ok(token) } /// Sends the authorization code to the provider in order to obtain access and refresh tokens. @@ -406,7 +403,6 @@ mod tests { use super::*; use crate::authenticator_delegate::FlowDelegate; use crate::helper::*; - use crate::types::StringError; #[test] fn test_end2end() { @@ -442,8 +438,7 @@ mod tests { } } if rduri.is_none() { - return Err(Box::new(StringError::new("no redirect uri!", None)) - as Box); + return Err("no redirect_uri!".into()); } let mut rduri = rduri.unwrap(); rduri.push_str(&format!("?code={}", self.0)); diff --git a/src/lib.rs b/src/lib.rs index 152f5b4..1e63c14 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,28 +71,26 @@ #[macro_use] extern crate serde_derive; -mod authenticator; -mod authenticator_delegate; +pub mod authenticator; +pub mod authenticator_delegate; mod device; +pub mod error; mod helper; mod installed; mod refresh; -mod service_account; +pub mod service_account; mod storage; mod types; -pub use crate::authenticator::{ - Authenticator, AuthenticatorBuilder, DeviceFlowAuthenticator, InstalledFlowAuthenticator, -}; -pub use crate::authenticator_delegate::{ - AuthenticatorDelegate, DefaultAuthenticatorDelegate, DefaultFlowDelegate, FlowDelegate, - PollInformation, -}; -pub use crate::device::GOOGLE_DEVICE_CODE_URL; +#[doc(inline)] +pub use crate::authenticator::{DeviceFlowAuthenticator, InstalledFlowAuthenticator}; + pub use crate::helper::*; pub use crate::installed::InstalledFlowReturnMethod; -pub use crate::service_account::*; -pub use crate::types::{ - ApplicationSecret, ConsoleApplicationSecret, PollError, RefreshResult, RequestError, Scheme, - Token, TokenType, -}; + +#[doc(inline)] +pub use crate::service_account::{ServiceAccountAuthenticator, ServiceAccountKey}; + +#[doc(inline)] +pub use crate::error::RequestError; +pub use crate::types::{ApplicationSecret, ConsoleApplicationSecret, Token}; diff --git a/src/refresh.rs b/src/refresh.rs index 7b33269..f6826be 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -1,4 +1,5 @@ -use crate::types::{ApplicationSecret, JsonErrorOr, RefreshResult, RequestError}; +use crate::error::{JsonErrorOr, RefreshError}; +use crate::types::ApplicationSecret; use super::Token; use chrono::Utc; @@ -33,7 +34,7 @@ impl RefreshFlow { client: &hyper::Client, client_secret: &ApplicationSecret, refresh_token: &str, - ) -> Result { + ) -> Result { // TODO: Does this function ever return RequestError? Maybe have it just return RefreshResult. let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ @@ -49,14 +50,8 @@ impl RefreshFlow { .body(hyper::Body::from(req)) .unwrap(); // TODO: error handling - let resp = match client.request(request).await { - Ok(resp) => resp, - Err(err) => return Ok(RefreshResult::Error(err)), - }; - let body = match resp.into_body().try_concat().await { - Ok(body) => body, - Err(err) => return Ok(RefreshResult::Error(err)), - }; + let resp = client.request(request).await?; + let body = resp.into_body().try_concat().await?; #[derive(Deserialize)] struct JsonToken { @@ -65,27 +60,18 @@ impl RefreshFlow { expires_in: i64, } - match serde_json::from_slice::>(&body) { - Err(_) => Ok(RefreshResult::RefreshError( - "failed to deserialized json token from refresh response".to_owned(), - None, - )), - Ok(JsonErrorOr::Err(json_err)) => Ok(RefreshResult::RefreshError( - json_err.error, - json_err.error_description, - )), - Ok(JsonErrorOr::Data(JsonToken { - access_token, - token_type, - expires_in, - })) => Ok(RefreshResult::Success(Token { - access_token, - token_type, - refresh_token: Some(refresh_token.to_string()), - expires_in: None, - expires_in_timestamp: Some(Utc::now().timestamp() + expires_in), - })), - } + let JsonToken { + access_token, + token_type, + expires_in, + } = serde_json::from_slice::>(&body)?.into_result()?; + Ok(Token { + access_token, + token_type, + refresh_token: Some(refresh_token.to_string()), + expires_in: None, + expires_in_timestamp: Some(Utc::now().timestamp() + expires_in), + }) } } @@ -128,16 +114,11 @@ mod tests { .with_body(r#"{"access_token": "new-access-token", "token_type": "Bearer", "expires_in": 1234567}"#) .create(); let fut = async { - let rr = RefreshFlow::refresh_token(&client, &app_secret, refresh_token) + let token = RefreshFlow::refresh_token(&client, &app_secret, refresh_token) .await .unwrap(); - match rr { - RefreshResult::Success(tok) => { - assert_eq!("new-access-token", tok.access_token); - assert_eq!("Bearer", tok.token_type); - } - _ => panic!(format!("unexpected RefreshResult {:?}", rr)), - } + assert_eq!("new-access-token", token.access_token); + assert_eq!("Bearer", token.token_type); Ok(()) as Result<(), ()> }; @@ -154,11 +135,9 @@ mod tests { .create(); let fut = async { - let rr = RefreshFlow::refresh_token(&client, &app_secret, refresh_token) - .await - .unwrap(); + let rr = RefreshFlow::refresh_token(&client, &app_secret, refresh_token).await; match rr { - RefreshResult::RefreshError(e, None) => { + Err(RefreshError::ServerError(e, None)) => { assert_eq!(e, "invalid_token"); } _ => panic!(format!("unexpected RefreshResult {:?}", rr)), diff --git a/src/service_account.rs b/src/service_account.rs index eb013e7..c2732a2 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -14,8 +14,9 @@ use std::sync::Mutex; use crate::authenticator::{DefaultHyperClient, HyperClientBuilder}; +use crate::error::{JsonErrorOr, RequestError}; use crate::storage::{self, Storage}; -use crate::types::{JsonErrorOr, RequestError, Token}; +use crate::types::Token; use futures::prelude::*; use hyper::header; @@ -291,14 +292,13 @@ where expires_in: Option, } - match serde_json::from_slice::>(&body)? { - JsonErrorOr::Err(err) => Err(err.into()), - JsonErrorOr::Data(TokenResponse { + match serde_json::from_slice::>(&body)?.into_result()? { + TokenResponse { access_token: Some(access_token), token_type: Some(token_type), expires_in: Some(expires_in), .. - }) => { + } => { let expires_ts = chrono::Utc::now().timestamp() + expires_in; Ok(Token { access_token, @@ -308,7 +308,7 @@ where expires_in_timestamp: Some(expires_ts), }) } - JsonErrorOr::Data(token) => Err(RequestError::BadServerResponse(format!( + token => Err(RequestError::BadServerResponse(format!( "Token response lacks fields: {:?}", token ))), diff --git a/src/types.rs b/src/types.rs index 796cb90..cb044a1 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,247 +1,4 @@ use chrono::{DateTime, TimeZone, Utc}; -use hyper; -use std::error::Error; -use std::fmt; -use std::io; -use std::str::FromStr; - -#[derive(Deserialize, Debug)] -pub struct JsonError { - pub error: String, - pub error_description: Option, - pub error_uri: Option, -} - -/// A helper type to deserialize either a JsonError or another piece of data. -#[derive(Deserialize, Debug)] -#[serde(untagged)] -pub enum JsonErrorOr { - Err(JsonError), - Data(T), -} - -/// All possible outcomes of the refresh flow -#[derive(Debug)] -pub enum RefreshResult { - /// Indicates connection failure - Error(hyper::Error), - /// The server did not answer with a new token, providing the server message - RefreshError(String, Option), - /// The refresh operation finished successfully, providing a new `Token` - Success(Token), -} - -/// Encapsulates all possible results of a `poll_token(...)` operation in the Device flow. -#[derive(Debug)] -pub enum PollError { - /// Connection failure - retry if you think it's worth it - HttpError(hyper::Error), - /// Indicates we are expired, including the expiration date - Expired(DateTime), - /// Indicates that the user declined access. String is server response - AccessDenied, - /// Indicates that too many attempts failed. - TimedOut, - /// Other type of error. - Other(String), -} - -/// Encapsulates all possible results of the `token(...)` operation -#[derive(Debug)] -pub enum RequestError { - /// Indicates connection failure - ClientError(hyper::Error), - /// The OAuth client was not found - InvalidClient, - /// Some requested scopes were invalid. String contains the scopes as part of - /// the server error message - InvalidScope(String), - /// A 'catch-all' variant containing the server error and description - /// First string is the error code, the second may be a more detailed description - NegativeServerResponse(String, Option), - /// A malformed server response. - BadServerResponse(String), - /// Error while decoding a JSON response. - JSONError(serde_json::Error), - /// Error within user input. - UserError(String), - /// A lower level IO error. - LowLevelError(io::Error), - /// A poll error occurred in the DeviceFlow. - Poll(PollError), - /// An error occurred while refreshing tokens. - Refresh(RefreshResult), - /// Error in token cache layer - Cache(Box), -} - -impl From for RequestError { - fn from(error: hyper::Error) -> RequestError { - RequestError::ClientError(error) - } -} - -impl From for RequestError { - fn from(value: JsonError) -> RequestError { - match &*value.error { - "invalid_client" => RequestError::InvalidClient, - "invalid_scope" => RequestError::InvalidScope( - value - .error_description - .unwrap_or_else(|| "no description provided".to_string()), - ), - _ => RequestError::NegativeServerResponse(value.error, value.error_description), - } - } -} - -impl From for RequestError { - fn from(value: serde_json::Error) -> RequestError { - RequestError::JSONError(value) - } -} - -impl fmt::Display for RequestError { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - match *self { - RequestError::ClientError(ref err) => err.fmt(f), - RequestError::InvalidClient => "Invalid Client".fmt(f), - RequestError::InvalidScope(ref scope) => writeln!(f, "Invalid Scope: '{}'", scope), - RequestError::NegativeServerResponse(ref error, ref desc) => { - error.fmt(f)?; - if let Some(ref desc) = *desc { - write!(f, ": {}", desc)?; - } - "\n".fmt(f) - } - RequestError::BadServerResponse(ref s) => s.fmt(f), - RequestError::JSONError(ref e) => format!( - "JSON Error; this might be a bug with unexpected server responses! {}", - e - ) - .fmt(f), - RequestError::UserError(ref s) => s.fmt(f), - RequestError::LowLevelError(ref e) => e.fmt(f), - RequestError::Poll(ref pe) => pe.fmt(f), - RequestError::Refresh(ref rr) => format!("{:?}", rr).fmt(f), - RequestError::Cache(ref e) => e.fmt(f), - } - } -} - -impl Error for RequestError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - match *self { - RequestError::ClientError(ref err) => Some(err), - RequestError::LowLevelError(ref err) => Some(err), - RequestError::JSONError(ref err) => Some(err), - _ => None, - } - } -} - -#[derive(Debug)] -pub struct StringError { - error: String, -} - -impl fmt::Display for StringError { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - self.description().fmt(f) - } -} - -impl StringError { - pub fn new>(error: S, desc: Option) -> StringError { - let mut error = error.as_ref().to_string(); - if let Some(d) = desc { - error.push_str(": "); - error.push_str(d.as_ref()); - } - - StringError { error } - } -} - -impl<'a> From<&'a dyn Error> for StringError { - fn from(err: &'a dyn Error) -> StringError { - StringError::new(err.description().to_string(), None) - } -} - -impl From for StringError { - fn from(value: String) -> StringError { - StringError::new(value, None) - } -} - -impl Error for StringError { - fn description(&self) -> &str { - &self.error - } -} - -/// Represents all implemented token types -#[derive(Clone, PartialEq, Debug)] -pub enum TokenType { - /// Means that whoever bears the access token will be granted access - Bearer, -} - -impl AsRef for TokenType { - fn as_ref(&self) -> &'static str { - match *self { - TokenType::Bearer => "Bearer", - } - } -} - -impl FromStr for TokenType { - type Err = (); - fn from_str(s: &str) -> Result { - match s { - "Bearer" => Ok(TokenType::Bearer), - _ => Err(()), - } - } -} - -/// A scheme for use in `hyper::header::Authorization` -#[derive(Clone, PartialEq, Debug)] -pub struct Scheme { - /// The type of our access token - pub token_type: TokenType, - /// The token returned by one of the Authorization Flows - pub access_token: String, -} - -impl std::convert::Into for Scheme { - fn into(self) -> hyper::header::HeaderValue { - hyper::header::HeaderValue::from_str(&format!( - "{} {}", - self.token_type.as_ref(), - self.access_token - )) - .expect("Invalid Scheme header value") - } -} - -impl FromStr for Scheme { - type Err = &'static str; - fn from_str(s: &str) -> Result { - let parts: Vec<&str> = s.split(' ').collect(); - if parts.len() != 2 { - return Err("Expected two parts: "); - } - match ::from_str(parts[0]) { - Ok(t) => Ok(Scheme { - token_type: t, - access_token: parts[1].to_string(), - }), - Err(_) => Err("Couldn't parse token type"), - } - } -} /// Represents a token as returned by OAuth2 servers. /// @@ -362,7 +119,6 @@ pub struct ConsoleApplicationSecret { #[cfg(test)] pub mod tests { use super::*; - use hyper; pub const SECRET: &'static str = "{\"installed\":{\"auth_uri\":\"https://accounts.google.com/o/oauth2/auth\",\ @@ -380,25 +136,4 @@ pub mod tests { Err(err) => panic!(err), } } - - #[test] - fn schema() { - let s = Scheme { - token_type: TokenType::Bearer, - access_token: "foo".to_string(), - }; - let mut headers = hyper::HeaderMap::new(); - headers.insert(hyper::header::AUTHORIZATION, s.into()); - assert_eq!( - format!("{:?}", headers), - "{\"authorization\": \"Bearer foo\"}".to_string() - ); - } - - #[test] - fn parse_schema() { - let auth = Scheme::from_str("Bearer foo").unwrap(); - assert_eq!(auth.token_type, TokenType::Bearer); - assert_eq!(auth.access_token, "foo".to_string()); - } } From ba0b8f366ad3d972379f8a18892670a80fcf1a34 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Wed, 13 Nov 2019 13:55:32 -0800 Subject: [PATCH 29/71] Rename RequestError to Error RequestError is the error value that encompasses all errors from the authenticators. Their is an established convention of using Error as the name for those types. --- src/authenticator.rs | 8 ++--- src/authenticator_delegate.rs | 16 ++++----- src/device.rs | 21 +++++------ src/error.rs | 68 +++++++++++++++++------------------ src/installed.rs | 18 +++++----- src/lib.rs | 2 +- src/refresh.rs | 2 +- src/service_account.rs | 21 +++++------ 8 files changed, 74 insertions(+), 82 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index b1824e4..a54cede 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -2,7 +2,7 @@ use crate::authenticator_delegate::{ AuthenticatorDelegate, DefaultAuthenticatorDelegate, FlowDelegate, }; use crate::device::DeviceFlow; -use crate::error::RequestError; +use crate::error::Error; use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; use crate::refresh::RefreshFlow; use crate::storage::{self, Storage}; @@ -27,7 +27,7 @@ impl Authenticator where C: hyper::client::connect::Connect + 'static, { - pub async fn token<'a, T>(&'a self, scopes: &'a [T]) -> Result + pub async fn token<'a, T>(&'a self, scopes: &'a [T]) -> Result where T: AsRef, { @@ -239,7 +239,7 @@ impl AuthenticatorBuilder { mod private { use crate::device::DeviceFlow; - use crate::error::RequestError; + use crate::error::Error; use crate::installed::InstalledFlow; use crate::types::{ApplicationSecret, Token}; @@ -266,7 +266,7 @@ mod private { hyper_client: &'a hyper::Client, app_secret: &'a ApplicationSecret, scopes: &'a [T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 92f5329..86c00dc 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -1,10 +1,10 @@ use hyper; -use std::error::Error; +use std::error::Error as StdError; use std::fmt; use std::pin::Pin; -use crate::error::{PollError, RefreshError, RequestError}; +use crate::error::{Error, PollError, RefreshError}; use chrono::{DateTime, Local, Utc}; use std::time::Duration; @@ -58,8 +58,8 @@ impl fmt::Display for PollError { } } -impl Error for PollError { - fn source(&self) -> Option<&(dyn Error + 'static)> { +impl StdError for PollError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { match *self { PollError::HttpError(ref e) => Some(e), _ => None, @@ -80,7 +80,7 @@ pub trait AuthenticatorDelegate: Send + Sync { } /// The server denied the attempt to obtain a request code - fn request_failure(&self, _: RequestError) {} + fn request_failure(&self, _: Error) {} /// Called if we could not acquire a refresh token for a reason possibly specified /// by the server. @@ -140,7 +140,7 @@ pub trait FlowDelegate: Send + Sync { &'a self, url: &'a str, need_code: bool, - ) -> Pin>> + Send + 'a>> + ) -> Pin>> + Send + 'a>> { Box::pin(present_user_url(url, need_code)) } @@ -149,7 +149,7 @@ pub trait FlowDelegate: Send + Sync { async fn present_user_url( url: &str, need_code: bool, -) -> Result> { +) -> Result> { if need_code { println!( "Please direct your browser to {}, follow the instructions and enter the \ @@ -163,7 +163,7 @@ async fn present_user_url( { Err(err) => { println!("{:?}", err); - Err(Box::new(err) as Box) + Err(Box::new(err) as Box) } Ok(_) => Ok(user_input), } diff --git a/src/device.rs b/src/device.rs index e1bdff6..1943d4e 100644 --- a/src/device.rs +++ b/src/device.rs @@ -10,7 +10,7 @@ use serde_json as json; use url::form_urlencoded; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate, PollInformation, Retry}; -use crate::error::{JsonErrorOr, PollError, RequestError}; +use crate::error::{Error, JsonErrorOr, PollError}; use crate::types::{ApplicationSecret, Token}; pub const GOOGLE_DEVICE_CODE_URL: &str = "https://accounts.google.com/o/oauth2/device/code"; @@ -46,7 +46,7 @@ impl DeviceFlow { hyper_client: &hyper::Client, app_secret: &ApplicationSecret, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -65,7 +65,7 @@ impl DeviceFlow { self.wait_duration, ) .await - .map_err(|_| RequestError::Poll(PollError::TimedOut))? + .map_err(|_| Error::Poll(PollError::TimedOut))? } async fn wait_for_device_token( @@ -75,7 +75,7 @@ impl DeviceFlow { pollinf: &PollInformation, device_code: &str, grant_type: &str, - ) -> Result + ) -> Result where C: hyper::client::connect::Connect + 'static, { @@ -93,15 +93,13 @@ impl DeviceFlow { .await { Ok(None) => match self.flow_delegate.pending(&pollinf) { - Retry::Abort | Retry::Skip => { - return Err(RequestError::Poll(PollError::TimedOut)) - } + Retry::Abort | Retry::Skip => return Err(Error::Poll(PollError::TimedOut)), Retry::After(d) => d, }, Ok(Some(tok)) => return Ok(tok), Err(e @ PollError::AccessDenied) | Err(e @ PollError::TimedOut) - | Err(e @ PollError::Expired(_)) => return Err(RequestError::Poll(e)), + | Err(e @ PollError::Expired(_)) => return Err(Error::Poll(e)), Err(ref e) => { error!("Unknown error from poll token api: {}", e); pollinf.interval @@ -130,7 +128,7 @@ impl DeviceFlow { client: &hyper::Client, device_code_url: &str, scopes: &[T], - ) -> Result<(PollInformation, String), RequestError> + ) -> Result<(PollInformation, String), Error> where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -148,10 +146,7 @@ impl DeviceFlow { .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(req)) .unwrap(); - let resp = client - .request(req) - .await - .map_err(RequestError::ClientError)?; + let resp = client.request(req).await.map_err(Error::ClientError)?; // This return type is defined in https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.2 // The alias is present as Google use a non-standard name for verification_uri. // According to the standard interval is optional, however, all tested implementations provide it. diff --git a/src/error.rs b/src/error.rs index 191007a..d1f1aa9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,4 @@ -use std::error::Error; +use std::error::Error as StdError; use std::fmt; use std::io; @@ -45,7 +45,7 @@ pub enum PollError { /// Encapsulates all possible results of the `token(...)` operation #[derive(Debug)] -pub enum RequestError { +pub enum Error { /// Indicates connection failure ClientError(hyper::Error), /// The OAuth client was not found @@ -69,75 +69,75 @@ pub enum RequestError { /// An error occurred while refreshing tokens. Refresh(RefreshError), /// Error in token cache layer - Cache(Box), + Cache(Box), } -impl From for RequestError { - fn from(error: hyper::Error) -> RequestError { - RequestError::ClientError(error) +impl From for Error { + fn from(error: hyper::Error) -> Error { + Error::ClientError(error) } } -impl From for RequestError { - fn from(value: JsonError) -> RequestError { +impl From for Error { + fn from(value: JsonError) -> Error { match &*value.error { - "invalid_client" => RequestError::InvalidClient, - "invalid_scope" => RequestError::InvalidScope( + "invalid_client" => Error::InvalidClient, + "invalid_scope" => Error::InvalidScope( value .error_description .unwrap_or_else(|| "no description provided".to_string()), ), - _ => RequestError::NegativeServerResponse(value.error, value.error_description), + _ => Error::NegativeServerResponse(value.error, value.error_description), } } } -impl From for RequestError { - fn from(value: serde_json::Error) -> RequestError { - RequestError::JSONError(value) +impl From for Error { + fn from(value: serde_json::Error) -> Error { + Error::JSONError(value) } } -impl From for RequestError { - fn from(value: RefreshError) -> RequestError { - RequestError::Refresh(value) +impl From for Error { + fn from(value: RefreshError) -> Error { + Error::Refresh(value) } } -impl fmt::Display for RequestError { +impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self { - RequestError::ClientError(ref err) => err.fmt(f), - RequestError::InvalidClient => "Invalid Client".fmt(f), - RequestError::InvalidScope(ref scope) => writeln!(f, "Invalid Scope: '{}'", scope), - RequestError::NegativeServerResponse(ref error, ref desc) => { + Error::ClientError(ref err) => err.fmt(f), + Error::InvalidClient => "Invalid Client".fmt(f), + Error::InvalidScope(ref scope) => writeln!(f, "Invalid Scope: '{}'", scope), + Error::NegativeServerResponse(ref error, ref desc) => { error.fmt(f)?; if let Some(ref desc) = *desc { write!(f, ": {}", desc)?; } "\n".fmt(f) } - RequestError::BadServerResponse(ref s) => s.fmt(f), - RequestError::JSONError(ref e) => format!( + Error::BadServerResponse(ref s) => s.fmt(f), + Error::JSONError(ref e) => format!( "JSON Error; this might be a bug with unexpected server responses! {}", e ) .fmt(f), - RequestError::UserError(ref s) => s.fmt(f), - RequestError::LowLevelError(ref e) => e.fmt(f), - RequestError::Poll(ref pe) => pe.fmt(f), - RequestError::Refresh(ref rr) => format!("{:?}", rr).fmt(f), - RequestError::Cache(ref e) => e.fmt(f), + Error::UserError(ref s) => s.fmt(f), + Error::LowLevelError(ref e) => e.fmt(f), + Error::Poll(ref pe) => pe.fmt(f), + Error::Refresh(ref rr) => format!("{:?}", rr).fmt(f), + Error::Cache(ref e) => e.fmt(f), } } } -impl Error for RequestError { - fn source(&self) -> Option<&(dyn Error + 'static)> { +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { match *self { - RequestError::ClientError(ref err) => Some(err), - RequestError::LowLevelError(ref err) => Some(err), - RequestError::JSONError(ref err) => Some(err), + Error::ClientError(ref err) => Some(err), + Error::LowLevelError(ref err) => Some(err), + Error::JSONError(ref err) => Some(err), _ => None, } } diff --git a/src/installed.rs b/src/installed.rs index b70979f..4fbdfb6 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -17,7 +17,7 @@ use url::form_urlencoded; use url::percent_encoding::{percent_encode, QUERY_ENCODE_SET}; use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate}; -use crate::error::{JsonErrorOr, RequestError}; +use crate::error::{Error, JsonErrorOr}; use crate::types::{ApplicationSecret, Token}; const OOB_REDIRECT_URI: &str = "urn:ietf:wg:oauth:2.0:oob"; @@ -90,7 +90,7 @@ impl InstalledFlow { hyper_client: &hyper::Client, app_secret: &ApplicationSecret, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -112,7 +112,7 @@ impl InstalledFlow { hyper_client: &hyper::Client, app_secret: &ApplicationSecret, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -136,7 +136,7 @@ impl InstalledFlow { } code } - _ => return Err(RequestError::UserError("couldn't read code".to_string())), + _ => return Err(Error::UserError("couldn't read code".to_string())), }; self.exchange_auth_code(&authcode, hyper_client, app_secret, None) .await @@ -147,7 +147,7 @@ impl InstalledFlow { hyper_client: &hyper::Client, app_secret: &ApplicationSecret, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -185,7 +185,7 @@ impl InstalledFlow { hyper_client: &hyper::Client, app_secret: &ApplicationSecret, server_addr: Option, - ) -> Result + ) -> Result where C: hyper::client::connect::Connect + 'static, { @@ -194,12 +194,12 @@ impl InstalledFlow { let resp = hyper_client .request(request) .await - .map_err(RequestError::ClientError)?; + .map_err(Error::ClientError)?; let body = resp .into_body() .try_concat() .await - .map_err(RequestError::ClientError)?; + .map_err(Error::ClientError)?; #[derive(Deserialize)] struct JSONTokenResponse { @@ -276,7 +276,7 @@ struct InstalledFlowServer { } impl InstalledFlowServer { - fn run() -> Result { + fn run() -> Result { use hyper::service::{make_service_fn, service_fn}; let (auth_code_tx, auth_code_rx) = oneshot::channel::(); let (trigger_shutdown_tx, trigger_shutdown_rx) = oneshot::channel::<()>(); diff --git a/src/lib.rs b/src/lib.rs index 1e63c14..bb82e6c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -92,5 +92,5 @@ pub use crate::installed::InstalledFlowReturnMethod; pub use crate::service_account::{ServiceAccountAuthenticator, ServiceAccountKey}; #[doc(inline)] -pub use crate::error::RequestError; +pub use crate::error::Error; pub use crate::types::{ApplicationSecret, ConsoleApplicationSecret, Token}; diff --git a/src/refresh.rs b/src/refresh.rs index f6826be..03847e2 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -35,7 +35,7 @@ impl RefreshFlow { client_secret: &ApplicationSecret, refresh_token: &str, ) -> Result { - // TODO: Does this function ever return RequestError? Maybe have it just return RefreshResult. + // TODO: Does this function ever return Error? Maybe have it just return RefreshResult. let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ ("client_id", client_secret.client_id.as_str()), diff --git a/src/service_account.rs b/src/service_account.rs index c2732a2..66ab1df 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -14,7 +14,7 @@ use std::sync::Mutex; use crate::authenticator::{DefaultHyperClient, HyperClientBuilder}; -use crate::error::{JsonErrorOr, RequestError}; +use crate::error::{Error, JsonErrorOr}; use crate::storage::{self, Storage}; use crate::types::Token; @@ -228,7 +228,7 @@ where }) } - pub async fn token(&self, scopes: &[T]) -> Result + pub async fn token(&self, scopes: &[T]) -> Result where T: AsRef, { @@ -256,13 +256,13 @@ where subject: Option<&str>, key: &ServiceAccountKey, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, { let claims = Claims::new(key, scopes, subject); let signed = signer.sign_claims(&claims).map_err(|_| { - RequestError::LowLevelError(io::Error::new( + Error::LowLevelError(io::Error::new( io::ErrorKind::Other, "unable to sign claims", )) @@ -274,15 +274,12 @@ where .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(rqbody)) .unwrap(); - let response = client - .request(request) - .await - .map_err(RequestError::ClientError)?; + let response = client.request(request).await.map_err(Error::ClientError)?; let body = response .into_body() .try_concat() .await - .map_err(RequestError::ClientError)?; + .map_err(Error::ClientError)?; /// This is the schema of the server's response. #[derive(Deserialize, Debug)] @@ -308,7 +305,7 @@ where expires_in_timestamp: Some(expires_ts), }) } - token => Err(RequestError::BadServerResponse(format!( + token => Err(Error::BadServerResponse(format!( "Token response lacks fields: {:?}", token ))), @@ -380,7 +377,7 @@ mod tests { .await?; assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); assert_eq!(Some(3600), tok.expires_in); - Ok(()) as Result<(), RequestError> + Ok(()) as Result<(), Error> }; rt.block_on(fut).expect("block_on"); @@ -400,7 +397,7 @@ mod tests { .await?; assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); assert_eq!(Some(3600), tok.expires_in); - Ok(()) as Result<(), RequestError> + Ok(()) as Result<(), Error> }; rt.block_on(fut).expect("block_on 2"); From e5aa32b3cf83ce282f2fa916d71ac31b179db0b8 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Wed, 13 Nov 2019 14:32:39 -0800 Subject: [PATCH 30/71] Tidy up some imports. No more need to macro_use serde. Order the imports consistently (albeit somewhat arbitrary), starting with items from this crate, followed by std, followed by external crates. --- Cargo.toml | 3 +-- src/authenticator_delegate.rs | 12 ++++-------- src/device.rs | 18 +++++++++--------- src/error.rs | 1 + src/helper.rs | 9 ++------- src/installed.rs | 10 +++++----- src/lib.rs | 3 --- src/refresh.rs | 5 ++--- src/service_account.rs | 15 +++++---------- src/storage.rs | 3 ++- src/types.rs | 1 + 11 files changed, 32 insertions(+), 48 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 88618be..7f0ed0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,9 +18,8 @@ hyper = {version = "0.13.0-alpha.4", features = ["unstable-stream"]} hyper-rustls = "=0.18.0-alpha.2" log = "0.4" rustls = "0.16" -serde = "1.0" +serde = {version = "1.0", features = ["derive"]} serde_json = "1.0" -serde_derive = "1.0" url = "1" futures-preview = "=0.3.0-alpha.19" tokio = "=0.2.0-alpha.6" diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 86c00dc..e5ad5e3 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -1,17 +1,12 @@ -use hyper; +use crate::error::{Error, PollError, RefreshError}; use std::error::Error as StdError; use std::fmt; use std::pin::Pin; - -use crate::error::{Error, PollError, RefreshError}; - -use chrono::{DateTime, Local, Utc}; use std::time::Duration; +use chrono::{DateTime, Local, Utc}; use futures::prelude::*; -use tio::AsyncBufReadExt; -use tokio::io as tio; /// A utility type to indicate how operations DeviceFlowHelper operations should be retried pub enum Retry { @@ -150,6 +145,7 @@ async fn present_user_url( url: &str, need_code: bool, ) -> Result> { + use tokio::io::AsyncBufReadExt; if need_code { println!( "Please direct your browser to {}, follow the instructions and enter the \ @@ -157,7 +153,7 @@ async fn present_user_url( url ); let mut user_input = String::new(); - match tio::BufReader::new(tio::stdin()) + match tokio::io::BufReader::new(tokio::io::stdin()) .read_line(&mut user_input) .await { diff --git a/src/device.rs b/src/device.rs index 1943d4e..90ec9f3 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,18 +1,17 @@ +use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate, PollInformation, Retry}; +use crate::error::{Error, JsonErrorOr, PollError}; +use crate::types::{ApplicationSecret, Token}; + use std::borrow::Cow; use std::time::Duration; use ::log::error; use chrono::{DateTime, Utc}; use futures::prelude::*; -use hyper; use hyper::header; -use serde_json as json; +use serde::Deserialize; use url::form_urlencoded; -use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate, PollInformation, Retry}; -use crate::error::{Error, JsonErrorOr, PollError}; -use crate::types::{ApplicationSecret, Token}; - pub const GOOGLE_DEVICE_CODE_URL: &str = "https://accounts.google.com/o/oauth2/device/code"; // https://developers.google.com/identity/protocols/OAuth2ForDevices#step-4:-poll-googles-authorization-server @@ -162,7 +161,8 @@ impl DeviceFlow { } let json_bytes = resp.into_body().try_concat().await?; - let decoded: JsonData = json::from_slice::>(&json_bytes)?.into_result()?; + let decoded: JsonData = + serde_json::from_slice::>(&json_bytes)?.into_result()?; let expires_in = decoded.expires_in.unwrap_or(60 * 60); let pi = PollInformation { user_code: decoded.user_code, @@ -235,7 +235,7 @@ impl DeviceFlow { error: String, } - match json::from_slice::(&body) { + match serde_json::from_slice::(&body) { Err(_) => {} // ignore, move on, it's not an error Ok(res) => { match res.error.as_ref() { @@ -255,7 +255,7 @@ impl DeviceFlow { } // yes, we expect that ! - let mut t: Token = json::from_slice(&body).unwrap(); + let mut t: Token = serde_json::from_slice(&body).unwrap(); t.set_expiry_absolute(); Ok(Some(t)) diff --git a/src/error.rs b/src/error.rs index d1f1aa9..a8b66e1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,6 +3,7 @@ use std::fmt; use std::io; use chrono::{DateTime, Utc}; +use serde::Deserialize; #[derive(Deserialize, Debug)] pub(crate) struct JsonError { diff --git a/src/helper.rs b/src/helper.rs index cd3e3e2..e5465ac 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -1,20 +1,15 @@ -#![allow(dead_code)] - //! Helper functions allowing you to avoid writing boilerplate code for common operations, such as //! parsing JSON or reading files. // Copyright (c) 2016 Google Inc (lewinb@google.com). // // Refer to the project root for licensing information. - -use serde_json; +use crate::service_account::ServiceAccountKey; +use crate::types::{ApplicationSecret, ConsoleApplicationSecret}; use std::io; use std::path::Path; -use crate::service_account::ServiceAccountKey; -use crate::types::{ApplicationSecret, ConsoleApplicationSecret}; - /// Read an application secret from a file. pub fn read_application_secret>(path: P) -> io::Result { parse_application_secret(std::fs::read_to_string(path)?) diff --git a/src/installed.rs b/src/installed.rs index 4fbdfb6..8fd64ae 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -2,6 +2,10 @@ // // Refer to the project root for licensing information. // +use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate}; +use crate::error::{Error, JsonErrorOr}; +use crate::types::{ApplicationSecret, Token}; + use std::convert::AsRef; use std::future::Future; use std::net::SocketAddr; @@ -10,16 +14,12 @@ use std::sync::{Arc, Mutex}; use futures::future::FutureExt; use futures_util::try_stream::TryStreamExt; -use hyper; use hyper::header; +use serde::Deserialize; use tokio::sync::oneshot; use url::form_urlencoded; use url::percent_encoding::{percent_encode, QUERY_ENCODE_SET}; -use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate}; -use crate::error::{Error, JsonErrorOr}; -use crate::types::{ApplicationSecret, Token}; - const OOB_REDIRECT_URI: &str = "urn:ietf:wg:oauth:2.0:oob"; /// Assembles a URL to request an authorization token (with user interaction). diff --git a/src/lib.rs b/src/lib.rs index bb82e6c..79670c2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -68,9 +68,6 @@ //! } //! ``` //! -#[macro_use] -extern crate serde_derive; - pub mod authenticator; pub mod authenticator_delegate; mod device; diff --git a/src/refresh.rs b/src/refresh.rs index 03847e2..0a4a225 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -1,11 +1,10 @@ use crate::error::{JsonErrorOr, RefreshError}; -use crate::types::ApplicationSecret; +use crate::types::{ApplicationSecret, Token}; -use super::Token; use chrono::Utc; use futures_util::try_stream::TryStreamExt; -use hyper; use hyper::header; +use serde::Deserialize; use url::form_urlencoded; /// Implements the [OAuth2 Refresh Token Flow](https://developers.google.com/youtube/v3/guides/authentication#devices). diff --git a/src/service_account.rs b/src/service_account.rs index 66ab1df..2327d1f 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -11,29 +11,24 @@ //! Copyright (c) 2016 Google Inc (lewinb@google.com). //! -use std::sync::Mutex; - use crate::authenticator::{DefaultHyperClient, HyperClientBuilder}; use crate::error::{Error, JsonErrorOr}; use crate::storage::{self, Storage}; use crate::types::Token; +use std::io; +use std::sync::Mutex; + use futures::prelude::*; use hyper::header; -use url::form_urlencoded; - use rustls::{ self, internal::pemfile, sign::{self, SigningKey}, PrivateKey, }; -use std::io; - -use base64; -use chrono; -use hyper; -use serde_json; +use serde::{Deserialize, Serialize}; +use url::form_urlencoded; const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer"; const GOOGLE_RS256_HEAD: &str = r#"{"alg":"RS256","typ":"JWT"}"#; diff --git a/src/storage.rs b/src/storage.rs index 17cf2db..7192f77 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -2,6 +2,7 @@ // // See project root for licensing information. // +use crate::types::Token; use std::cmp::Ordering; use std::collections::hash_map::DefaultHasher; @@ -10,7 +11,7 @@ use std::io; use std::path::{Path, PathBuf}; use std::sync::Mutex; -use crate::types::Token; +use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct ScopeHash(u64); diff --git a/src/types.rs b/src/types.rs index cb044a1..c8d5510 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,4 +1,5 @@ use chrono::{DateTime, TimeZone, Utc}; +use serde::{Deserialize, Serialize}; /// Represents a token as returned by OAuth2 servers. /// From ca453c056c2ff2238c1020ac07f2edb1936cb54b Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Wed, 13 Nov 2019 15:43:16 -0800 Subject: [PATCH 31/71] Improve documentation --- src/authenticator.rs | 126 +++++++++++++++++++++++++++------- src/authenticator_delegate.rs | 2 + src/error.rs | 2 + src/lib.rs | 15 ++-- src/service_account.rs | 40 ++++++++++- src/types.rs | 20 +----- 6 files changed, 156 insertions(+), 49 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index a54cede..bdade8f 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -1,3 +1,4 @@ +//! Module contianing the core functionality for OAuth2 Authentication. use crate::authenticator_delegate::{ AuthenticatorDelegate, DefaultAuthenticatorDelegate, FlowDelegate, }; @@ -15,6 +16,8 @@ use std::path::PathBuf; use std::sync::Mutex; use std::time::Duration; +/// Authenticator is responsible for fetching tokens, handling refreshing tokens, +/// and optionally persisting tokens to disk. pub struct Authenticator { hyper_client: hyper::Client, app_secret: ApplicationSecret, @@ -27,6 +30,7 @@ impl Authenticator where C: hyper::client::connect::Connect + 'static, { + /// Return the current token for the provided scopes. pub async fn token<'a, T>(&'a self, scopes: &'a [T]) -> Result where T: AsRef, @@ -77,6 +81,7 @@ where } } +/// Configure an Authenticator using the builder pattern. pub struct AuthenticatorBuilder { hyper_client_builder: C, app_secret: ApplicationSecret, @@ -85,8 +90,24 @@ pub struct AuthenticatorBuilder { auth_flow: F, } +/// Create an authenticator that uses the installed flow. +/// ``` +/// # async fn foo() { +/// # use yup_oauth2::InstalledFlowReturnMethod; +/// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultFlowDelegate; +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").unwrap(); +/// let authenticator = yup_oauth2::InstalledFlowAuthenticator::builder( +/// app_secret, +/// InstalledFlowReturnMethod::HTTPRedirect, +/// ) +/// .build() +/// .await +/// .expect("failed to create authenticator"); +/// # } +/// ``` pub struct InstalledFlowAuthenticator; impl InstalledFlowAuthenticator { + /// Use the builder pattern to create an Authenticator that uses the installed flow. pub fn builder( app_secret: ApplicationSecret, method: InstalledFlowReturnMethod, @@ -98,8 +119,19 @@ impl InstalledFlowAuthenticator { } } +/// Create an authenticator that uses the device flow. +/// ``` +/// # async fn foo() { +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").unwrap(); +/// let authenticator = yup_oauth2::DeviceFlowAuthenticator::builder(app_secret) +/// .build() +/// .await +/// .expect("failed to create authenticator"); +/// # } +/// ``` pub struct DeviceFlowAuthenticator; impl DeviceFlowAuthenticator { + /// Use the builder pattern to create an Authenticator that uses the device flow. pub fn builder( app_secret: ApplicationSecret, ) -> AuthenticatorBuilder { @@ -107,7 +139,45 @@ impl DeviceFlowAuthenticator { } } +/// Methods available when building any Authenticator. +/// ``` +/// # async fn foo() { +/// # let custom_hyper_client = hyper::Client::new(); +/// # let custom_auth_delegate = yup_oauth2::authenticator_delegate::DefaultAuthenticatorDelegate; +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").unwrap(); +/// let authenticator = yup_oauth2::DeviceFlowAuthenticator::builder(app_secret) +/// .hyper_client(custom_hyper_client) +/// .persist_tokens_to_disk("/tmp/tokenfile.json") +/// .auth_delegate(Box::new(custom_auth_delegate)) +/// .build() +/// .await +/// .expect("failed to create authenticator"); +/// # } +/// ``` impl AuthenticatorBuilder { + /// Create the authenticator. + pub async fn build(self) -> io::Result> + where + C: HyperClientBuilder, + F: Into, + { + let hyper_client = self.hyper_client_builder.build_hyper_client(); + let storage = match self.storage_type { + StorageType::Memory => Storage::Memory { + tokens: Mutex::new(storage::JSONTokens::new()), + }, + StorageType::Disk(path) => Storage::Disk(storage::DiskStorage::new(path).await?), + }; + + Ok(Authenticator { + hyper_client, + app_secret: self.app_secret, + storage, + auth_delegate: self.auth_delegate, + auth_flow: self.auth_flow.into(), + }) + } + fn with_auth_flow( app_secret: ApplicationSecret, auth_flow: F, @@ -153,31 +223,23 @@ impl AuthenticatorBuilder { ..self } } - - /// Create the authenticator. - pub async fn build(self) -> io::Result> - where - C: HyperClientBuilder, - F: Into, - { - let hyper_client = self.hyper_client_builder.build_hyper_client(); - let storage = match self.storage_type { - StorageType::Memory => Storage::Memory { - tokens: Mutex::new(storage::JSONTokens::new()), - }, - StorageType::Disk(path) => Storage::Disk(storage::DiskStorage::new(path).await?), - }; - - Ok(Authenticator { - hyper_client, - app_secret: self.app_secret, - storage, - auth_delegate: self.auth_delegate, - auth_flow: self.auth_flow.into(), - }) - } } +/// Methods available when building a device flow Authenticator. +/// ``` +/// # async fn foo() { +/// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultFlowDelegate; +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").unwrap(); +/// let authenticator = yup_oauth2::DeviceFlowAuthenticator::builder(app_secret) +/// .device_code_url("foo") +/// .flow_delegate(Box::new(custom_flow_delegate)) +/// .wait_duration(std::time::Duration::from_secs(120)) +/// .grant_type("foo") +/// .build() +/// .await +/// .expect("failed to create authenticator"); +/// # } +/// ``` impl AuthenticatorBuilder { /// Use the provided device code url. pub fn device_code_url(self, url: impl Into>) -> Self { @@ -224,6 +286,22 @@ impl AuthenticatorBuilder { } } +/// Methods available when building an installed flow Authenticator. +/// ``` +/// # async fn foo() { +/// # use yup_oauth2::InstalledFlowReturnMethod; +/// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultFlowDelegate; +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").unwrap(); +/// let authenticator = yup_oauth2::InstalledFlowAuthenticator::builder( +/// app_secret, +/// InstalledFlowReturnMethod::HTTPRedirect, +/// ) +/// .flow_delegate(Box::new(custom_flow_delegate)) +/// .build() +/// .await +/// .expect("failed to create authenticator"); +/// # } +/// ``` impl AuthenticatorBuilder { /// Use the provided FlowDelegate. pub fn flow_delegate(self, flow_delegate: Box) -> Self { @@ -285,8 +363,10 @@ mod private { /// A trait implemented for any hyper::Client as well as teh DefaultHyperClient. pub trait HyperClientBuilder { + /// The hyper connector that the resulting hyper client will use. type Connector: hyper::client::connect::Connect + 'static; + /// Create a hyper::Client fn build_hyper_client(self) -> hyper::Client; } diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index e5ad5e3..c55c2c7 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -1,3 +1,5 @@ +//! Module containing types related to delegates. + use crate::error::{Error, PollError, RefreshError}; use std::error::Error as StdError; diff --git a/src/error.rs b/src/error.rs index a8b66e1..619b984 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,5 @@ +//! Module containing various error types. + use std::error::Error as StdError; use std::fmt; use std::io; diff --git a/src/lib.rs b/src/lib.rs index 79670c2..824eb9f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,19 +20,19 @@ //! based on the Google APIs; it may or may not work with other providers. //! //! # Installed Flow Usage -//! The `InstalledFlow` involves showing a URL to the user (or opening it in a browser) +//! The installed flow involves showing a URL to the user (or opening it in a browser) //! and then either prompting the user to enter a displayed code, or make the authorizing //! website redirect to a web server spun up by this library and running on localhost. //! -//! In order to use the interactive method, use the `InstalledInteractive` `FlowType`; -//! for the redirect method, use `InstalledRedirect`, with the port number to let the -//! server listen on. +//! In order to use the interactive method, use the `Interactive` `InstalledFlowReturnMethod`; +//! for the redirect method, use `HTTPRedirect`. //! //! You can implement your own `AuthenticatorDelegate` in order to customize the flow; -//! the `InstalledFlow` uses the `present_user_url` method. +//! the installed flow uses the `present_user_url` method. //! -//! The returned `Token` is stored permanently in the given token storage in order to -//! authorize future API requests to the same scopes. +//! The returned `Token` will be stored in memory in order to authorize future +//! API requests to the same scopes. The tokens can optionally be persisted to +//! disk by using `persist_tokens_to_disk` when creating the authenticator. //! //! The following example, which is derived from the (actual and runnable) example in //! `examples/test-installed/`, shows the basics of using this crate: @@ -68,6 +68,7 @@ //! } //! ``` //! +#![deny(missing_docs)] pub mod authenticator; pub mod authenticator_delegate; mod device; diff --git a/src/service_account.rs b/src/service_account.rs index 2327d1f..b88d5c8 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -69,15 +69,25 @@ fn decode_rsa_key(pem_pkcs8: &str) -> Result { #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ServiceAccountKey { #[serde(rename = "type")] + /// key_type pub key_type: Option, + /// project_id pub project_id: Option, + /// private_key_id pub private_key_id: Option, + /// private_key pub private_key: String, + /// client_email pub client_email: String, + /// client_id pub client_id: Option, + /// auth_uri pub auth_uri: Option, + /// token_uri pub token_uri: String, + /// auth_provider_x509_cert_url pub auth_provider_x509_cert_url: Option, + /// client_x509_cert_url pub client_x509_cert_url: Option, } @@ -150,8 +160,19 @@ impl JWTSigner { } } +/// Create an authenticator that uses a service account. +/// ``` +/// # async fn foo() { +/// # let service_key = yup_oauth2::service_account_key_from_file("/tmp/foo").unwrap(); +/// let authenticator = yup_oauth2::ServiceAccountAuthenticator::builder(service_key) +/// .build() +/// .expect("failed to create authenticator"); +/// # } +/// ``` pub struct ServiceAccountAuthenticator; impl ServiceAccountAuthenticator { + /// Use the builder pattern to create an authenticator that uses a service + /// account. pub fn builder(key: ServiceAccountKey) -> Builder { Builder { client: DefaultHyperClient, @@ -161,12 +182,25 @@ impl ServiceAccountAuthenticator { } } +/// Configure a service account authenticator using the builder pattern. pub struct Builder { client: C, key: ServiceAccountKey, subject: Option, } +/// Methods available when building a service account authenticator. +/// ``` +/// # async fn foo() { +/// # let custom_hyper_client = hyper::Client::new(); +/// # let service_key = yup_oauth2::service_account_key_from_file("/tmp/foo").unwrap(); +/// let authenticator = yup_oauth2::ServiceAccountAuthenticator::builder(service_key) +/// .hyper_client(custom_hyper_client) +/// .subject("foo") +/// .build() +/// .expect("failed to create authenticator"); +/// # } +/// ``` impl Builder { /// Use the provided hyper client. pub fn hyper_client(self, hyper_client: NewC) -> Builder { @@ -178,9 +212,9 @@ impl Builder { } /// Use the provided subject. - pub fn subject(self, subject: String) -> Self { + pub fn subject(self, subject: impl Into) -> Self { Builder { - subject: Some(subject), + subject: Some(subject.into()), ..self } } @@ -194,6 +228,7 @@ impl Builder { } } +/// ServiceAccountAccess can fetch oauth tokens using a service account. pub struct ServiceAccountAccess { client: hyper::Client, key: ServiceAccountKey, @@ -223,6 +258,7 @@ where }) } + /// Return the current token for the provided scopes. pub async fn token(&self, scopes: &[T]) -> Result where T: AsRef, diff --git a/src/types.rs b/src/types.rs index c8d5510..34218b3 100644 --- a/src/types.rs +++ b/src/types.rs @@ -80,8 +80,8 @@ pub struct ApplicationSecret { pub token_uri: String, /// The authorization server endpoint URI. pub auth_uri: String, + /// The redirect uris. pub redirect_uris: Vec, - /// Name of the google project the credentials are associated with pub project_id: Option, /// The service account email associated with the client. @@ -93,27 +93,13 @@ pub struct ApplicationSecret { pub client_x509_cert_url: Option, } -impl ApplicationSecret { - pub const fn empty() -> Self { - ApplicationSecret { - client_id: String::new(), - client_secret: String::new(), - token_uri: String::new(), - auth_uri: String::new(), - redirect_uris: Vec::new(), - project_id: None, - client_email: None, - auth_provider_x509_cert_url: None, - client_x509_cert_url: None, - } - } -} - /// A type to facilitate reading and writing the json secret file /// as returned by the [google developer console](https://code.google.com/apis/console) #[derive(Deserialize, Serialize, Default)] pub struct ConsoleApplicationSecret { + /// web app secret pub web: Option, + /// installed app secret pub installed: Option, } From 68a30ea0fea7a375e215d2511ab33828a37fbb9d Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 14 Nov 2019 14:07:11 -0800 Subject: [PATCH 32/71] Tidy up tests. --- src/device.rs | 147 +++++++++++++++++++++-------------------- src/helper.rs | 10 +++ src/installed.rs | 119 +++++++++++++++++---------------- src/refresh.rs | 46 ++++--------- src/service_account.rs | 139 ++++++++++++++------------------------ 5 files changed, 214 insertions(+), 247 deletions(-) diff --git a/src/device.rs b/src/device.rs index 90ec9f3..6815963 100644 --- a/src/device.rs +++ b/src/device.rs @@ -264,16 +264,12 @@ impl DeviceFlow { #[cfg(test)] mod tests { - use hyper; use hyper_rustls::HttpsConnector; - use mockito; - use tokio; use super::*; - use crate::helper::parse_application_secret; - #[test] - fn test_device_end2end() { + #[tokio::test] + async fn test_device_end2end() { #[derive(Clone)] struct FD; impl FlowDelegate for FD { @@ -283,9 +279,15 @@ mod tests { } let server_url = mockito::server_url(); - let app_secret = r#"{"installed":{"client_id":"902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com","project_id":"yup-test-243420","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_secret":"iuMPN6Ne1PD7cos29Tk9rlqH","redirect_uris":["urn:ietf:wg:oauth:2.0:oob","http://localhost"]}}"#; - let mut app_secret = parse_application_secret(app_secret).unwrap(); - app_secret.token_uri = format!("{}/token", server_url); + let app_secret: ApplicationSecret = crate::parse_json!({ + "client_id": "902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com", + "project_id": "yup-test-243420", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": format!("{}/token", server_url), + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_secret": "iuMPN6Ne1PD7cos29Tk9rlqH", + "redirect_uris": ["urn:ietf:wg:oauth:2.0:oob","http://localhost"], + }); let device_code_url = format!("{}/code", server_url); let https = HttpsConnector::new(); @@ -300,118 +302,123 @@ mod tests { grant_type: GOOGLE_GRANT_TYPE.into(), }; - let rt = tokio::runtime::Builder::new() - .core_threads(1) - .panic_handler(|e| std::panic::resume_unwind(e)) - .build() - .unwrap(); - // Successful path { - let code_response = r#"{"device_code": "devicecode", "user_code": "usercode", "verification_url": "https://example.com/verify", "expires_in": 1234567, "interval": 1}"#; + let code_response = serde_json::json!({ + "device_code": "devicecode", + "user_code": "usercode", + "verification_url": "https://example.com/verify", + "expires_in": 1234567, + "interval": 1 + }); let _m = mockito::mock("POST", "/code") .match_body(mockito::Matcher::Regex( ".*client_id=902216714886.*".to_string(), )) .with_status(200) - .with_body(code_response) + .with_body(code_response.to_string()) .create(); - let token_response = r#"{"access_token": "accesstoken", "refresh_token": "refreshtoken", "token_type": "Bearer", "expires_in": 1234567}"#; + let token_response = serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 1234567 + }); let _m = mockito::mock("POST", "/token") .match_body(mockito::Matcher::Regex( ".*client_secret=iuMPN6Ne1PD7cos29Tk9rlqH&code=devicecode.*".to_string(), )) .with_status(200) - .with_body(token_response) + .with_body(token_response.to_string()) .create(); - let fut = async { - let token = flow - .token( - &client, - &app_secret, - &["https://www.googleapis.com/scope/1"], - ) - .await - .unwrap(); - assert_eq!("accesstoken", token.access_token); - Ok(()) as Result<(), ()> - }; - rt.block_on(fut).expect("block_on"); - + let token = flow + .token( + &client, + &app_secret, + &["https://www.googleapis.com/scope/1"], + ) + .await + .expect("token failed"); + assert_eq!("accesstoken", token.access_token); _m.assert(); } + // Code is not delivered. { - let code_response = - r#"{"error": "invalid_client_id", "error_description": "description"}"#; + let code_response = serde_json::json!({ + "error": "invalid_client_id", + "error_description": "description" + }); let _m = mockito::mock("POST", "/code") .match_body(mockito::Matcher::Regex( ".*client_id=902216714886.*".to_string(), )) .with_status(400) - .with_body(code_response) + .with_body(code_response.to_string()) .create(); - let token_response = r#"{"access_token": "accesstoken", "refresh_token": "refreshtoken", "token_type": "Bearer", "expires_in": 1234567}"#; + let token_response = serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 1234567 + }); let _m = mockito::mock("POST", "/token") .match_body(mockito::Matcher::Regex( ".*client_secret=iuMPN6Ne1PD7cos29Tk9rlqH&code=devicecode.*".to_string(), )) .with_status(200) - .with_body(token_response) + .with_body(token_response.to_string()) .expect(0) // Never called! .create(); - let fut = async { - let res = flow - .token( - &client, - &app_secret, - &["https://www.googleapis.com/scope/1"], - ) - .await; - assert!(res.is_err()); - assert!(format!("{}", res.unwrap_err()).contains("invalid_client_id")); - Ok(()) as Result<(), ()> - }; - rt.block_on(fut).expect("block_on"); - + let res = flow + .token( + &client, + &app_secret, + &["https://www.googleapis.com/scope/1"], + ) + .await; + assert!(res.is_err()); + assert!(format!("{}", res.unwrap_err()).contains("invalid_client_id")); _m.assert(); } + // Token is not delivered. { - let code_response = r#"{"device_code": "devicecode", "user_code": "usercode", "verification_url": "https://example.com/verify", "expires_in": 1234567, "interval": 1}"#; + let code_response = serde_json::json!({ + "device_code": "devicecode", + "user_code": "usercode", + "verification_url": "https://example.com/verify", + "expires_in": 1234567, + "interval": 1 + }); let _m = mockito::mock("POST", "/code") .match_body(mockito::Matcher::Regex( ".*client_id=902216714886.*".to_string(), )) .with_status(200) - .with_body(code_response) + .with_body(code_response.to_string()) .create(); - let token_response = r#"{"error": "access_denied"}"#; + let token_response = serde_json::json!({"error": "access_denied"}); let _m = mockito::mock("POST", "/token") .match_body(mockito::Matcher::Regex( ".*client_secret=iuMPN6Ne1PD7cos29Tk9rlqH&code=devicecode.*".to_string(), )) .with_status(400) - .with_body(token_response) + .with_body(token_response.to_string()) .expect(1) .create(); - let fut = async { - let res = flow - .token( - &client, - &app_secret, - &["https://www.googleapis.com/scope/1"], - ) - .await; - assert!(res.is_err()); - assert!(format!("{}", res.unwrap_err()).contains("Access denied by user")); - Ok(()) as Result<(), ()> - }; - rt.block_on(fut).expect("block_on"); - + let res = flow + .token( + &client, + &app_secret, + &["https://www.googleapis.com/scope/1"], + ) + .await; + assert!(res.is_err()); + assert!(format!("{}", res.unwrap_err()).contains("Access denied by user")); _m.assert(); } } diff --git a/src/helper.rs b/src/helper.rs index e5465ac..18cced0 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -65,3 +65,13 @@ where debug_assert_eq!(size, result.len()); result } + +#[cfg(test)] +#[macro_export] +/// Utility function for parsing json. Useful in unit tests. Simply wrap the +/// json! macro in a from_value to deserialize the contents to arbitrary structs. +macro_rules! parse_json { + ($($json:tt)+) => { + ::serde_json::from_value(::serde_json::json!($($json)+)).expect("failed to deserialize") + } +} diff --git a/src/installed.rs b/src/installed.rs index 8fd64ae..d0805d0 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -397,15 +397,13 @@ mod tests { use hyper::client::connect::HttpConnector; use hyper::Uri; use hyper_rustls::HttpsConnector; - use mockito::{self, mock}; - use tokio; + use mockito::mock; use super::*; use crate::authenticator_delegate::FlowDelegate; - use crate::helper::*; - #[test] - fn test_end2end() { + #[tokio::test] + async fn test_end2end() { #[derive(Clone)] struct FD( String, @@ -455,9 +453,15 @@ mod tests { } let server_url = mockito::server_url(); - let app_secret = r#"{"installed":{"client_id":"902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com","project_id":"yup-test-243420","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_secret":"iuMPN6Ne1PD7cos29Tk9rlqH","redirect_uris":["urn:ietf:wg:oauth:2.0:oob","http://localhost"]}}"#; - let mut app_secret = parse_application_secret(app_secret).unwrap(); - app_secret.token_uri = format!("{}/token", server_url); + let app_secret: ApplicationSecret = crate::parse_json!({ + "client_id": "902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com", + "project_id": "yup-test-243420", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": format!("{}/token", server_url), + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_secret": "iuMPN6Ne1PD7cos29Tk9rlqH", + "redirect_uris": ["urn:ietf:wg:oauth:2.0:oob","http://localhost"] + }); let https = HttpsConnector::new(); let client = hyper::Client::builder() @@ -470,35 +474,34 @@ mod tests { flow_delegate: Box::new(fd), }; - let rt = tokio::runtime::Builder::new() - .core_threads(1) - .panic_handler(|e| std::panic::resume_unwind(e)) - .build() - .unwrap(); - // Successful path. { let _m = mock("POST", "/token") - .match_body(mockito::Matcher::Regex(".*code=authorizationcode.*client_id=9022167.*".to_string())) - .with_body(r#"{"access_token": "accesstoken", "refresh_token": "refreshtoken", "token_type": "Bearer", "expires_in": 12345678}"#) - .expect(1) - .create(); + .match_body(mockito::Matcher::Regex( + ".*code=authorizationcode.*client_id=9022167.*".to_string(), + )) + .with_body( + serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 12345678 + }) + .to_string(), + ) + .expect(1) + .create(); - let fut = || { - async { - let tok = inf - .token(&client, &app_secret, &["https://googleapis.com/some/scope"]) - .await - .map_err(|_| ())?; - assert_eq!("accesstoken", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); - Ok(()) as Result<(), ()> - } - }; - rt.block_on(fut()).expect("block on"); + let tok = inf + .token(&client, &app_secret, &["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!("accesstoken", tok.access_token); + assert_eq!("refreshtoken", tok.refresh_token.unwrap()); + assert_eq!("Bearer", tok.token_type); _m.assert(); } + // Successful path with HTTP redirect. { let inf = InstalledFlow { @@ -509,24 +512,31 @@ mod tests { )), }; let _m = mock("POST", "/token") - .match_body(mockito::Matcher::Regex(".*code=authorizationcodefromlocalserver.*client_id=9022167.*".to_string())) - .with_body(r#"{"access_token": "accesstoken", "refresh_token": "refreshtoken", "token_type": "Bearer", "expires_in": 12345678}"#) - .expect(1) - .create(); + .match_body(mockito::Matcher::Regex( + ".*code=authorizationcodefromlocalserver.*client_id=9022167.*".to_string(), + )) + .with_body( + serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 12345678 + }) + .to_string(), + ) + .expect(1) + .create(); - let fut = async { - let tok = inf - .token(&client, &app_secret, &["https://googleapis.com/some/scope"]) - .await - .map_err(|_| ())?; - assert_eq!("accesstoken", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); - Ok(()) as Result<(), ()> - }; - rt.block_on(fut).expect("block on"); + let tok = inf + .token(&client, &app_secret, &["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!("accesstoken", tok.access_token); + assert_eq!("refreshtoken", tok.refresh_token.unwrap()); + assert_eq!("Bearer", tok.token_type); _m.assert(); } + // Error from server. { let _m = mock("POST", "/token") @@ -534,22 +544,17 @@ mod tests { ".*code=authorizationcode.*client_id=9022167.*".to_string(), )) .with_status(400) - .with_body(r#"{"error": "invalid_code"}"#) + .with_body(serde_json::json!({"error": "invalid_code"}).to_string()) .expect(1) .create(); - let fut = async { - let tokr = inf - .token(&client, &app_secret, &["https://googleapis.com/some/scope"]) - .await; - assert!(tokr.is_err()); - assert!(format!("{}", tokr.unwrap_err()).contains("invalid_code")); - Ok(()) as Result<(), ()> - }; - rt.block_on(fut).expect("block on"); + let tokr = inf + .token(&client, &app_secret, &["https://googleapis.com/some/scope"]) + .await; + assert!(tokr.is_err()); + assert!(format!("{}", tokr.unwrap_err()).contains("invalid_code")); _m.assert(); } - rt.shutdown_on_idle(); } #[test] diff --git a/src/refresh.rs b/src/refresh.rs index 0a4a225..970cfe1 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -79,13 +79,10 @@ mod tests { use super::*; use crate::helper; - use hyper; use hyper_rustls::HttpsConnector; - use mockito; - use tokio; - #[test] - fn test_refresh_end2end() { + #[tokio::test] + async fn test_refresh_end2end() { let server_url = mockito::server_url(); let app_secret = r#"{"installed":{"client_id":"902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com","project_id":"yup-test-243420","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_secret":"iuMPN6Ne1PD7cos29Tk9rlqH","redirect_uris":["urn:ietf:wg:oauth:2.0:oob","http://localhost"]}}"#; @@ -98,12 +95,6 @@ mod tests { .keep_alive(false) .build::<_, hyper::Body>(https); - let rt = tokio::runtime::Builder::new() - .core_threads(1) - .panic_handler(|e| std::panic::resume_unwind(e)) - .build() - .unwrap(); - // Success { let _m = mockito::mock("POST", "/token") @@ -112,18 +103,14 @@ mod tests { .with_status(200) .with_body(r#"{"access_token": "new-access-token", "token_type": "Bearer", "expires_in": 1234567}"#) .create(); - let fut = async { - let token = RefreshFlow::refresh_token(&client, &app_secret, refresh_token) - .await - .unwrap(); - assert_eq!("new-access-token", token.access_token); - assert_eq!("Bearer", token.token_type); - Ok(()) as Result<(), ()> - }; - - rt.block_on(fut).expect("block_on"); + let token = RefreshFlow::refresh_token(&client, &app_secret, refresh_token) + .await + .expect("token failed"); + assert_eq!("new-access-token", token.access_token); + assert_eq!("Bearer", token.token_type); _m.assert(); } + // Refresh error. { let _m = mockito::mock("POST", "/token") @@ -133,18 +120,13 @@ mod tests { .with_body(r#"{"error": "invalid_token"}"#) .create(); - let fut = async { - let rr = RefreshFlow::refresh_token(&client, &app_secret, refresh_token).await; - match rr { - Err(RefreshError::ServerError(e, None)) => { - assert_eq!(e, "invalid_token"); - } - _ => panic!(format!("unexpected RefreshResult {:?}", rr)), + let rr = RefreshFlow::refresh_token(&client, &app_secret, refresh_token).await; + match rr { + Err(RefreshError::ServerError(e, None)) => { + assert_eq!(e, "invalid_token"); } - Ok(()) as Result<(), ()> - }; - - rt.block_on(fut).expect("block_on"); + _ => panic!(format!("unexpected RefreshResult {:?}", rr)), + } _m.assert(); } } diff --git a/src/service_account.rs b/src/service_account.rs index b88d5c8..e88022a 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -348,69 +348,54 @@ where mod tests { use super::*; use crate::helper::service_account_key_from_file; + use crate::parse_json; - use hyper; - use hyper_rustls::HttpsConnector; - use mockito::{self, mock}; - use tokio; + use mockito::mock; - #[test] - fn test_mocked_http() { + #[tokio::test] + async fn test_mocked_http() { env_logger::try_init().unwrap(); let server_url = &mockito::server_url(); - let client_secret = r#"{ - "type": "service_account", - "project_id": "yup-test-243420", - "private_key_id": "26de294916614a5ebdf7a065307ed3ea9941902b", - "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDemmylrvp1KcOn\n9yTAVVKPpnpYznvBvcAU8Qjwr2fSKylpn7FQI54wCk5VJVom0jHpAmhxDmNiP8yv\nHaqsef+87Oc0n1yZ71/IbeRcHZc2OBB33/LCFqf272kThyJo3qspEqhuAw0e8neg\nLQb4jpm9PsqR8IjOoAtXQSu3j0zkXemMYFy93PWHjVpPEUX16NGfsWH7oxspBHOk\n9JPGJL8VJdbiAoDSDgF0y9RjJY5I52UeHNhMsAkTYs6mIG4kKXt2+T9tAyHw8aho\nwmuytQAfydTflTfTG8abRtliF3nil2taAc5VB07dP1b4dVYy/9r6M8Z0z4XM7aP+\nNdn2TKm3AgMBAAECggEAWi54nqTlXcr2M5l535uRb5Xz0f+Q/pv3ceR2iT+ekXQf\n+mUSShOr9e1u76rKu5iDVNE/a7H3DGopa7ZamzZvp2PYhSacttZV2RbAIZtxU6th\n7JajPAM+t9klGh6wj4jKEcE30B3XVnbHhPJI9TCcUyFZoscuPXt0LLy/z8Uz0v4B\nd5JARwyxDMb53VXwukQ8nNY2jP7WtUig6zwE5lWBPFMbi8GwGkeGZOruAK5sPPwY\nGBAlfofKANI7xKx9UXhRwisB4+/XI1L0Q6xJySv9P+IAhDUI6z6kxR+WkyT/YpG3\nX9gSZJc7qEaxTIuDjtep9GTaoEqiGntjaFBRKoe+VQKBgQDzM1+Ii+REQqrGlUJo\nx7KiVNAIY/zggu866VyziU6h5wjpsoW+2Npv6Dv7nWvsvFodrwe50Y3IzKtquIal\nVd8aa50E72JNImtK/o5Nx6xK0VySjHX6cyKENxHRDnBmNfbALRM+vbD9zMD0lz2q\nmns/RwRGq3/98EqxP+nHgHSr9QKBgQDqUYsFAAfvfT4I75Glc9svRv8IsaemOm07\nW1LCwPnj1MWOhsTxpNF23YmCBupZGZPSBFQobgmHVjQ3AIo6I2ioV6A+G2Xq/JCF\nmzfbvZfqtbbd+nVgF9Jr1Ic5T4thQhAvDHGUN77BpjEqZCQLAnUWJx9x7e2xvuBl\n1A6XDwH/ewKBgQDv4hVyNyIR3nxaYjFd7tQZYHTOQenVffEAd9wzTtVbxuo4sRlR\nNM7JIRXBSvaATQzKSLHjLHqgvJi8LITLIlds1QbNLl4U3UVddJbiy3f7WGTqPFfG\nkLhUF4mgXpCpkMLxrcRU14Bz5vnQiDmQRM4ajS7/kfwue00BZpxuZxst3QKBgQCI\nRI3FhaQXyc0m4zPfdYYVc4NjqfVmfXoC1/REYHey4I1XetbT9Nb/+ow6ew0UbgSC\nUZQjwwJ1m1NYXU8FyovVwsfk9ogJ5YGiwYb1msfbbnv/keVq0c/Ed9+AG9th30qM\nIf93hAfClITpMz2mzXIMRQpLdmQSR4A2l+E4RjkSOwKBgQCB78AyIdIHSkDAnCxz\nupJjhxEhtQ88uoADxRoEga7H/2OFmmPsqfytU4+TWIdal4K+nBCBWRvAX1cU47vH\nJOlSOZI0gRKe0O4bRBQc8GXJn/ubhYSxI02IgkdGrIKpOb5GG10m85ZvqsXw3bKn\nRVHMD0ObF5iORjZUqD0yRitAdg==\n-----END PRIVATE KEY-----\n", - "client_email": "yup-test-sa-1@yup-test-243420.iam.gserviceaccount.com", - "client_id": "102851967901799660408", - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": "", - "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/yup-test-sa-1%40yup-test-243420.iam.gserviceaccount.com" -}"#; - let mut key: ServiceAccountKey = serde_json::from_str(client_secret).unwrap(); - key.token_uri = format!("{}/token", server_url); + let key: ServiceAccountKey = parse_json!({ + "type": "service_account", + "project_id": "yup-test-243420", + "private_key_id": "26de294916614a5ebdf7a065307ed3ea9941902b", + "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDemmylrvp1KcOn\n9yTAVVKPpnpYznvBvcAU8Qjwr2fSKylpn7FQI54wCk5VJVom0jHpAmhxDmNiP8yv\nHaqsef+87Oc0n1yZ71/IbeRcHZc2OBB33/LCFqf272kThyJo3qspEqhuAw0e8neg\nLQb4jpm9PsqR8IjOoAtXQSu3j0zkXemMYFy93PWHjVpPEUX16NGfsWH7oxspBHOk\n9JPGJL8VJdbiAoDSDgF0y9RjJY5I52UeHNhMsAkTYs6mIG4kKXt2+T9tAyHw8aho\nwmuytQAfydTflTfTG8abRtliF3nil2taAc5VB07dP1b4dVYy/9r6M8Z0z4XM7aP+\nNdn2TKm3AgMBAAECggEAWi54nqTlXcr2M5l535uRb5Xz0f+Q/pv3ceR2iT+ekXQf\n+mUSShOr9e1u76rKu5iDVNE/a7H3DGopa7ZamzZvp2PYhSacttZV2RbAIZtxU6th\n7JajPAM+t9klGh6wj4jKEcE30B3XVnbHhPJI9TCcUyFZoscuPXt0LLy/z8Uz0v4B\nd5JARwyxDMb53VXwukQ8nNY2jP7WtUig6zwE5lWBPFMbi8GwGkeGZOruAK5sPPwY\nGBAlfofKANI7xKx9UXhRwisB4+/XI1L0Q6xJySv9P+IAhDUI6z6kxR+WkyT/YpG3\nX9gSZJc7qEaxTIuDjtep9GTaoEqiGntjaFBRKoe+VQKBgQDzM1+Ii+REQqrGlUJo\nx7KiVNAIY/zggu866VyziU6h5wjpsoW+2Npv6Dv7nWvsvFodrwe50Y3IzKtquIal\nVd8aa50E72JNImtK/o5Nx6xK0VySjHX6cyKENxHRDnBmNfbALRM+vbD9zMD0lz2q\nmns/RwRGq3/98EqxP+nHgHSr9QKBgQDqUYsFAAfvfT4I75Glc9svRv8IsaemOm07\nW1LCwPnj1MWOhsTxpNF23YmCBupZGZPSBFQobgmHVjQ3AIo6I2ioV6A+G2Xq/JCF\nmzfbvZfqtbbd+nVgF9Jr1Ic5T4thQhAvDHGUN77BpjEqZCQLAnUWJx9x7e2xvuBl\n1A6XDwH/ewKBgQDv4hVyNyIR3nxaYjFd7tQZYHTOQenVffEAd9wzTtVbxuo4sRlR\nNM7JIRXBSvaATQzKSLHjLHqgvJi8LITLIlds1QbNLl4U3UVddJbiy3f7WGTqPFfG\nkLhUF4mgXpCpkMLxrcRU14Bz5vnQiDmQRM4ajS7/kfwue00BZpxuZxst3QKBgQCI\nRI3FhaQXyc0m4zPfdYYVc4NjqfVmfXoC1/REYHey4I1XetbT9Nb/+ow6ew0UbgSC\nUZQjwwJ1m1NYXU8FyovVwsfk9ogJ5YGiwYb1msfbbnv/keVq0c/Ed9+AG9th30qM\nIf93hAfClITpMz2mzXIMRQpLdmQSR4A2l+E4RjkSOwKBgQCB78AyIdIHSkDAnCxz\nupJjhxEhtQ88uoADxRoEga7H/2OFmmPsqfytU4+TWIdal4K+nBCBWRvAX1cU47vH\nJOlSOZI0gRKe0O4bRBQc8GXJn/ubhYSxI02IgkdGrIKpOb5GG10m85ZvqsXw3bKn\nRVHMD0ObF5iORjZUqD0yRitAdg==\n-----END PRIVATE KEY-----\n", + "client_email": "yup-test-sa-1@yup-test-243420.iam.gserviceaccount.com", + "client_id": "102851967901799660408", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": format!("{}/token", server_url), + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/yup-test-sa-1%40yup-test-243420.iam.gserviceaccount.com" + }); - let json_response = r#"{ - "access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk", - "expires_in": 3600, - "token_type": "Bearer" -}"#; - let bad_json_response = r#"{ - "access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk", - "token_type": "Bearer" -}"#; - - let https = HttpsConnector::new(); - let client = hyper::Client::builder() - .keep_alive(false) - .build::<_, hyper::Body>(https); - let rt = tokio::runtime::Builder::new() - .core_threads(1) - .panic_handler(|e| std::panic::resume_unwind(e)) - .build() - .unwrap(); + let json_response = serde_json::json!({ + "access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk", + "expires_in": 3600, + "token_type": "Bearer" + }); + let bad_json_response = serde_json::json!({ + "access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk", + "token_type": "Bearer" + }); // Successful path. { let _m = mock("POST", "/token") .with_status(200) .with_header("content-type", "text/json") - .with_body(json_response) + .with_body(json_response.to_string()) .expect(1) .create(); - let acc = ServiceAccountAccess::new(client.clone(), key.clone(), None).unwrap(); - let fut = async { - let tok = acc - .token(&["https://www.googleapis.com/auth/pubsub"]) - .await?; - assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); - assert_eq!(Some(3600), tok.expires_in); - Ok(()) as Result<(), Error> - }; - rt.block_on(fut).expect("block_on"); + let acc = ServiceAccountAuthenticator::builder(key.clone()) + .build() + .unwrap(); + let tok = acc + .token(&["https://www.googleapis.com/auth/pubsub"]) + .await + .expect("token failed"); + assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); + assert_eq!(Some(3600), tok.expires_in); assert!(acc .cache @@ -422,16 +407,12 @@ mod tests { ) .is_some()); // Test that token is in cache (otherwise mock will tell us) - let fut = async { - let tok = acc - .token(&["https://www.googleapis.com/auth/pubsub"]) - .await?; - assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); - assert_eq!(Some(3600), tok.expires_in); - Ok(()) as Result<(), Error> - }; - rt.block_on(fut).expect("block_on 2"); - + let tok = acc + .token(&["https://www.googleapis.com/auth/pubsub"]) + .await + .expect("token failed"); + assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); + assert_eq!(Some(3600), tok.expires_in); _m.assert(); } // Malformed response. @@ -439,48 +420,30 @@ mod tests { let _m = mock("POST", "/token") .with_status(200) .with_header("content-type", "text/json") - .with_body(bad_json_response) + .with_body(bad_json_response.to_string()) .create(); let acc = ServiceAccountAuthenticator::builder(key.clone()) - .hyper_client(client.clone()) .build() .unwrap(); - let fut = async { - let result = acc.token(&["https://www.googleapis.com/auth/pubsub"]).await; - assert!(result.is_err()); - Ok(()) as Result<(), ()> - }; - rt.block_on(fut).expect("block_on"); + let result = acc.token(&["https://www.googleapis.com/auth/pubsub"]).await; + assert!(result.is_err()); _m.assert(); } - rt.shutdown_on_idle(); } // Valid but deactivated key. const TEST_PRIVATE_KEY_PATH: &'static str = "examples/Sanguine-69411a0c0eea.json"; // Uncomment this test to verify that we can successfully obtain tokens. - //#[test] + //#[tokio::test] #[allow(dead_code)] - fn test_service_account_e2e() { + async fn test_service_account_e2e() { let key = service_account_key_from_file(&TEST_PRIVATE_KEY_PATH.to_string()).unwrap(); - let https = HttpsConnector::new(); - let client = hyper::Client::builder().build(https); - let acc = ServiceAccountAuthenticator::builder(key) - .hyper_client(client) - .build() - .unwrap(); - let rt = tokio::runtime::Builder::new() - .core_threads(1) - .panic_handler(|e| std::panic::resume_unwind(e)) - .build() - .unwrap(); - rt.block_on(async { - println!( - "{:?}", - acc.token(&["https://www.googleapis.com/auth/pubsub"]).await - ); - }); + let acc = ServiceAccountAuthenticator::builder(key).build().unwrap(); + println!( + "{:?}", + acc.token(&["https://www.googleapis.com/auth/pubsub"]).await + ); } #[test] From d17c7602761033735bd1d24982e8a750f459d142 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 14 Nov 2019 14:10:11 -0800 Subject: [PATCH 33/71] Remove an obsolete todo. --- src/refresh.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/refresh.rs b/src/refresh.rs index 970cfe1..e312287 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -34,7 +34,6 @@ impl RefreshFlow { client_secret: &ApplicationSecret, refresh_token: &str, ) -> Result { - // TODO: Does this function ever return Error? Maybe have it just return RefreshResult. let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ ("client_id", client_secret.client_id.as_str()), From b6b48594b9fef5096d2bff7326376a280513da02 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 14 Nov 2019 14:12:30 -0800 Subject: [PATCH 34/71] Remove dev-dependencies that are no longer used. These appear to only be used by examples in the old/ directory which has not compiled for a long time. Not sure why the contents of that directory are still around. --- Cargo.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7f0ed0c..5503630 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,9 +26,6 @@ tokio = "=0.2.0-alpha.6" futures-util-preview = "=0.3.0-alpha.19" [dev-dependencies] -getopts = "0.2" -open = "1.1" -yup-hyper-mock = "3.14" mockito = "0.17" env_logger = "0.6" From f76dea5dbdf6e197477abd7ee32f308810b2c3bd Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 14 Nov 2019 17:03:34 -0800 Subject: [PATCH 35/71] Add header styling to the AuthenticatorBuilder rustdoc --- src/authenticator.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index bdade8f..ae66fda 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -139,7 +139,7 @@ impl DeviceFlowAuthenticator { } } -/// Methods available when building any Authenticator. +/// ## Methods available when building any Authenticator. /// ``` /// # async fn foo() { /// # let custom_hyper_client = hyper::Client::new(); @@ -225,7 +225,7 @@ impl AuthenticatorBuilder { } } -/// Methods available when building a device flow Authenticator. +/// ## Methods available when building a device flow Authenticator. /// ``` /// # async fn foo() { /// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultFlowDelegate; @@ -286,7 +286,7 @@ impl AuthenticatorBuilder { } } -/// Methods available when building an installed flow Authenticator. +/// ## Methods available when building an installed flow Authenticator. /// ``` /// # async fn foo() { /// # use yup_oauth2::InstalledFlowReturnMethod; From 4b4b2fe3f4cabc9cf8f922ee3bfd7c9d7bdd10fb Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 15 Nov 2019 09:39:07 -0800 Subject: [PATCH 36/71] refactor storage get and set methods. These previously accepted a hash and scopes. The hash was required to be a hash of the provided scopes but that wasn't enforced by the compiler. We now have the compiler enforce that by creating a HashedScopes type that ties the scopes and the hash together and pass that into the storage methods. --- src/authenticator.rs | 10 ++-- src/service_account.rs | 15 +++--- src/storage.rs | 114 +++++++++++++++++++++++++++-------------- 3 files changed, 85 insertions(+), 54 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index ae66fda..75f6888 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -35,8 +35,8 @@ where where T: AsRef, { - let scope_key = storage::ScopeHash::new(scopes); - match self.storage.get(scope_key, scopes) { + let hashed_scopes = storage::HashedScopes::from(scopes); + match self.storage.get(hashed_scopes) { Some(t) if !t.expired() => { // unexpired token found Ok(t) @@ -59,9 +59,7 @@ where } Ok(token) => token, }; - self.storage - .set(scope_key, scopes, Some(token.clone())) - .await; + self.storage.set(hashed_scopes, Some(token.clone())).await; Ok(token) } None @@ -74,7 +72,7 @@ where .auth_flow .token(&self.hyper_client, &self.app_secret, scopes) .await?; - self.storage.set(scope_key, scopes, Some(t.clone())).await; + self.storage.set(hashed_scopes, Some(t.clone())).await; Ok(t) } } diff --git a/src/service_account.rs b/src/service_account.rs index e88022a..c81b6cc 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -263,9 +263,9 @@ where where T: AsRef, { - let hash = storage::ScopeHash::new(scopes); + let hashed_scopes = storage::HashedScopes::from(scopes); let cache = &self.cache; - match cache.get(hash, scopes) { + match cache.get(hashed_scopes) { Some(token) if !token.expired() => return Ok(token), _ => {} } @@ -277,7 +277,7 @@ where scopes, ) .await?; - cache.set(hash, scopes, Some(token.clone())).await; + cache.set(hashed_scopes, Some(token.clone())).await; Ok(token) } /// Send a request for a new Bearer token to the OAuth provider. @@ -399,12 +399,9 @@ mod tests { assert!(acc .cache - .get( - dbg!(storage::ScopeHash::new(&[ - "https://www.googleapis.com/auth/pubsub" - ])), - &["https://www.googleapis.com/auth/pubsub"], - ) + .get(storage::HashedScopes::from(&[ + "https://www.googleapis.com/auth/pubsub" + ])) .is_some()); // Test that token is in cache (otherwise mock will tell us) let tok = acc diff --git a/src/storage.rs b/src/storage.rs index 7192f77..0dd8720 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -13,24 +13,53 @@ use std::sync::Mutex; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub struct ScopeHash(u64); +#[derive(Debug)] +pub struct HashedScopes<'a, T> { + hash: u64, + scopes: &'a [T], +} -impl ScopeHash { - /// Calculate a hash value describing the scopes. The order of the scopes in the - /// list does not change the hash value. i.e. two lists that contains the exact - /// same scopes, but in different order will return the same hash value. - pub fn new(scopes: &[T]) -> Self - where - T: AsRef, - { - let mut hash_sum = DefaultHasher::new().finish(); +// Implement Clone manually. Auto derive fails to work correctly because we want +// Clone to be implemented regardless of whether T is Clone or not. +impl<'a, T> Clone for HashedScopes<'a, T> { + fn clone(&self) -> Self { + HashedScopes { + hash: self.hash, + scopes: self.scopes, + } + } +} +impl<'a, T> Copy for HashedScopes<'a, T> {} + +impl<'a, T> From<&'a [T]> for HashedScopes<'a, T> +where + T: AsRef, +{ + fn from(scopes: &'a [T]) -> Self { + // Calculate a hash value describing the scopes. The order of the scopes in the + // list does not change the hash value. i.e. two lists that contains the exact + // same scopes, but in different order will return the same hash value. + let mut hash = DefaultHasher::new().finish(); for scope in scopes { let mut hasher = DefaultHasher::new(); scope.as_ref().hash(&mut hasher); - hash_sum ^= hasher.finish(); + hash ^= hasher.finish(); } - ScopeHash(hash_sum) + HashedScopes { hash, scopes } + } +} + +impl<'a, T> HashedScopes<'a, T> +where + T: AsRef, +{ + // implement an inherent from method even though From is implemented. This + // is because passing an array ref like &[&str; 1] (&["foo"]) will be auto + // deref'd to a slice on function boundaries, but it will not implement the + // From trait. This inherent method just serves to auto deref from array + // refs to slices and proxy to the From impl. + pub fn from(scopes: &'a [T]) -> Self { + >::from(scopes) } } @@ -40,23 +69,23 @@ pub(crate) enum Storage { } impl Storage { - pub(crate) async fn set(&self, h: ScopeHash, scopes: &[T], token: Option) + pub(crate) async fn set(&self, scopes: HashedScopes<'_, T>, token: Option) where T: AsRef, { match self { - Storage::Memory { tokens } => tokens.lock().unwrap().set(h, scopes, token), - Storage::Disk(disk_storage) => disk_storage.set(h, scopes, token).await, + Storage::Memory { tokens } => tokens.lock().unwrap().set(scopes, token), + Storage::Disk(disk_storage) => disk_storage.set(scopes, token).await, } } - pub(crate) fn get(&self, h: ScopeHash, scopes: &[T]) -> Option + pub(crate) fn get(&self, scopes: HashedScopes) -> Option where T: AsRef, { match self { - Storage::Memory { tokens } => tokens.lock().unwrap().get(h, scopes), - Storage::Disk(disk_storage) => disk_storage.get(h, scopes), + Storage::Memory { tokens } => tokens.lock().unwrap().get(scopes), + Storage::Disk(disk_storage) => disk_storage.get(scopes), } } } @@ -64,7 +93,7 @@ impl Storage { /// A single stored token. #[derive(Debug, Clone, Serialize, Deserialize)] struct JSONToken { - pub hash: ScopeHash, + pub hash: u64, pub scopes: Option>, pub token: Token, } @@ -107,38 +136,45 @@ impl JSONTokens { Ok(container) } - fn get(&self, h: ScopeHash, scopes: &[T]) -> Option + fn get(&self, scopes: HashedScopes) -> Option where T: AsRef, { for t in self.tokens.iter() { if let Some(token_scopes) = &t.scopes { if scopes + .scopes .iter() .all(|s| token_scopes.iter().any(|t| t == s.as_ref())) { return Some(t.token.clone()); } - } else if h == t.hash { + } else if scopes.hash == t.hash { return Some(t.token.clone()); } } None } - fn set(&mut self, h: ScopeHash, scopes: &[T], token: Option) + fn set(&mut self, scopes: HashedScopes, token: Option) where T: AsRef, { - eprintln!("setting: {:?}, {:?}", h, token); - self.tokens.retain(|x| x.hash != h); + eprintln!("setting: {:?}, {:?}", scopes.hash, token); + self.tokens.retain(|x| x.hash != scopes.hash); match token { None => (), Some(t) => { self.tokens.push(JSONToken { - hash: h, - scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()), + hash: scopes.hash, + scopes: Some( + scopes + .scopes + .iter() + .map(|x| x.as_ref().to_string()) + .collect(), + ), token: t, }); } @@ -183,13 +219,13 @@ impl DiskStorage { }) } - async fn set(&self, h: ScopeHash, scopes: &[T], token: Option) + async fn set(&self, scopes: HashedScopes<'_, T>, token: Option) where T: AsRef, { let cloned_tokens = { let mut tokens = self.tokens.lock().unwrap(); - tokens.set(h, scopes, token); + tokens.set(scopes, token); tokens.clone() }; self.write_tx @@ -199,11 +235,11 @@ impl DiskStorage { .expect("disk storage task not running"); } - pub(crate) fn get(&self, h: ScopeHash, scopes: &[T]) -> Option + pub(crate) fn get(&self, scopes: HashedScopes) -> Option where T: AsRef, { - self.tokens.lock().unwrap().get(h, scopes) + self.tokens.lock().unwrap().get(scopes) } } @@ -215,24 +251,24 @@ mod tests { fn test_hash_scopes() { // Idential list should hash equal. assert_eq!( - ScopeHash::new(&["foo", "bar"]), - ScopeHash::new(&["foo", "bar"]) + HashedScopes::from(&["foo", "bar"]).hash, + HashedScopes::from(&["foo", "bar"]).hash, ); // The hash should be order independent. assert_eq!( - ScopeHash::new(&["bar", "foo"]), - ScopeHash::new(&["foo", "bar"]) + HashedScopes::from(&["bar", "foo"]).hash, + HashedScopes::from(&["foo", "bar"]).hash, ); assert_eq!( - ScopeHash::new(&["bar", "baz", "bat"]), - ScopeHash::new(&["baz", "bar", "bat"]) + HashedScopes::from(&["bar", "baz", "bat"]).hash, + HashedScopes::from(&["baz", "bar", "bat"]).hash, ); // Ensure hashes differ when the contents are different by more than // just order. assert_ne!( - ScopeHash::new(&["foo", "bar", "baz"]), - ScopeHash::new(&["foo", "bar"]) + HashedScopes::from(&["foo", "bar", "baz"]).hash, + HashedScopes::from(&["foo", "bar"]).hash, ); } } From b70d07aac2c088f9259a860802a89dea1b78d6c3 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 15 Nov 2019 09:48:18 -0800 Subject: [PATCH 37/71] storage set method should just accept a Token rather than Option. No caller ever provided a None value. Presumably a None value should delete the token, but it didn't do that and that would be more clearly done with a remove or delete method. --- src/authenticator.rs | 4 ++-- src/service_account.rs | 2 +- src/storage.rs | 33 ++++++++++++++------------------- 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index 75f6888..67b2295 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -59,7 +59,7 @@ where } Ok(token) => token, }; - self.storage.set(hashed_scopes, Some(token.clone())).await; + self.storage.set(hashed_scopes, token.clone()).await; Ok(token) } None @@ -72,7 +72,7 @@ where .auth_flow .token(&self.hyper_client, &self.app_secret, scopes) .await?; - self.storage.set(hashed_scopes, Some(t.clone())).await; + self.storage.set(hashed_scopes, t.clone()).await; Ok(t) } } diff --git a/src/service_account.rs b/src/service_account.rs index c81b6cc..c672357 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -277,7 +277,7 @@ where scopes, ) .await?; - cache.set(hashed_scopes, Some(token.clone())).await; + cache.set(hashed_scopes, token.clone()).await; Ok(token) } /// Send a request for a new Bearer token to the OAuth provider. diff --git a/src/storage.rs b/src/storage.rs index 0dd8720..5481700 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -69,7 +69,7 @@ pub(crate) enum Storage { } impl Storage { - pub(crate) async fn set(&self, scopes: HashedScopes<'_, T>, token: Option) + pub(crate) async fn set(&self, scopes: HashedScopes<'_, T>, token: Token) where T: AsRef, { @@ -156,29 +156,24 @@ impl JSONTokens { None } - fn set(&mut self, scopes: HashedScopes, token: Option) + fn set(&mut self, scopes: HashedScopes, token: Token) where T: AsRef, { eprintln!("setting: {:?}, {:?}", scopes.hash, token); self.tokens.retain(|x| x.hash != scopes.hash); - match token { - None => (), - Some(t) => { - self.tokens.push(JSONToken { - hash: scopes.hash, - scopes: Some( - scopes - .scopes - .iter() - .map(|x| x.as_ref().to_string()) - .collect(), - ), - token: t, - }); - } - } + self.tokens.push(JSONToken { + hash: scopes.hash, + scopes: Some( + scopes + .scopes + .iter() + .map(|x| x.as_ref().to_string()) + .collect(), + ), + token, + }); } // TODO: ideally this function would accept &Path, but tokio requires the @@ -219,7 +214,7 @@ impl DiskStorage { }) } - async fn set(&self, scopes: HashedScopes<'_, T>, token: Option) + async fn set(&self, scopes: HashedScopes<'_, T>, token: Token) where T: AsRef, { From baa8d566532776366af6dffbdf2427f8c2941639 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 15 Nov 2019 09:57:28 -0800 Subject: [PATCH 38/71] JSONToken should always contain scopes. This is already the case when writing a token file. Presumably the only reason it was an Option was for backwards compatibility, but we're already breaking compatibility with the change to the hash value so this seems like an appropriate time to make the change. This change also highlights how unused the hash value has been previously. Future changes plan to use the hash value for more efficient handling. --- src/storage.rs | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/src/storage.rs b/src/storage.rs index 5481700..c8a2696 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -94,7 +94,7 @@ impl Storage { #[derive(Debug, Clone, Serialize, Deserialize)] struct JSONToken { pub hash: u64, - pub scopes: Option>, + pub scopes: Vec, pub token: Token, } @@ -141,15 +141,11 @@ impl JSONTokens { T: AsRef, { for t in self.tokens.iter() { - if let Some(token_scopes) = &t.scopes { - if scopes - .scopes - .iter() - .all(|s| token_scopes.iter().any(|t| t == s.as_ref())) - { - return Some(t.token.clone()); - } - } else if scopes.hash == t.hash { + if scopes + .scopes + .iter() + .all(|s| t.scopes.iter().any(|t| t == s.as_ref())) + { return Some(t.token.clone()); } } @@ -165,13 +161,11 @@ impl JSONTokens { self.tokens.push(JSONToken { hash: scopes.hash, - scopes: Some( - scopes - .scopes - .iter() - .map(|x| x.as_ref().to_string()) - .collect(), - ), + scopes: scopes + .scopes + .iter() + .map(|x| x.as_ref().to_string()) + .collect(), token, }); } From 089c6ba212b3ecc81af61a54273820cc42bdea0e Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 15 Nov 2019 12:47:02 -0800 Subject: [PATCH 39/71] Use seahash rather that DefaultHasher. DefaultHasher is not documented as being consistent. It's best to not trust that the resulting hash value is consistent even across different executions of the same binary and even more so across different versions. --- Cargo.toml | 1 + src/storage.rs | 34 +++++----------------------------- 2 files changed, 6 insertions(+), 29 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5503630..5edf915 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ url = "1" futures-preview = "=0.3.0-alpha.19" tokio = "=0.2.0-alpha.6" futures-util-preview = "=0.3.0-alpha.19" +seahash = "3.0.6" [dev-dependencies] mockito = "0.17" diff --git a/src/storage.rs b/src/storage.rs index c8a2696..051ac96 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -4,9 +4,6 @@ // use crate::types::Token; -use std::cmp::Ordering; -use std::collections::hash_map::DefaultHasher; -use std::hash::{Hash, Hasher}; use std::io; use std::path::{Path, PathBuf}; use std::sync::Mutex; @@ -39,12 +36,11 @@ where // Calculate a hash value describing the scopes. The order of the scopes in the // list does not change the hash value. i.e. two lists that contains the exact // same scopes, but in different order will return the same hash value. - let mut hash = DefaultHasher::new().finish(); - for scope in scopes { - let mut hasher = DefaultHasher::new(); - scope.as_ref().hash(&mut hasher); - hash ^= hasher.finish(); - } + // Use seahash because it's fast and guaranteed to remain consistent, + // even across different executions and versions. + let hash = scopes.iter().fold(0u64, |h, scope| { + h ^ seahash::hash(scope.as_ref().as_bytes()) + }); HashedScopes { hash, scopes } } } @@ -98,26 +94,6 @@ struct JSONToken { pub token: Token, } -impl PartialEq for JSONToken { - fn eq(&self, other: &Self) -> bool { - self.hash == other.hash - } -} - -impl Eq for JSONToken {} - -impl PartialOrd for JSONToken { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for JSONToken { - fn cmp(&self, other: &Self) -> Ordering { - self.hash.cmp(&other.hash) - } -} - /// List of tokens in a JSON object #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct JSONTokens { From 5be2eadecadd606352feae7d475946639d6f47bd Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 15 Nov 2019 13:01:11 -0800 Subject: [PATCH 40/71] Use the storage token hash more effectively. Use a BTreeMap to key the tokens by the hash value. On retrieval first look for a matching hash value and return it if it exists. Only if it does not exist does it fallback to the subset matching. This makes the common case where an application uses a consistent set of scopes more efficient without detrimentally impacting the less common cases. --- src/storage.rs | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/src/storage.rs b/src/storage.rs index 051ac96..fbef74c 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -4,6 +4,7 @@ // use crate::types::Token; +use std::collections::BTreeMap; use std::io; use std::path::{Path, PathBuf}; use std::sync::Mutex; @@ -89,7 +90,6 @@ impl Storage { /// A single stored token. #[derive(Debug, Clone, Serialize, Deserialize)] struct JSONToken { - pub hash: u64, pub scopes: Vec, pub token: Token, } @@ -97,12 +97,14 @@ struct JSONToken { /// List of tokens in a JSON object #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct JSONTokens { - tokens: Vec, + token_map: BTreeMap, } impl JSONTokens { pub(crate) fn new() -> Self { - JSONTokens { tokens: Vec::new() } + JSONTokens { + token_map: BTreeMap::new(), + } } pub(crate) async fn load_from_file(filename: &Path) -> Result { @@ -112,13 +114,20 @@ impl JSONTokens { Ok(container) } - fn get(&self, scopes: HashedScopes) -> Option + fn get(&self, HashedScopes { hash, scopes }: HashedScopes) -> Option where T: AsRef, { - for t in self.tokens.iter() { + // Check for existing hash first. This will match if we already have a + // token for the exact set of scopes requested. + if let Some(json_token) = self.token_map.get(&hash) { + return Some(json_token.token.clone()); + } + + // No exact match for the scopes provided. Search for any tokens that + // exist for a superset of the scopes requested. + for t in self.token_map.values() { if scopes - .scopes .iter() .all(|s| t.scopes.iter().any(|t| t == s.as_ref())) { @@ -128,22 +137,17 @@ impl JSONTokens { None } - fn set(&mut self, scopes: HashedScopes, token: Token) + fn set(&mut self, HashedScopes { hash, scopes }: HashedScopes, token: Token) where T: AsRef, { - eprintln!("setting: {:?}, {:?}", scopes.hash, token); - self.tokens.retain(|x| x.hash != scopes.hash); - - self.tokens.push(JSONToken { - hash: scopes.hash, - scopes: scopes - .scopes - .iter() - .map(|x| x.as_ref().to_string()) - .collect(), - token, - }); + self.token_map.insert( + hash, + JSONToken { + scopes: scopes.iter().map(|x| x.as_ref().to_string()).collect(), + token, + }, + ); } // TODO: ideally this function would accept &Path, but tokio requires the From 7c1664142c446a9be0d828557773c2234365ab6c Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Mon, 18 Nov 2019 09:35:11 -0800 Subject: [PATCH 41/71] Don't serialize the scope hash. Seahash is a stable hash, but there isn't any value in serializing it's value. Instead calculate the value of the hash when deserializing and only serialize the scopes and tokens. This provides flexibility to change the hash value in the future without breaking the on-disk format. --- src/storage.rs | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/storage.rs b/src/storage.rs index fbef74c..3ef0b34 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -109,9 +109,16 @@ impl JSONTokens { pub(crate) async fn load_from_file(filename: &Path) -> Result { let contents = tokio::fs::read(filename).await?; - let container: JSONTokens = serde_json::from_slice(&contents) + let token_vec: Vec = serde_json::from_slice(&contents) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - Ok(container) + let token_map: BTreeMap = token_vec + .into_iter() + .map(|json_token| { + let hash = HashedScopes::from(&json_token.scopes).hash; + (hash, json_token) + }) + .collect(); + Ok(JSONTokens { token_map }) } fn get(&self, HashedScopes { hash, scopes }: HashedScopes) -> Option @@ -150,19 +157,14 @@ impl JSONTokens { ); } - // TODO: ideally this function would accept &Path, but tokio requires the - // path be 'static. Revisit this and ask why tokio::fs::write has that - // limitation. - async fn dump_to_file(&self, path: PathBuf) -> Result<(), io::Error> { - let serialized = serde_json::to_string(self) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - tokio::fs::write(path, &serialized).await + fn all_tokens(&self) -> Vec { + self.token_map.values().cloned().collect() } } pub(crate) struct DiskStorage { tokens: Mutex, - write_tx: tokio::sync::mpsc::Sender, + write_tx: tokio::sync::mpsc::Sender>, } impl DiskStorage { @@ -174,11 +176,16 @@ impl DiskStorage { // buffered channel. This ensures that the writes happen in the order // received, and if writes fall too far behind we will block GetToken // requests until disk i/o completes. - let (write_tx, mut write_rx) = tokio::sync::mpsc::channel::(2); + let (write_tx, mut write_rx) = tokio::sync::mpsc::channel::>(2); tokio::spawn(async move { while let Some(tokens) = write_rx.recv().await { - if let Err(e) = tokens.dump_to_file(path.to_path_buf()).await { - log::error!("Failed to write token storage to disk: {}", e); + match serde_json::to_string(&tokens) { + Err(e) => log::error!("Failed to serialize tokens: {}", e), + Ok(ser) => { + if let Err(e) = tokio::fs::write(path.clone(), &ser).await { + log::error!("Failed to write tokens to disk: {}", e); + } + } } } }); @@ -195,7 +202,7 @@ impl DiskStorage { let cloned_tokens = { let mut tokens = self.tokens.lock().unwrap(); tokens.set(scopes, token); - tokens.clone() + tokens.all_tokens() }; self.write_tx .clone() From 8f8455376999b574b8c2b76094593e715ac71152 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Mon, 18 Nov 2019 12:34:59 -0800 Subject: [PATCH 42/71] Use a bloom filter to track scopes. Each token is stored along with a 64bit bloom filter that is created from the set of scopes associated with that token. When retrieving tokens for a set of scopes a new bloom filter is calculated for the requested scopes and compared to the filters of all previously fetched scopes. The bloom filter allows for efficiently skipping entries that are definitely not a superset. --- src/authenticator.rs | 2 +- src/service_account.rs | 4 +- src/storage.rs | 193 ++++++++++++++++++++++++++--------------- 3 files changed, 126 insertions(+), 73 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index 67b2295..1402448 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -35,7 +35,7 @@ where where T: AsRef, { - let hashed_scopes = storage::HashedScopes::from(scopes); + let hashed_scopes = storage::ScopesAndFilter::from(scopes); match self.storage.get(hashed_scopes) { Some(t) if !t.expired() => { // unexpired token found diff --git a/src/service_account.rs b/src/service_account.rs index c672357..211357d 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -263,7 +263,7 @@ where where T: AsRef, { - let hashed_scopes = storage::HashedScopes::from(scopes); + let hashed_scopes = storage::ScopesAndFilter::from(scopes); let cache = &self.cache; match cache.get(hashed_scopes) { Some(token) if !token.expired() => return Ok(token), @@ -399,7 +399,7 @@ mod tests { assert!(acc .cache - .get(storage::HashedScopes::from(&[ + .get(storage::ScopesAndFilter::from(&[ "https://www.googleapis.com/auth/pubsub" ])) .is_some()); diff --git a/src/storage.rs b/src/storage.rs index 3ef0b34..92acfc2 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -11,42 +11,88 @@ use std::sync::Mutex; use serde::{Deserialize, Serialize}; +// The storage layer allows retrieving tokens for scopes that have been +// previously granted tokens. One wrinkle is that a token granted for a set +// of scopes X is also valid for any subset of X's scopes. So when retrieving a +// token for a set of scopes provided by the caller it's beneficial to compare +// that set to all previously stored tokens to see if it is a subset of any +// existing set. To do this efficiently we store a bloom filter along with each +// token that represents the set of scopes the token is associated with. The +// bloom filter allows for efficiently skipping any entries that are +// definitively not a superset. +// The current implementation uses a 64bit bloom filter with 4 hash functions. + +/// ScopeFilter represents a filter for a set of scopes. It can definitively +/// prove that a given list of scopes is not a subset of another. +#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] +struct ScopeFilter { + bitmask: u64, +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +enum FilterResponse { + Maybe, + No, +} + +impl ScopeFilter { + fn new(scopes: &[T]) -> Self + where + T: AsRef, + { + let mut bitmask = 0u64; + for scope in scopes { + let scope_hash = seahash::hash(scope.as_ref().as_bytes()); + // Use the first 4 6-bit chunks of the seahash as the 4 hash values + // in the bloom filter. + for i in 0..4 { + // h is a hash derived value in the range 0..64 + let h = (scope_hash >> (6 * i)) & 0b11_1111; + bitmask |= 1 << h; + } + } + ScopeFilter { bitmask } + } + + /// Determine if this ScopeFilter could be a subset of the provided filter. + fn is_subset_of(self, filter: ScopeFilter) -> FilterResponse { + if self.bitmask & filter.bitmask == self.bitmask { + FilterResponse::Maybe + } else { + FilterResponse::No + } + } +} + #[derive(Debug)] -pub struct HashedScopes<'a, T> { - hash: u64, +pub struct ScopesAndFilter<'a, T> { + filter: ScopeFilter, scopes: &'a [T], } // Implement Clone manually. Auto derive fails to work correctly because we want // Clone to be implemented regardless of whether T is Clone or not. -impl<'a, T> Clone for HashedScopes<'a, T> { +impl<'a, T> Clone for ScopesAndFilter<'a, T> { fn clone(&self) -> Self { - HashedScopes { - hash: self.hash, + ScopesAndFilter { + filter: self.filter, scopes: self.scopes, } } } -impl<'a, T> Copy for HashedScopes<'a, T> {} +impl<'a, T> Copy for ScopesAndFilter<'a, T> {} -impl<'a, T> From<&'a [T]> for HashedScopes<'a, T> +impl<'a, T> From<&'a [T]> for ScopesAndFilter<'a, T> where T: AsRef, { fn from(scopes: &'a [T]) -> Self { - // Calculate a hash value describing the scopes. The order of the scopes in the - // list does not change the hash value. i.e. two lists that contains the exact - // same scopes, but in different order will return the same hash value. - // Use seahash because it's fast and guaranteed to remain consistent, - // even across different executions and versions. - let hash = scopes.iter().fold(0u64, |h, scope| { - h ^ seahash::hash(scope.as_ref().as_bytes()) - }); - HashedScopes { hash, scopes } + let filter = ScopeFilter::new(scopes); + ScopesAndFilter { filter, scopes } } } -impl<'a, T> HashedScopes<'a, T> +impl<'a, T> ScopesAndFilter<'a, T> where T: AsRef, { @@ -66,7 +112,7 @@ pub(crate) enum Storage { } impl Storage { - pub(crate) async fn set(&self, scopes: HashedScopes<'_, T>, token: Token) + pub(crate) async fn set(&self, scopes: ScopesAndFilter<'_, T>, token: Token) where T: AsRef, { @@ -76,7 +122,7 @@ impl Storage { } } - pub(crate) fn get(&self, scopes: HashedScopes) -> Option + pub(crate) fn get(&self, scopes: ScopesAndFilter) -> Option where T: AsRef, { @@ -95,9 +141,9 @@ struct JSONToken { } /// List of tokens in a JSON object -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub(crate) struct JSONTokens { - token_map: BTreeMap, + token_map: BTreeMap>, } impl JSONTokens { @@ -111,54 +157,66 @@ impl JSONTokens { let contents = tokio::fs::read(filename).await?; let token_vec: Vec = serde_json::from_slice(&contents) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - let token_map: BTreeMap = token_vec - .into_iter() - .map(|json_token| { - let hash = HashedScopes::from(&json_token.scopes).hash; - (hash, json_token) - }) - .collect(); + let mut token_map: BTreeMap> = BTreeMap::new(); + for token in token_vec { + let filter = ScopesAndFilter::from(&token.scopes).filter; + token_map.entry(filter).or_default().push(token); + } Ok(JSONTokens { token_map }) } - fn get(&self, HashedScopes { hash, scopes }: HashedScopes) -> Option + fn get(&self, ScopesAndFilter { filter, scopes }: ScopesAndFilter) -> Option where T: AsRef, { - // Check for existing hash first. This will match if we already have a - // token for the exact set of scopes requested. - if let Some(json_token) = self.token_map.get(&hash) { - return Some(json_token.token.clone()); + let requested_scopes_are_subset_of = |other_scopes: &[String]| { + scopes + .iter() + .all(|s| other_scopes.iter().any(|t| t.as_str() == s.as_ref())) + }; + // Check for exact match of bloom filter first. In the common case an + // application will provide the same set of scopes repeatedly. If a + // token exists for the exact scope list requested a lookup of the + // ScopeFilter will return a list that would contain it. + if let Some(tokens) = self.token_map.get(&filter) { + for t in tokens { + if requested_scopes_are_subset_of(t.scopes.as_slice()) { + return Some(t.token.clone()); + } + } } // No exact match for the scopes provided. Search for any tokens that // exist for a superset of the scopes requested. - for t in self.token_map.values() { - if scopes - .iter() - .all(|s| t.scopes.iter().any(|t| t == s.as_ref())) - { + for t in self + .token_map + .iter() + .filter(|(k, _v)| filter.is_subset_of(**k) == FilterResponse::Maybe) + .flat_map(|(_, v)| v.iter()) + { + if requested_scopes_are_subset_of(&t.scopes) { return Some(t.token.clone()); } } None } - fn set(&mut self, HashedScopes { hash, scopes }: HashedScopes, token: Token) + fn set(&mut self, ScopesAndFilter { filter, scopes }: ScopesAndFilter, token: Token) where T: AsRef, { - self.token_map.insert( - hash, - JSONToken { - scopes: scopes.iter().map(|x| x.as_ref().to_string()).collect(), - token, - }, - ); + self.token_map.entry(filter).or_default().push(JSONToken { + scopes: scopes.iter().map(|x| x.as_ref().to_string()).collect(), + token, + }); } fn all_tokens(&self) -> Vec { - self.token_map.values().cloned().collect() + self.token_map + .values() + .flat_map(|v| v.iter()) + .cloned() + .collect() } } @@ -195,7 +253,7 @@ impl DiskStorage { }) } - async fn set(&self, scopes: HashedScopes<'_, T>, token: Token) + async fn set(&self, scopes: ScopesAndFilter<'_, T>, token: Token) where T: AsRef, { @@ -211,7 +269,7 @@ impl DiskStorage { .expect("disk storage task not running"); } - pub(crate) fn get(&self, scopes: HashedScopes) -> Option + pub(crate) fn get(&self, scopes: ScopesAndFilter) -> Option where T: AsRef, { @@ -224,27 +282,22 @@ mod tests { use super::*; #[test] - fn test_hash_scopes() { - // Idential list should hash equal. - assert_eq!( - HashedScopes::from(&["foo", "bar"]).hash, - HashedScopes::from(&["foo", "bar"]).hash, - ); - // The hash should be order independent. - assert_eq!( - HashedScopes::from(&["bar", "foo"]).hash, - HashedScopes::from(&["foo", "bar"]).hash, - ); - assert_eq!( - HashedScopes::from(&["bar", "baz", "bat"]).hash, - HashedScopes::from(&["baz", "bar", "bat"]).hash, - ); + fn test_scope_filter() { + let foo = ScopeFilter::new(&["foo"]); + let bar = ScopeFilter::new(&["bar"]); + let foobar = ScopeFilter::new(&["foo", "bar"]); - // Ensure hashes differ when the contents are different by more than - // just order. - assert_ne!( - HashedScopes::from(&["foo", "bar", "baz"]).hash, - HashedScopes::from(&["foo", "bar"]).hash, - ); + // foo and bar are both subsets of foobar. This condition should hold no + // matter what changes are made to the bloom filter implementation. + assert!(foo.is_subset_of(foobar) == FilterResponse::Maybe); + assert!(bar.is_subset_of(foobar) == FilterResponse::Maybe); + + // These conditions hold under the current bloom filter implementation + // because "foo" and "bar" don't collide, but if the bloom filter + // implementations change it could be valid for them to return Maybe. + assert!(foo.is_subset_of(bar) == FilterResponse::No); + assert!(bar.is_subset_of(foo) == FilterResponse::No); + assert!(foobar.is_subset_of(foo) == FilterResponse::No); + assert!(foobar.is_subset_of(bar) == FilterResponse::No); } } From d4b80a0c5c9c434f5ea511b5f33ca159d28216aa Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Wed, 20 Nov 2019 11:19:48 -0800 Subject: [PATCH 43/71] Fix a bug in refactoring the storage layer. Attempting to load from disk when the file does not exist should not return an error and should continue with an empty set of tokens. --- src/storage.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/storage.rs b/src/storage.rs index 92acfc2..35e123a 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -227,7 +227,12 @@ pub(crate) struct DiskStorage { impl DiskStorage { pub(crate) async fn new(path: PathBuf) -> Result { - let tokens = JSONTokens::load_from_file(&path).await?; + let tokens = match JSONTokens::load_from_file(&path).await { + Ok(tokens) => tokens, + Err(e) if e.kind() == io::ErrorKind::NotFound => JSONTokens::new(), + Err(e) => return Err(e), + }; + // Writing to disk will happen in a separate task. This means in the // common case returning a token to the user will not be required to // wait for disk i/o. We communicate with a dedicated writer task via a From 5256f642d78ab286ed2e4b2c83cbfc52b4b8d8e4 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Wed, 20 Nov 2019 14:01:17 -0800 Subject: [PATCH 44/71] Tie ServiceAccount's into Authenticator. Prior to this change DeviceFlow and InstalledFlow were used within Authenticator, while ServiceAccountAccess was used on it's own. AFAICT this was the case because ServiceAccountAccess never used refresh tokens and Authenticator assumed all tokens contained refresh tokens. Authenticator was recently modified to handle the case where a token does not contain a refresh token so I don't see any reason to keep the service account access separate anymore. Folding it into the authenticator provides a nice consistent interface, and the service account implementation no longer needs to provide it's own caching since it is now handled by Authenticator. --- examples/test-svc-acct/src/main.rs | 2 +- src/authenticator.rs | 193 +++++++++++++++++++--------- src/device.rs | 35 +++--- src/installed.rs | 20 +-- src/lib.rs | 9 +- src/service_account.rs | 195 ++++++++--------------------- src/storage.rs | 29 ++--- 7 files changed, 233 insertions(+), 250 deletions(-) diff --git a/examples/test-svc-acct/src/main.rs b/examples/test-svc-acct/src/main.rs index 8ff8db0..1b73a69 100644 --- a/examples/test-svc-acct/src/main.rs +++ b/examples/test-svc-acct/src/main.rs @@ -4,7 +4,7 @@ use yup_oauth2::ServiceAccountAuthenticator; #[tokio::main] async fn main() { let creds = yup_oauth2::service_account_key_from_file("serviceaccount.json").unwrap(); - let sa = ServiceAccountAuthenticator::builder(creds).build().unwrap(); + let sa = ServiceAccountAuthenticator::builder(creds).build().await.unwrap(); let scopes = &["https://www.googleapis.com/auth/pubsub"]; let tok = sa.token(scopes).await.unwrap(); diff --git a/src/authenticator.rs b/src/authenticator.rs index 1402448..b74a3c8 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -6,6 +6,7 @@ use crate::device::DeviceFlow; use crate::error::Error; use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; use crate::refresh::RefreshFlow; +use crate::service_account::{ServiceAccountFlow, ServiceAccountFlowOpts, ServiceAccountKey}; use crate::storage::{self, Storage}; use crate::types::{ApplicationSecret, Token}; use private::AuthFlow; @@ -20,7 +21,6 @@ use std::time::Duration; /// and optionally persisting tokens to disk. pub struct Authenticator { hyper_client: hyper::Client, - app_secret: ApplicationSecret, auth_delegate: Box, storage: Storage, auth_flow: AuthFlow, @@ -36,19 +36,22 @@ where T: AsRef, { let hashed_scopes = storage::ScopesAndFilter::from(scopes); - match self.storage.get(hashed_scopes) { - Some(t) if !t.expired() => { + match (self.storage.get(hashed_scopes), self.auth_flow.app_secret()) { + (Some(t), _) if !t.expired() => { // unexpired token found Ok(t) } - Some(Token { - refresh_token: Some(refresh_token), - .. - }) => { + ( + Some(Token { + refresh_token: Some(refresh_token), + .. + }), + Some(app_secret), + ) => { // token is expired but has a refresh token. let token = match RefreshFlow::refresh_token( &self.hyper_client, - &self.app_secret, + app_secret, &refresh_token, ) .await @@ -62,16 +65,9 @@ where self.storage.set(hashed_scopes, token.clone()).await; Ok(token) } - None - | Some(Token { - refresh_token: None, - .. - }) => { - // no token in the cache or the token returned does not contain a refresh token. - let t = self - .auth_flow - .token(&self.hyper_client, &self.app_secret, scopes) - .await?; + _ => { + // no token in the cache or the token returned can't be refreshed. + let t = self.auth_flow.token(&self.hyper_client, scopes).await?; self.storage.set(hashed_scopes, t.clone()).await; Ok(t) } @@ -82,7 +78,6 @@ where /// Configure an Authenticator using the builder pattern. pub struct AuthenticatorBuilder { hyper_client_builder: C, - app_secret: ApplicationSecret, auth_delegate: Box, storage_type: StorageType, auth_flow: F, @@ -110,10 +105,9 @@ impl InstalledFlowAuthenticator { app_secret: ApplicationSecret, method: InstalledFlowReturnMethod, ) -> AuthenticatorBuilder { - AuthenticatorBuilder::::with_auth_flow( - app_secret, - InstalledFlow::new(method), - ) + AuthenticatorBuilder::::with_auth_flow(InstalledFlow::new( + app_secret, method, + )) } } @@ -133,7 +127,30 @@ impl DeviceFlowAuthenticator { pub fn builder( app_secret: ApplicationSecret, ) -> AuthenticatorBuilder { - AuthenticatorBuilder::::with_auth_flow(app_secret, DeviceFlow::new()) + AuthenticatorBuilder::::with_auth_flow(DeviceFlow::new(app_secret)) + } +} + +/// Create an authenticator that uses a service account. +/// ``` +/// # async fn foo() { +/// # let service_account_key = yup_oauth2::service_account_key_from_file("/tmp/foo").unwrap(); +/// let authenticator = yup_oauth2::ServiceAccountAuthenticator::builder(service_account_key) +/// .build() +/// .await +/// .expect("failed to create authenticator"); +/// # } +/// ``` +pub struct ServiceAccountAuthenticator; +impl ServiceAccountAuthenticator { + /// Use the builder pattern to create an Authenticator that uses a service account. + pub fn builder( + service_account_key: ServiceAccountKey, + ) -> AuthenticatorBuilder { + AuthenticatorBuilder::::with_auth_flow(ServiceAccountFlowOpts { + key: service_account_key, + subject: None, + }) } } @@ -153,14 +170,17 @@ impl DeviceFlowAuthenticator { /// # } /// ``` impl AuthenticatorBuilder { - /// Create the authenticator. - pub async fn build(self) -> io::Result> + async fn common_build( + hyper_client_builder: C, + storage_type: StorageType, + auth_delegate: Box, + auth_flow: AuthFlow, + ) -> io::Result> where C: HyperClientBuilder, - F: Into, { - let hyper_client = self.hyper_client_builder.build_hyper_client(); - let storage = match self.storage_type { + let hyper_client = hyper_client_builder.build_hyper_client(); + let storage = match storage_type { StorageType::Memory => Storage::Memory { tokens: Mutex::new(storage::JSONTokens::new()), }, @@ -169,20 +189,15 @@ impl AuthenticatorBuilder { Ok(Authenticator { hyper_client, - app_secret: self.app_secret, storage, - auth_delegate: self.auth_delegate, - auth_flow: self.auth_flow.into(), + auth_delegate, + auth_flow, }) } - fn with_auth_flow( - app_secret: ApplicationSecret, - auth_flow: F, - ) -> AuthenticatorBuilder { + fn with_auth_flow(auth_flow: F) -> AuthenticatorBuilder { AuthenticatorBuilder { hyper_client_builder: DefaultHyperClient, - app_secret, auth_delegate: Box::new(DefaultAuthenticatorDelegate), storage_type: StorageType::Memory, auth_flow, @@ -196,7 +211,6 @@ impl AuthenticatorBuilder { ) -> AuthenticatorBuilder, F> { AuthenticatorBuilder { hyper_client_builder: hyper_client, - app_secret: self.app_secret, auth_delegate: self.auth_delegate, storage_type: self.storage_type, auth_flow: self.auth_flow, @@ -282,6 +296,20 @@ impl AuthenticatorBuilder { ..self } } + + /// Create the authenticator. + pub async fn build(self) -> io::Result> + where + C: HyperClientBuilder, + { + Self::common_build( + self.hyper_client_builder, + self.storage_type, + self.auth_delegate, + AuthFlow::DeviceFlow(self.auth_flow), + ) + .await + } } /// ## Methods available when building an installed flow Authenticator. @@ -311,36 +339,88 @@ impl AuthenticatorBuilder { ..self } } + + /// Create the authenticator. + pub async fn build(self) -> io::Result> + where + C: HyperClientBuilder, + { + Self::common_build( + self.hyper_client_builder, + self.storage_type, + self.auth_delegate, + AuthFlow::InstalledFlow(self.auth_flow), + ) + .await + } +} + +/// ## Methods available when building a service account authenticator. +/// ``` +/// # async fn foo() { +/// # let service_account_key = yup_oauth2::service_account_key_from_file("/tmp/foo").unwrap(); +/// let authenticator = yup_oauth2::ServiceAccountAuthenticator::builder( +/// service_account_key, +/// ) +/// .subject("mysubject") +/// .build() +/// .await +/// .expect("failed to create authenticator"); +/// # } +/// ``` +impl AuthenticatorBuilder { + /// Use the provided subject. + pub fn subject(self, subject: impl Into) -> Self { + AuthenticatorBuilder { + auth_flow: ServiceAccountFlowOpts { + subject: Some(subject.into()), + ..self.auth_flow + }, + ..self + } + } + + /// Create the authenticator. + pub async fn build(self) -> io::Result> + where + C: HyperClientBuilder, + { + let service_account_auth_flow = ServiceAccountFlow::new(self.auth_flow)?; + Self::common_build( + self.hyper_client_builder, + self.storage_type, + self.auth_delegate, + AuthFlow::ServiceAccountFlow(service_account_auth_flow), + ) + .await + } } mod private { use crate::device::DeviceFlow; use crate::error::Error; use crate::installed::InstalledFlow; + use crate::service_account::ServiceAccountFlow; use crate::types::{ApplicationSecret, Token}; pub enum AuthFlow { DeviceFlow(DeviceFlow), InstalledFlow(InstalledFlow), - } - - impl From for AuthFlow { - fn from(device_flow: DeviceFlow) -> AuthFlow { - AuthFlow::DeviceFlow(device_flow) - } - } - - impl From for AuthFlow { - fn from(installed_flow: InstalledFlow) -> AuthFlow { - AuthFlow::InstalledFlow(installed_flow) - } + ServiceAccountFlow(ServiceAccountFlow), } impl AuthFlow { + pub(crate) fn app_secret(&self) -> Option<&ApplicationSecret> { + match self { + AuthFlow::DeviceFlow(device_flow) => Some(&device_flow.app_secret), + AuthFlow::InstalledFlow(installed_flow) => Some(&installed_flow.app_secret), + AuthFlow::ServiceAccountFlow(_) => None, + } + } + pub(crate) async fn token<'a, C, T>( &'a self, hyper_client: &'a hyper::Client, - app_secret: &'a ApplicationSecret, scopes: &'a [T], ) -> Result where @@ -348,18 +428,19 @@ mod private { C: hyper::client::connect::Connect + 'static, { match self { - AuthFlow::DeviceFlow(device_flow) => { - device_flow.token(hyper_client, app_secret, scopes).await - } + AuthFlow::DeviceFlow(device_flow) => device_flow.token(hyper_client, scopes).await, AuthFlow::InstalledFlow(installed_flow) => { - installed_flow.token(hyper_client, app_secret, scopes).await + installed_flow.token(hyper_client, scopes).await + } + AuthFlow::ServiceAccountFlow(service_account_flow) => { + service_account_flow.token(hyper_client, scopes).await } } } } } -/// A trait implemented for any hyper::Client as well as teh DefaultHyperClient. +/// A trait implemented for any hyper::Client as well as the DefaultHyperClient. pub trait HyperClientBuilder { /// The hyper connector that the resulting hyper client will use. type Connector: hyper::client::connect::Connect + 'static; diff --git a/src/device.rs b/src/device.rs index 6815963..d2758a6 100644 --- a/src/device.rs +++ b/src/device.rs @@ -22,6 +22,7 @@ pub const GOOGLE_GRANT_TYPE: &str = "http://oauth.net/grant_type/device/1.0"; /// * obtain a code to show to the user // * (repeatedly) poll for the user to authenticate your application pub struct DeviceFlow { + pub(crate) app_secret: ApplicationSecret, pub(crate) device_code_url: Cow<'static, str>, pub(crate) flow_delegate: Box, pub(crate) wait_duration: Duration, @@ -31,8 +32,9 @@ pub struct DeviceFlow { impl DeviceFlow { /// Create a new DeviceFlow. The default FlowDelegate will be used and the /// default wait time is 120 seconds. - pub(crate) fn new() -> Self { + pub(crate) fn new(app_secret: ApplicationSecret) -> Self { DeviceFlow { + app_secret, device_code_url: GOOGLE_DEVICE_CODE_URL.into(), flow_delegate: Box::new(DefaultFlowDelegate), wait_duration: Duration::from_secs(120), @@ -43,20 +45,24 @@ impl DeviceFlow { pub(crate) async fn token( &self, hyper_client: &hyper::Client, - app_secret: &ApplicationSecret, scopes: &[T], ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, { - let (pollinf, device_code) = - Self::request_code(app_secret, hyper_client, &self.device_code_url, scopes).await?; + let (pollinf, device_code) = Self::request_code( + &self.app_secret, + hyper_client, + &self.device_code_url, + scopes, + ) + .await?; self.flow_delegate.present_user_code(&pollinf); tokio::timer::Timeout::new( self.wait_for_device_token( hyper_client, - app_secret, + &self.app_secret, &pollinf, &device_code, &self.grant_type, @@ -296,6 +302,7 @@ mod tests { .build::<_, hyper::Body>(https); let flow = DeviceFlow { + app_secret, device_code_url: device_code_url.into(), flow_delegate: Box::new(FD), wait_duration: Duration::from_secs(5), @@ -333,11 +340,7 @@ mod tests { .create(); let token = flow - .token( - &client, - &app_secret, - &["https://www.googleapis.com/scope/1"], - ) + .token(&client, &["https://www.googleapis.com/scope/1"]) .await .expect("token failed"); assert_eq!("accesstoken", token.access_token); @@ -373,11 +376,7 @@ mod tests { .create(); let res = flow - .token( - &client, - &app_secret, - &["https://www.googleapis.com/scope/1"], - ) + .token(&client, &["https://www.googleapis.com/scope/1"]) .await; assert!(res.is_err()); assert!(format!("{}", res.unwrap_err()).contains("invalid_client_id")); @@ -411,11 +410,7 @@ mod tests { .create(); let res = flow - .token( - &client, - &app_secret, - &["https://www.googleapis.com/scope/1"], - ) + .token(&client, &["https://www.googleapis.com/scope/1"]) .await; assert!(res.is_err()); assert!(format!("{}", res.unwrap_err()).contains("Access denied by user")); diff --git a/src/installed.rs b/src/installed.rs index d0805d0..81528f9 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -66,14 +66,19 @@ pub enum InstalledFlowReturnMethod { /// https://www.oauth.com/oauth2-servers/authorization/, /// https://developers.google.com/identity/protocols/OAuth2InstalledApp). pub struct InstalledFlow { + pub(crate) app_secret: ApplicationSecret, pub(crate) method: InstalledFlowReturnMethod, pub(crate) flow_delegate: Box, } impl InstalledFlow { /// Create a new InstalledFlow with the provided secret and method. - pub(crate) fn new(method: InstalledFlowReturnMethod) -> InstalledFlow { + pub(crate) fn new( + app_secret: ApplicationSecret, + method: InstalledFlowReturnMethod, + ) -> InstalledFlow { InstalledFlow { + app_secret, method, flow_delegate: Box::new(DefaultFlowDelegate), } @@ -88,7 +93,6 @@ impl InstalledFlow { pub(crate) async fn token( &self, hyper_client: &hyper::Client, - app_secret: &ApplicationSecret, scopes: &[T], ) -> Result where @@ -97,11 +101,11 @@ impl InstalledFlow { { match self.method { InstalledFlowReturnMethod::HTTPRedirect => { - self.ask_auth_code_via_http(hyper_client, app_secret, scopes) + self.ask_auth_code_via_http(hyper_client, &self.app_secret, scopes) .await } InstalledFlowReturnMethod::Interactive => { - self.ask_auth_code_interactively(hyper_client, app_secret, scopes) + self.ask_auth_code_interactively(hyper_client, &self.app_secret, scopes) .await } } @@ -470,6 +474,7 @@ mod tests { let fd = FD("authorizationcode".to_string(), client.clone()); let inf = InstalledFlow { + app_secret: app_secret.clone(), method: InstalledFlowReturnMethod::Interactive, flow_delegate: Box::new(fd), }; @@ -493,7 +498,7 @@ mod tests { .create(); let tok = inf - .token(&client, &app_secret, &["https://googleapis.com/some/scope"]) + .token(&client, &["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); assert_eq!("accesstoken", tok.access_token); @@ -505,6 +510,7 @@ mod tests { // Successful path with HTTP redirect. { let inf = InstalledFlow { + app_secret: app_secret.clone(), method: InstalledFlowReturnMethod::HTTPRedirect, flow_delegate: Box::new(FD( "authorizationcodefromlocalserver".to_string(), @@ -528,7 +534,7 @@ mod tests { .create(); let tok = inf - .token(&client, &app_secret, &["https://googleapis.com/some/scope"]) + .token(&client, &["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); assert_eq!("accesstoken", tok.access_token); @@ -549,7 +555,7 @@ mod tests { .create(); let tokr = inf - .token(&client, &app_secret, &["https://googleapis.com/some/scope"]) + .token(&client, &["https://googleapis.com/some/scope"]) .await; assert!(tokr.is_err()); assert!(format!("{}", tokr.unwrap_err()).contains("invalid_code")); diff --git a/src/lib.rs b/src/lib.rs index 824eb9f..81520f9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,18 +76,19 @@ pub mod error; mod helper; mod installed; mod refresh; -pub mod service_account; +mod service_account; mod storage; mod types; #[doc(inline)] -pub use crate::authenticator::{DeviceFlowAuthenticator, InstalledFlowAuthenticator}; +pub use crate::authenticator::{ + DeviceFlowAuthenticator, InstalledFlowAuthenticator, ServiceAccountAuthenticator, +}; pub use crate::helper::*; pub use crate::installed::InstalledFlowReturnMethod; -#[doc(inline)] -pub use crate::service_account::{ServiceAccountAuthenticator, ServiceAccountKey}; +pub use crate::service_account::ServiceAccountKey; #[doc(inline)] pub use crate::error::Error; diff --git a/src/service_account.rs b/src/service_account.rs index 211357d..abb4eda 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -11,13 +11,10 @@ //! Copyright (c) 2016 Google Inc (lewinb@google.com). //! -use crate::authenticator::{DefaultHyperClient, HyperClientBuilder}; use crate::error::{Error, JsonErrorOr}; -use crate::storage::{self, Storage}; use crate::types::Token; use std::io; -use std::sync::Mutex; use futures::prelude::*; use hyper::header; @@ -124,7 +121,7 @@ impl<'a> Claims<'a> { } /// A JSON Web Token ready for signing. -struct JWTSigner { +pub(crate) struct JWTSigner { signer: Box, } @@ -160,139 +157,40 @@ impl JWTSigner { } } -/// Create an authenticator that uses a service account. -/// ``` -/// # async fn foo() { -/// # let service_key = yup_oauth2::service_account_key_from_file("/tmp/foo").unwrap(); -/// let authenticator = yup_oauth2::ServiceAccountAuthenticator::builder(service_key) -/// .build() -/// .expect("failed to create authenticator"); -/// # } -/// ``` -pub struct ServiceAccountAuthenticator; -impl ServiceAccountAuthenticator { - /// Use the builder pattern to create an authenticator that uses a service - /// account. - pub fn builder(key: ServiceAccountKey) -> Builder { - Builder { - client: DefaultHyperClient, - key, - subject: None, - } - } +pub struct ServiceAccountFlowOpts { + pub(crate) key: ServiceAccountKey, + pub(crate) subject: Option, } -/// Configure a service account authenticator using the builder pattern. -pub struct Builder { - client: C, +/// ServiceAccountFlow can fetch oauth tokens using a service account. +pub struct ServiceAccountFlow { key: ServiceAccountKey, subject: Option, -} - -/// Methods available when building a service account authenticator. -/// ``` -/// # async fn foo() { -/// # let custom_hyper_client = hyper::Client::new(); -/// # let service_key = yup_oauth2::service_account_key_from_file("/tmp/foo").unwrap(); -/// let authenticator = yup_oauth2::ServiceAccountAuthenticator::builder(service_key) -/// .hyper_client(custom_hyper_client) -/// .subject("foo") -/// .build() -/// .expect("failed to create authenticator"); -/// # } -/// ``` -impl Builder { - /// Use the provided hyper client. - pub fn hyper_client(self, hyper_client: NewC) -> Builder { - Builder { - client: hyper_client, - key: self.key, - subject: self.subject, - } - } - - /// Use the provided subject. - pub fn subject(self, subject: impl Into) -> Self { - Builder { - subject: Some(subject.into()), - ..self - } - } - - /// Build the configured ServiceAccountAccess. - pub fn build(self) -> Result, io::Error> - where - C: HyperClientBuilder, - { - ServiceAccountAccess::new(self.client.build_hyper_client(), self.key, self.subject) - } -} - -/// ServiceAccountAccess can fetch oauth tokens using a service account. -pub struct ServiceAccountAccess { - client: hyper::Client, - key: ServiceAccountKey, - cache: Storage, - subject: Option, signer: JWTSigner, } -impl ServiceAccountAccess -where - C: hyper::client::connect::Connect + 'static, -{ - fn new( - client: hyper::Client, - key: ServiceAccountKey, - subject: Option, - ) -> Result { - let signer = JWTSigner::new(&key.private_key)?; - Ok(ServiceAccountAccess { - client, - key, - cache: Storage::Memory { - tokens: Mutex::new(storage::JSONTokens::new()), - }, - subject, +impl ServiceAccountFlow { + pub(crate) fn new(opts: ServiceAccountFlowOpts) -> Result { + let signer = JWTSigner::new(&opts.key.private_key)?; + Ok(ServiceAccountFlow { + key: opts.key, + subject: opts.subject, signer, }) } - /// Return the current token for the provided scopes. - pub async fn token(&self, scopes: &[T]) -> Result - where - T: AsRef, - { - let hashed_scopes = storage::ScopesAndFilter::from(scopes); - let cache = &self.cache; - match cache.get(hashed_scopes) { - Some(token) if !token.expired() => return Ok(token), - _ => {} - } - let token = Self::request_token( - &self.client, - &self.signer, - self.subject.as_ref().map(|x| x.as_str()), - &self.key, - scopes, - ) - .await?; - cache.set(hashed_scopes, token.clone()).await; - Ok(token) - } /// Send a request for a new Bearer token to the OAuth provider. - async fn request_token( - client: &hyper::client::Client, - signer: &JWTSigner, - subject: Option<&str>, - key: &ServiceAccountKey, + pub(crate) async fn token( + &self, + hyper_client: &hyper::Client, scopes: &[T], ) -> Result where T: AsRef, + C: hyper::client::connect::Connect + 'static, { - let claims = Claims::new(key, scopes, subject); - let signed = signer.sign_claims(&claims).map_err(|_| { + let claims = Claims::new(&self.key, scopes, self.subject.as_ref().map(|x| x.as_str())); + let signed = self.signer.sign_claims(&claims).map_err(|_| { Error::LowLevelError(io::Error::new( io::ErrorKind::Other, "unable to sign claims", @@ -301,11 +199,14 @@ where let rqbody = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", signed.as_str())]) .finish(); - let request = hyper::Request::post(&key.token_uri) + let request = hyper::Request::post(&self.key.token_uri) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(rqbody)) .unwrap(); - let response = client.request(request).await.map_err(Error::ClientError)?; + let response = hyper_client + .request(request) + .await + .map_err(Error::ClientError)?; let body = response .into_body() .try_concat() @@ -349,12 +250,17 @@ mod tests { use super::*; use crate::helper::service_account_key_from_file; use crate::parse_json; + use hyper_rustls::HttpsConnector; use mockito::mock; #[tokio::test] async fn test_mocked_http() { env_logger::try_init().unwrap(); + let https = HttpsConnector::new(); + let client = hyper::Client::builder() + .keep_alive(false) + .build::<_, hyper::Body>(https); let server_url = &mockito::server_url(); let key: ServiceAccountKey = parse_json!({ "type": "service_account", @@ -387,25 +293,13 @@ mod tests { .with_body(json_response.to_string()) .expect(1) .create(); - let acc = ServiceAccountAuthenticator::builder(key.clone()) - .build() - .unwrap(); + let acc = ServiceAccountFlow::new(ServiceAccountFlowOpts { + key: key.clone(), + subject: None, + }) + .unwrap(); let tok = acc - .token(&["https://www.googleapis.com/auth/pubsub"]) - .await - .expect("token failed"); - assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); - assert_eq!(Some(3600), tok.expires_in); - - assert!(acc - .cache - .get(storage::ScopesAndFilter::from(&[ - "https://www.googleapis.com/auth/pubsub" - ])) - .is_some()); - // Test that token is in cache (otherwise mock will tell us) - let tok = acc - .token(&["https://www.googleapis.com/auth/pubsub"]) + .token(&client, &["https://www.googleapis.com/auth/pubsub"]) .await .expect("token failed"); assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); @@ -419,10 +313,14 @@ mod tests { .with_header("content-type", "text/json") .with_body(bad_json_response.to_string()) .create(); - let acc = ServiceAccountAuthenticator::builder(key.clone()) - .build() - .unwrap(); - let result = acc.token(&["https://www.googleapis.com/auth/pubsub"]).await; + let acc = ServiceAccountFlow::new(ServiceAccountFlowOpts { + key: key.clone(), + subject: None, + }) + .unwrap(); + let result = acc + .token(&client, &["https://www.googleapis.com/auth/pubsub"]) + .await; assert!(result.is_err()); _m.assert(); } @@ -436,10 +334,15 @@ mod tests { #[allow(dead_code)] async fn test_service_account_e2e() { let key = service_account_key_from_file(&TEST_PRIVATE_KEY_PATH.to_string()).unwrap(); - let acc = ServiceAccountAuthenticator::builder(key).build().unwrap(); + let acc = ServiceAccountFlow::new(ServiceAccountFlowOpts { key, subject: None }).unwrap(); + let https = HttpsConnector::new(); + let client = hyper::Client::builder() + .keep_alive(false) + .build::<_, hyper::Body>(https); println!( "{:?}", - acc.token(&["https://www.googleapis.com/auth/pubsub"]).await + acc.token(&client, &["https://www.googleapis.com/auth/pubsub"]) + .await ); } diff --git a/src/storage.rs b/src/storage.rs index 35e123a..3172c85 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -178,27 +178,24 @@ impl JSONTokens { // application will provide the same set of scopes repeatedly. If a // token exists for the exact scope list requested a lookup of the // ScopeFilter will return a list that would contain it. - if let Some(tokens) = self.token_map.get(&filter) { - for t in tokens { - if requested_scopes_are_subset_of(t.scopes.as_slice()) { - return Some(t.token.clone()); - } - } + if let Some(t) = self + .token_map + .get(&filter) + .into_iter() + .flat_map(|tokens_matching_filter| tokens_matching_filter.iter()) + .find(|js_token: &&JSONToken| requested_scopes_are_subset_of(&js_token.scopes)) + { + return Some(t.token.clone()); } // No exact match for the scopes provided. Search for any tokens that // exist for a superset of the scopes requested. - for t in self - .token_map + self.token_map .iter() - .filter(|(k, _v)| filter.is_subset_of(**k) == FilterResponse::Maybe) - .flat_map(|(_, v)| v.iter()) - { - if requested_scopes_are_subset_of(&t.scopes) { - return Some(t.token.clone()); - } - } - None + .filter(|(k, _)| filter.is_subset_of(**k) == FilterResponse::Maybe) + .flat_map(|(_, tokens_matching_filter)| tokens_matching_filter.iter()) + .find(|v: &&JSONToken| requested_scopes_are_subset_of(&v.scopes)) + .map(|t: &JSONToken| t.token.clone()) } fn set(&mut self, ScopesAndFilter { filter, scopes }: ScopesAndFilter, token: Token) From e72b4c233558a2759983c8263b3173123ced72e3 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Wed, 20 Nov 2019 14:51:47 -0800 Subject: [PATCH 45/71] Rename service_account_key_from_file to read_service_account_key This makes the name consistent with the other helper read_application_secret. --- examples/test-svc-acct/src/main.rs | 7 ++++-- src/authenticator.rs | 4 ++-- src/helper.rs | 10 ++++++--- src/service_account.rs | 36 ++++++++++++++---------------- 4 files changed, 31 insertions(+), 26 deletions(-) diff --git a/examples/test-svc-acct/src/main.rs b/examples/test-svc-acct/src/main.rs index 1b73a69..bf8f564 100644 --- a/examples/test-svc-acct/src/main.rs +++ b/examples/test-svc-acct/src/main.rs @@ -3,8 +3,11 @@ use yup_oauth2::ServiceAccountAuthenticator; #[tokio::main] async fn main() { - let creds = yup_oauth2::service_account_key_from_file("serviceaccount.json").unwrap(); - let sa = ServiceAccountAuthenticator::builder(creds).build().await.unwrap(); + let creds = yup_oauth2::read_service_account_key("serviceaccount.json").unwrap(); + let sa = ServiceAccountAuthenticator::builder(creds) + .build() + .await + .unwrap(); let scopes = &["https://www.googleapis.com/auth/pubsub"]; let tok = sa.token(scopes).await.unwrap(); diff --git a/src/authenticator.rs b/src/authenticator.rs index b74a3c8..e412862 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -134,7 +134,7 @@ impl DeviceFlowAuthenticator { /// Create an authenticator that uses a service account. /// ``` /// # async fn foo() { -/// # let service_account_key = yup_oauth2::service_account_key_from_file("/tmp/foo").unwrap(); +/// # let service_account_key = yup_oauth2::read_service_account_key("/tmp/foo").unwrap(); /// let authenticator = yup_oauth2::ServiceAccountAuthenticator::builder(service_account_key) /// .build() /// .await @@ -358,7 +358,7 @@ impl AuthenticatorBuilder { /// ## Methods available when building a service account authenticator. /// ``` /// # async fn foo() { -/// # let service_account_key = yup_oauth2::service_account_key_from_file("/tmp/foo").unwrap(); +/// # let service_account_key = yup_oauth2::read_service_account_key("/tmp/foo").unwrap(); /// let authenticator = yup_oauth2::ServiceAccountAuthenticator::builder( /// service_account_key, /// ) diff --git a/src/helper.rs b/src/helper.rs index 18cced0..a8f0061 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -38,10 +38,14 @@ pub fn parse_application_secret>(secret: S) -> io::Result>(path: S) -> io::Result { +pub fn read_service_account_key>(path: P) -> io::Result { let key = std::fs::read_to_string(path)?; - serde_json::from_str(&key) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("{}", e))) + serde_json::from_str(&key).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("Bad service account key: {}", e), + ) + }) } pub(crate) fn join(pieces: &[T], separator: &str) -> String diff --git a/src/service_account.rs b/src/service_account.rs index abb4eda..b9cffaf 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -37,31 +37,29 @@ fn append_base64 + ?Sized>(s: &T, out: &mut String) { /// Decode a PKCS8 formatted RSA key. fn decode_rsa_key(pem_pkcs8: &str) -> Result { - let private = pem_pkcs8.to_string().replace("\\n", "\n").into_bytes(); - let mut private_reader: &[u8] = private.as_ref(); - let private_keys = pemfile::pkcs8_private_keys(&mut private_reader); + let private = pem_pkcs8.replace("\\n", "\n"); + let private_keys = pemfile::pkcs8_private_keys(&mut private.as_bytes()); - if let Ok(pk) = private_keys { - if !pk.is_empty() { - Ok(pk[0].clone()) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Not enough private keys in PEM", - )) + match private_keys { + Ok(mut keys) if !keys.is_empty() => { + keys.truncate(1); + Ok(keys.remove(0)) } - } else { - Err(io::Error::new( + Ok(_) => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Not enough private keys in PEM", + )), + Err(_) => Err(io::Error::new( io::ErrorKind::InvalidInput, "Error reading key from PEM", - )) + )), } } /// JSON schema of secret service account key. You can obtain the key from /// the Cloud Console at https://console.cloud.google.com/. /// -/// You can use `helpers::service_account_key_from_file()` as a quick way to read a JSON client +/// You can use `helpers::read_service_account_key()` as a quick way to read a JSON client /// secret into a ServiceAccountKey. #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ServiceAccountKey { @@ -248,7 +246,7 @@ impl ServiceAccountFlow { #[cfg(test)] mod tests { use super::*; - use crate::helper::service_account_key_from_file; + use crate::helper::read_service_account_key; use crate::parse_json; use hyper_rustls::HttpsConnector; @@ -333,7 +331,7 @@ mod tests { //#[tokio::test] #[allow(dead_code)] async fn test_service_account_e2e() { - let key = service_account_key_from_file(&TEST_PRIVATE_KEY_PATH.to_string()).unwrap(); + let key = read_service_account_key(&TEST_PRIVATE_KEY_PATH.to_string()).unwrap(); let acc = ServiceAccountFlow::new(ServiceAccountFlowOpts { key, subject: None }).unwrap(); let https = HttpsConnector::new(); let client = hyper::Client::builder() @@ -348,7 +346,7 @@ mod tests { #[test] fn test_jwt_initialize_claims() { - let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap(); + let key = read_service_account_key(TEST_PRIVATE_KEY_PATH).unwrap(); let scopes = vec!["scope1", "scope2", "scope3"]; let claims = Claims::new(&key, &scopes, None); @@ -368,7 +366,7 @@ mod tests { #[test] fn test_jwt_sign() { - let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap(); + let key = read_service_account_key(TEST_PRIVATE_KEY_PATH).unwrap(); let scopes = vec!["scope1", "scope2", "scope3"]; let signer = JWTSigner::new(&key.private_key).unwrap(); let claims = Claims::new(&key, &scopes, None); From 73af51bab62bcccc5bb7fbcb5572ac6f40223723 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Wed, 20 Nov 2019 14:53:48 -0800 Subject: [PATCH 46/71] Remove what appears to be an unnecessary replace. decode_rsa_key does a .replace('\\n', '\n') which replaces any literal 2 byte sequence '\n' with a newline character. The original commit that added it is 38fd8514933bedaa5b23d7ba0b02c35353e3f05b but there's no mention of why it's needed and there are no test cases that fail when it's omitted. Any file that has a literal 2 byte sequence of '\\n' is surely not a valid pkcs8 file so it seems perfectly valid to skip the replace (and allocation) and return an error if one is encountered. If it's determined that this check is needed for some reason please add a unit test that explains why for future contributors. --- src/service_account.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/service_account.rs b/src/service_account.rs index b9cffaf..36bb994 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -37,8 +37,7 @@ fn append_base64 + ?Sized>(s: &T, out: &mut String) { /// Decode a PKCS8 formatted RSA key. fn decode_rsa_key(pem_pkcs8: &str) -> Result { - let private = pem_pkcs8.replace("\\n", "\n"); - let private_keys = pemfile::pkcs8_private_keys(&mut private.as_bytes()); + let private_keys = pemfile::pkcs8_private_keys(&mut pem_pkcs8.as_bytes()); match private_keys { Ok(mut keys) if !keys.is_empty() => { From 1d25341c662df8d7b881b03df05f38a7f7afed65 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Wed, 20 Nov 2019 15:03:08 -0800 Subject: [PATCH 47/71] Remove AuthenticatorDelegate::{client_error, request_failure} These methods are never called. --- src/authenticator_delegate.rs | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index c55c2c7..3ebb225 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -1,6 +1,6 @@ //! Module containing types related to delegates. -use crate::error::{Error, PollError, RefreshError}; +use crate::error::{PollError, RefreshError}; use std::error::Error as StdError; use std::fmt; @@ -69,16 +69,6 @@ impl StdError for PollError { /// The only method that needs to be implemented manually is `present_user_code(...)`, /// as no assumptions are made on how this presentation should happen. pub trait AuthenticatorDelegate: Send + Sync { - /// Called whenever there is an client, usually if there are network problems. - /// - /// Return retry information. - fn client_error(&self, _: &hyper::Error) -> Retry { - Retry::Abort - } - - /// The server denied the attempt to obtain a request code - fn request_failure(&self, _: Error) {} - /// Called if we could not acquire a refresh token for a reason possibly specified /// by the server. /// This call is made for the delegate's information only. From 8e38d3976b0f3fdb640065ce639ff044af6a52b6 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 21 Nov 2019 09:14:30 -0800 Subject: [PATCH 48/71] Make helpers that read from disk async --- examples/test-device/src/main.rs | 1 + examples/test-installed/src/main.rs | 1 + examples/test-svc-acct/src/main.rs | 4 +++- src/authenticator.rs | 14 +++++++------- src/helper.rs | 25 +++++++++++++------------ src/lib.rs | 1 + src/service_account.rs | 20 +++++++++++++------- 7 files changed, 39 insertions(+), 27 deletions(-) diff --git a/examples/test-device/src/main.rs b/examples/test-device/src/main.rs index 82a4741..6d505b4 100644 --- a/examples/test-device/src/main.rs +++ b/examples/test-device/src/main.rs @@ -6,6 +6,7 @@ use tokio; #[tokio::main] async fn main() { let app_secret = yup_oauth2::read_application_secret(path::Path::new("clientsecret.json")) + .await .expect("clientsecret"); let auth = DeviceFlowAuthenticator::builder(app_secret) .persist_tokens_to_disk("tokenstorage.json") diff --git a/examples/test-installed/src/main.rs b/examples/test-installed/src/main.rs index 3febb75..c59f9c9 100644 --- a/examples/test-installed/src/main.rs +++ b/examples/test-installed/src/main.rs @@ -5,6 +5,7 @@ use std::path::Path; #[tokio::main] async fn main() { let app_secret = yup_oauth2::read_application_secret(Path::new("clientsecret.json")) + .await .expect("clientsecret.json"); let auth = diff --git a/examples/test-svc-acct/src/main.rs b/examples/test-svc-acct/src/main.rs index bf8f564..ee79ece 100644 --- a/examples/test-svc-acct/src/main.rs +++ b/examples/test-svc-acct/src/main.rs @@ -3,7 +3,9 @@ use yup_oauth2::ServiceAccountAuthenticator; #[tokio::main] async fn main() { - let creds = yup_oauth2::read_service_account_key("serviceaccount.json").unwrap(); + let creds = yup_oauth2::read_service_account_key("serviceaccount.json") + .await + .unwrap(); let sa = ServiceAccountAuthenticator::builder(creds) .build() .await diff --git a/src/authenticator.rs b/src/authenticator.rs index e412862..5546739 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -88,7 +88,7 @@ pub struct AuthenticatorBuilder { /// # async fn foo() { /// # use yup_oauth2::InstalledFlowReturnMethod; /// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultFlowDelegate; -/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").unwrap(); +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").await.unwrap(); /// let authenticator = yup_oauth2::InstalledFlowAuthenticator::builder( /// app_secret, /// InstalledFlowReturnMethod::HTTPRedirect, @@ -114,7 +114,7 @@ impl InstalledFlowAuthenticator { /// Create an authenticator that uses the device flow. /// ``` /// # async fn foo() { -/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").unwrap(); +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").await.unwrap(); /// let authenticator = yup_oauth2::DeviceFlowAuthenticator::builder(app_secret) /// .build() /// .await @@ -134,7 +134,7 @@ impl DeviceFlowAuthenticator { /// Create an authenticator that uses a service account. /// ``` /// # async fn foo() { -/// # let service_account_key = yup_oauth2::read_service_account_key("/tmp/foo").unwrap(); +/// # let service_account_key = yup_oauth2::read_service_account_key("/tmp/foo").await.unwrap(); /// let authenticator = yup_oauth2::ServiceAccountAuthenticator::builder(service_account_key) /// .build() /// .await @@ -159,7 +159,7 @@ impl ServiceAccountAuthenticator { /// # async fn foo() { /// # let custom_hyper_client = hyper::Client::new(); /// # let custom_auth_delegate = yup_oauth2::authenticator_delegate::DefaultAuthenticatorDelegate; -/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").unwrap(); +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").await.unwrap(); /// let authenticator = yup_oauth2::DeviceFlowAuthenticator::builder(app_secret) /// .hyper_client(custom_hyper_client) /// .persist_tokens_to_disk("/tmp/tokenfile.json") @@ -241,7 +241,7 @@ impl AuthenticatorBuilder { /// ``` /// # async fn foo() { /// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultFlowDelegate; -/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").unwrap(); +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").await.unwrap(); /// let authenticator = yup_oauth2::DeviceFlowAuthenticator::builder(app_secret) /// .device_code_url("foo") /// .flow_delegate(Box::new(custom_flow_delegate)) @@ -317,7 +317,7 @@ impl AuthenticatorBuilder { /// # async fn foo() { /// # use yup_oauth2::InstalledFlowReturnMethod; /// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultFlowDelegate; -/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").unwrap(); +/// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").await.unwrap(); /// let authenticator = yup_oauth2::InstalledFlowAuthenticator::builder( /// app_secret, /// InstalledFlowReturnMethod::HTTPRedirect, @@ -358,7 +358,7 @@ impl AuthenticatorBuilder { /// ## Methods available when building a service account authenticator. /// ``` /// # async fn foo() { -/// # let service_account_key = yup_oauth2::read_service_account_key("/tmp/foo").unwrap(); +/// # let service_account_key = yup_oauth2::read_service_account_key("/tmp/foo").await.unwrap(); /// let authenticator = yup_oauth2::ServiceAccountAuthenticator::builder( /// service_account_key, /// ) diff --git a/src/helper.rs b/src/helper.rs index a8f0061..81200f6 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -11,18 +11,19 @@ use std::io; use std::path::Path; /// Read an application secret from a file. -pub fn read_application_secret>(path: P) -> io::Result { - parse_application_secret(std::fs::read_to_string(path)?) +pub async fn read_application_secret>(path: P) -> io::Result { + parse_application_secret(tokio::fs::read(path).await?) } /// Read an application secret from a JSON string. -pub fn parse_application_secret>(secret: S) -> io::Result { - let decoded: ConsoleApplicationSecret = serde_json::from_str(secret.as_ref()).map_err(|e| { - io::Error::new( - io::ErrorKind::InvalidData, - format!("Bad application secret: {}", e), - ) - })?; +pub fn parse_application_secret>(secret: S) -> io::Result { + let decoded: ConsoleApplicationSecret = + serde_json::from_slice(secret.as_ref()).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("Bad application secret: {}", e), + ) + })?; if let Some(web) = decoded.web { Ok(web) @@ -38,9 +39,9 @@ pub fn parse_application_secret>(secret: S) -> io::Result>(path: P) -> io::Result { - let key = std::fs::read_to_string(path)?; - serde_json::from_str(&key).map_err(|e| { +pub async fn read_service_account_key>(path: P) -> io::Result { + let key = tokio::fs::read(path).await?; + serde_json::from_slice(&key).map_err(|e| { io::Error::new( io::ErrorKind::InvalidData, format!("Bad service account key: {}", e), diff --git a/src/lib.rs b/src/lib.rs index 81520f9..aa7c0fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,6 +45,7 @@ //! // Read application secret from a file. Sometimes it's easier to compile it directly into //! // the binary. The clientsecret file contains JSON like `{"installed":{"client_id": ... }}` //! let secret = yup_oauth2::read_application_secret("clientsecret.json") +//! .await //! .expect("clientsecret.json"); //! //! // Create an authenticator that uses an InstalledFlow to authenticate. The diff --git a/src/service_account.rs b/src/service_account.rs index 36bb994..d584b26 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -330,7 +330,9 @@ mod tests { //#[tokio::test] #[allow(dead_code)] async fn test_service_account_e2e() { - let key = read_service_account_key(&TEST_PRIVATE_KEY_PATH.to_string()).unwrap(); + let key = read_service_account_key(TEST_PRIVATE_KEY_PATH) + .await + .unwrap(); let acc = ServiceAccountFlow::new(ServiceAccountFlowOpts { key, subject: None }).unwrap(); let https = HttpsConnector::new(); let client = hyper::Client::builder() @@ -343,9 +345,11 @@ mod tests { ); } - #[test] - fn test_jwt_initialize_claims() { - let key = read_service_account_key(TEST_PRIVATE_KEY_PATH).unwrap(); + #[tokio::test] + async fn test_jwt_initialize_claims() { + let key = read_service_account_key(TEST_PRIVATE_KEY_PATH) + .await + .unwrap(); let scopes = vec!["scope1", "scope2", "scope3"]; let claims = Claims::new(&key, &scopes, None); @@ -363,9 +367,11 @@ mod tests { assert_eq!(claims.exp - claims.iat, 3595); } - #[test] - fn test_jwt_sign() { - let key = read_service_account_key(TEST_PRIVATE_KEY_PATH).unwrap(); + #[tokio::test] + async fn test_jwt_sign() { + let key = read_service_account_key(TEST_PRIVATE_KEY_PATH) + .await + .unwrap(); let scopes = vec!["scope1", "scope2", "scope3"]; let signer = JWTSigner::new(&key.private_key).unwrap(); let claims = Claims::new(&key, &scopes, None); From d63396a7409feeebb68da828fb0cec71dd5df1bf Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 21 Nov 2019 09:29:43 -0800 Subject: [PATCH 49/71] Split FlowDelegate into DeviceFlowDelegate and InstalledFlowDelegate. Each flow invokes a non-overlapping set of methods. There doesn't appear to be any benefit in having both flows use a common trait. The benefit of splitting the traits is that it makes it clear which methods need to be updated for each flow type where previously comments were required to communicate that information. --- src/authenticator.rs | 16 +++++++------- src/authenticator_delegate.rs | 39 ++++++++++++++++++++--------------- src/device.rs | 14 +++++++------ src/installed.rs | 12 +++++------ 4 files changed, 44 insertions(+), 37 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index 5546739..af43e14 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -1,6 +1,6 @@ //! Module contianing the core functionality for OAuth2 Authentication. use crate::authenticator_delegate::{ - AuthenticatorDelegate, DefaultAuthenticatorDelegate, FlowDelegate, + AuthenticatorDelegate, DefaultAuthenticatorDelegate, DeviceFlowDelegate, InstalledFlowDelegate, }; use crate::device::DeviceFlow; use crate::error::Error; @@ -87,7 +87,7 @@ pub struct AuthenticatorBuilder { /// ``` /// # async fn foo() { /// # use yup_oauth2::InstalledFlowReturnMethod; -/// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultFlowDelegate; +/// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultInstalledFlowDelegate; /// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").await.unwrap(); /// let authenticator = yup_oauth2::InstalledFlowAuthenticator::builder( /// app_secret, @@ -240,7 +240,7 @@ impl AuthenticatorBuilder { /// ## Methods available when building a device flow Authenticator. /// ``` /// # async fn foo() { -/// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultFlowDelegate; +/// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultDeviceFlowDelegate; /// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").await.unwrap(); /// let authenticator = yup_oauth2::DeviceFlowAuthenticator::builder(app_secret) /// .device_code_url("foo") @@ -264,8 +264,8 @@ impl AuthenticatorBuilder { } } - /// Use the provided FlowDelegate. - pub fn flow_delegate(self, flow_delegate: Box) -> Self { + /// Use the provided DeviceFlowDelegate. + pub fn flow_delegate(self, flow_delegate: Box) -> Self { AuthenticatorBuilder { auth_flow: DeviceFlow { flow_delegate, @@ -316,7 +316,7 @@ impl AuthenticatorBuilder { /// ``` /// # async fn foo() { /// # use yup_oauth2::InstalledFlowReturnMethod; -/// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultFlowDelegate; +/// # let custom_flow_delegate = yup_oauth2::authenticator_delegate::DefaultInstalledFlowDelegate; /// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").await.unwrap(); /// let authenticator = yup_oauth2::InstalledFlowAuthenticator::builder( /// app_secret, @@ -329,8 +329,8 @@ impl AuthenticatorBuilder { /// # } /// ``` impl AuthenticatorBuilder { - /// Use the provided FlowDelegate. - pub fn flow_delegate(self, flow_delegate: Box) -> Self { + /// Use the provided InstalledFlowDelegate. + pub fn flow_delegate(self, flow_delegate: Box) -> Self { AuthenticatorBuilder { auth_flow: InstalledFlow { flow_delegate, diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 3ebb225..51650c4 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -75,9 +75,9 @@ pub trait AuthenticatorDelegate: Send + Sync { fn token_refresh_failed(&self, _: &RefreshError) {} } -/// FlowDelegate methods are called when an OAuth flow needs to ask the application what to do in -/// certain cases. -pub trait FlowDelegate: Send + Sync { +/// DeviceFlowDelegate methods are called when a device flow needs to ask the +/// application what to do in certain cases. +pub trait DeviceFlowDelegate: Send + Sync { /// Called if the request code is expired. You will have to start over in this case. /// This will be the last call the delegate receives. /// Given `DateTime` is the expiration date @@ -91,22 +91,14 @@ pub trait FlowDelegate: Send + Sync { /// Can be used to print progress information, or decide to time-out. /// /// If the returned `Retry` variant is a duration. - /// # Notes - /// * Only used in `DeviceFlow`. Return value will only be used if it - /// is larger than the interval desired by the server. fn pending(&self, _: &PollInformation) -> Retry { Retry::After(Duration::from_secs(5)) } - /// Configure a custom redirect uri if needed. - fn redirect_uri(&self) -> Option<&str> { - None - } /// The server has returned a `user_code` which must be shown to the user, /// along with the `verification_url`. /// # Notes /// * Will be called exactly once, provided we didn't abort during `request_code` phase. - /// * Will only be called if the Authenticator's flow_type is `DeviceFlow`. fn present_user_code(&self, pi: &PollInformation) { println!( "Please enter {} at {} and grant access to this application", @@ -118,8 +110,16 @@ pub trait FlowDelegate: Send + Sync { pi.expires_at.with_timezone(&Local) ); } +} + +/// InstalledFlowDelegate methods are called when an installed flow needs to ask +/// the application what to do in certain cases. +pub trait InstalledFlowDelegate: Send + Sync { + /// Configure a custom redirect uri if needed. + fn redirect_uri(&self) -> Option<&str> { + None + } - /// This method is used by the InstalledFlow. /// We need the user to navigate to a URL using their browser and potentially paste back a code /// (or maybe not). Whether they have to enter a code depends on the InstalledFlowReturnMethod /// used. @@ -167,11 +167,16 @@ async fn present_user_url( /// Uses all default implementations by AuthenticatorDelegate, and makes the trait's /// implementation usable in the first place. -#[derive(Clone)] +#[derive(Copy, Clone)] pub struct DefaultAuthenticatorDelegate; impl AuthenticatorDelegate for DefaultAuthenticatorDelegate {} -/// Uses all default implementations in the FlowDelegate trait. -#[derive(Clone)] -pub struct DefaultFlowDelegate; -impl FlowDelegate for DefaultFlowDelegate {} +/// Uses all default implementations in the DeviceFlowDelegate trait. +#[derive(Copy, Clone)] +pub struct DefaultDeviceFlowDelegate; +impl DeviceFlowDelegate for DefaultDeviceFlowDelegate {} + +/// Uses all default implementations in the DeviceFlowDelegate trait. +#[derive(Copy, Clone)] +pub struct DefaultInstalledFlowDelegate; +impl InstalledFlowDelegate for DefaultInstalledFlowDelegate {} diff --git a/src/device.rs b/src/device.rs index d2758a6..f61e092 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,4 +1,6 @@ -use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate, PollInformation, Retry}; +use crate::authenticator_delegate::{ + DefaultDeviceFlowDelegate, DeviceFlowDelegate, PollInformation, Retry, +}; use crate::error::{Error, JsonErrorOr, PollError}; use crate::types::{ApplicationSecret, Token}; @@ -24,7 +26,7 @@ pub const GOOGLE_GRANT_TYPE: &str = "http://oauth.net/grant_type/device/1.0"; pub struct DeviceFlow { pub(crate) app_secret: ApplicationSecret, pub(crate) device_code_url: Cow<'static, str>, - pub(crate) flow_delegate: Box, + pub(crate) flow_delegate: Box, pub(crate) wait_duration: Duration, pub(crate) grant_type: Cow<'static, str>, } @@ -36,7 +38,7 @@ impl DeviceFlow { DeviceFlow { app_secret, device_code_url: GOOGLE_DEVICE_CODE_URL.into(), - flow_delegate: Box::new(DefaultFlowDelegate), + flow_delegate: Box::new(DefaultDeviceFlowDelegate), wait_duration: Duration::from_secs(120), grant_type: GOOGLE_GRANT_TYPE.into(), } @@ -93,7 +95,7 @@ impl DeviceFlow { device_code, grant_type, pollinf.expires_at, - &*self.flow_delegate as &dyn FlowDelegate, + &*self.flow_delegate as &dyn DeviceFlowDelegate, ) .await { @@ -203,7 +205,7 @@ impl DeviceFlow { device_code: &str, grant_type: &str, expires_at: DateTime, - flow_delegate: &dyn FlowDelegate, + flow_delegate: &dyn DeviceFlowDelegate, ) -> Result, PollError> where C: hyper::client::connect::Connect + 'static, @@ -278,7 +280,7 @@ mod tests { async fn test_device_end2end() { #[derive(Clone)] struct FD; - impl FlowDelegate for FD { + impl DeviceFlowDelegate for FD { fn present_user_code(&self, pi: &PollInformation) { assert_eq!("https://example.com/verify", pi.verification_url); } diff --git a/src/installed.rs b/src/installed.rs index 81528f9..25f6549 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -2,7 +2,7 @@ // // Refer to the project root for licensing information. // -use crate::authenticator_delegate::{DefaultFlowDelegate, FlowDelegate}; +use crate::authenticator_delegate::{DefaultInstalledFlowDelegate, InstalledFlowDelegate}; use crate::error::{Error, JsonErrorOr}; use crate::types::{ApplicationSecret, Token}; @@ -68,7 +68,7 @@ pub enum InstalledFlowReturnMethod { pub struct InstalledFlow { pub(crate) app_secret: ApplicationSecret, pub(crate) method: InstalledFlowReturnMethod, - pub(crate) flow_delegate: Box, + pub(crate) flow_delegate: Box, } impl InstalledFlow { @@ -80,7 +80,7 @@ impl InstalledFlow { InstalledFlow { app_secret, method, - flow_delegate: Box::new(DefaultFlowDelegate), + flow_delegate: Box::new(DefaultInstalledFlowDelegate), } } @@ -89,7 +89,7 @@ impl InstalledFlow { /// . Obtain a token and refresh token using that code. /// . Return that token /// - /// It's recommended not to use the DefaultFlowDelegate, but a specialized one. + /// It's recommended not to use the DefaultInstalledFlowDelegate, but a specialized one. pub(crate) async fn token( &self, hyper_client: &hyper::Client, @@ -404,7 +404,7 @@ mod tests { use mockito::mock; use super::*; - use crate::authenticator_delegate::FlowDelegate; + use crate::authenticator_delegate::InstalledFlowDelegate; #[tokio::test] async fn test_end2end() { @@ -413,7 +413,7 @@ mod tests { String, hyper::Client, hyper::Body>, ); - impl FlowDelegate for FD { + impl InstalledFlowDelegate for FD { /// Depending on need_code, return the pre-set code or send the code to the server at /// the redirect_uri given in the url. fn present_user_url<'a>( From 25ba7f0b1f65108b11794be35241ebcb02b9777b Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 21 Nov 2019 10:29:33 -0800 Subject: [PATCH 50/71] Move the impls for PollError into error.rs These should have been moved when error.rs was created but were missed. --- src/authenticator_delegate.rs | 23 +---------------------- src/error.rs | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 51650c4..e32f88e 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -1,6 +1,6 @@ //! Module containing types related to delegates. -use crate::error::{PollError, RefreshError}; +use crate::error::{RefreshError}; use std::error::Error as StdError; use std::fmt; @@ -43,27 +43,6 @@ impl fmt::Display for PollInformation { } } -impl fmt::Display for PollError { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - match *self { - PollError::HttpError(ref err) => err.fmt(f), - PollError::Expired(ref date) => writeln!(f, "Authentication expired at {}", date), - PollError::AccessDenied => "Access denied by user".fmt(f), - PollError::TimedOut => "Timed out waiting for token".fmt(f), - PollError::Other(ref s) => format!("Unknown server error: {}", s).fmt(f), - } - } -} - -impl StdError for PollError { - fn source(&self) -> Option<&(dyn StdError + 'static)> { - match *self { - PollError::HttpError(ref e) => Some(e), - _ => None, - } - } -} - /// A partially implemented trait to interact with the `Authenticator` /// /// The only method that needs to be implemented manually is `present_user_code(...)`, diff --git a/src/error.rs b/src/error.rs index 619b984..ae0fca8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -46,6 +46,27 @@ pub enum PollError { Other(String), } +impl fmt::Display for PollError { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + match *self { + PollError::HttpError(ref err) => err.fmt(f), + PollError::Expired(ref date) => writeln!(f, "Authentication expired at {}", date), + PollError::AccessDenied => "Access denied by user".fmt(f), + PollError::TimedOut => "Timed out waiting for token".fmt(f), + PollError::Other(ref s) => format!("Unknown server error: {}", s).fmt(f), + } + } +} + +impl StdError for PollError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match *self { + PollError::HttpError(ref e) => Some(e), + _ => None, + } + } +} + /// Encapsulates all possible results of the `token(...)` operation #[derive(Debug)] pub enum Error { From 8030d31da98118027b3e2ebf093c0727169362a3 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 21 Nov 2019 10:42:08 -0800 Subject: [PATCH 51/71] Reconsider Error variants BadServerResponse didn't seem to be adequately different from NegativeServerResponse so it's been removed. NegativeServerResponse is now a struct variant with field names 'error' and 'error_description' to self document what it contains. InvalidClient and InvalidScope were only ever created based on string parsing of the returned error message from the server. This seemed overly specific to a particular implementation and didn't provide much value to callers. Printing a NegativeServerResponse would provide users the same information. The caching layer never returns errors anymore so remove that variant. --- src/authenticator_delegate.rs | 2 +- src/error.rs | 41 +++++++++++++---------------------- src/service_account.rs | 40 ++++++++++++++-------------------- 3 files changed, 32 insertions(+), 51 deletions(-) diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index e32f88e..af111db 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -1,6 +1,6 @@ //! Module containing types related to delegates. -use crate::error::{RefreshError}; +use crate::error::RefreshError; use std::error::Error as StdError; use std::fmt; diff --git a/src/error.rs b/src/error.rs index ae0fca8..34b4570 100644 --- a/src/error.rs +++ b/src/error.rs @@ -72,16 +72,13 @@ impl StdError for PollError { pub enum Error { /// Indicates connection failure ClientError(hyper::Error), - /// The OAuth client was not found - InvalidClient, - /// Some requested scopes were invalid. String contains the scopes as part of - /// the server error message - InvalidScope(String), - /// A 'catch-all' variant containing the server error and description - /// First string is the error code, the second may be a more detailed description - NegativeServerResponse(String, Option), - /// A malformed server response. - BadServerResponse(String), + /// The server returned an error. + NegativeServerResponse { + /// The error code + error: String, + /// Detailed description + error_description: Option, + }, /// Error while decoding a JSON response. JSONError(serde_json::Error), /// Error within user input. @@ -92,8 +89,6 @@ pub enum Error { Poll(PollError), /// An error occurred while refreshing tokens. Refresh(RefreshError), - /// Error in token cache layer - Cache(Box), } impl From for Error { @@ -104,14 +99,9 @@ impl From for Error { impl From for Error { fn from(value: JsonError) -> Error { - match &*value.error { - "invalid_client" => Error::InvalidClient, - "invalid_scope" => Error::InvalidScope( - value - .error_description - .unwrap_or_else(|| "no description provided".to_string()), - ), - _ => Error::NegativeServerResponse(value.error, value.error_description), + Error::NegativeServerResponse { + error: value.error, + error_description: value.error_description, } } } @@ -132,16 +122,16 @@ impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self { Error::ClientError(ref err) => err.fmt(f), - Error::InvalidClient => "Invalid Client".fmt(f), - Error::InvalidScope(ref scope) => writeln!(f, "Invalid Scope: '{}'", scope), - Error::NegativeServerResponse(ref error, ref desc) => { + Error::NegativeServerResponse { + ref error, + ref error_description, + } => { error.fmt(f)?; - if let Some(ref desc) = *desc { + if let Some(ref desc) = *error_description { write!(f, ": {}", desc)?; } "\n".fmt(f) } - Error::BadServerResponse(ref s) => s.fmt(f), Error::JSONError(ref e) => format!( "JSON Error; this might be a bug with unexpected server responses! {}", e @@ -151,7 +141,6 @@ impl fmt::Display for Error { Error::LowLevelError(ref e) => e.fmt(f), Error::Poll(ref pe) => pe.fmt(f), Error::Refresh(ref rr) => format!("{:?}", rr).fmt(f), - Error::Cache(ref e) => e.fmt(f), } } } diff --git a/src/service_account.rs b/src/service_account.rs index d584b26..6f4fc3a 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -213,32 +213,24 @@ impl ServiceAccountFlow { /// This is the schema of the server's response. #[derive(Deserialize, Debug)] struct TokenResponse { - access_token: Option, - token_type: Option, - expires_in: Option, + access_token: String, + token_type: String, + expires_in: i64, } - match serde_json::from_slice::>(&body)?.into_result()? { - TokenResponse { - access_token: Some(access_token), - token_type: Some(token_type), - expires_in: Some(expires_in), - .. - } => { - let expires_ts = chrono::Utc::now().timestamp() + expires_in; - Ok(Token { - access_token, - token_type, - refresh_token: None, - expires_in: Some(expires_in), - expires_in_timestamp: Some(expires_ts), - }) - } - token => Err(Error::BadServerResponse(format!( - "Token response lacks fields: {:?}", - token - ))), - } + let TokenResponse { + access_token, + token_type, + expires_in, + } = serde_json::from_slice::>(&body)?.into_result()?; + let expires_ts = chrono::Utc::now().timestamp() + expires_in; + Ok(Token { + access_token, + token_type, + refresh_token: None, + expires_in: Some(expires_in), + expires_in_timestamp: Some(expires_ts), + }) } } From 2253c60b8954d2227d94845e7c6622f65a0e5f42 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 21 Nov 2019 11:19:00 -0800 Subject: [PATCH 52/71] InstalledFlowDelegate::present_user_url should return a String error. Prior to this change the only place present_user_url is called overwrote the error with a static string. After this change the error returned is appended to the message. No need to make the signature more complicated when the error is always going to be flattened to a string anyway. --- src/authenticator_delegate.rs | 22 +++++++--------------- src/installed.rs | 26 +++++++++----------------- 2 files changed, 16 insertions(+), 32 deletions(-) diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index af111db..9907178 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -2,7 +2,6 @@ use crate::error::RefreshError; -use std::error::Error as StdError; use std::fmt; use std::pin::Pin; use std::time::Duration; @@ -106,16 +105,12 @@ pub trait InstalledFlowDelegate: Send + Sync { &'a self, url: &'a str, need_code: bool, - ) -> Pin>> + Send + 'a>> - { + ) -> Pin> + Send + 'a>> { Box::pin(present_user_url(url, need_code)) } } -async fn present_user_url( - url: &str, - need_code: bool, -) -> Result> { +async fn present_user_url(url: &str, need_code: bool) -> Result { use tokio::io::AsyncBufReadExt; if need_code { println!( @@ -124,16 +119,13 @@ async fn present_user_url( url ); let mut user_input = String::new(); - match tokio::io::BufReader::new(tokio::io::stdin()) + tokio::io::BufReader::new(tokio::io::stdin()) .read_line(&mut user_input) .await - { - Err(err) => { - println!("{:?}", err); - Err(Box::new(err) as Box) - } - Ok(_) => Ok(user_input), - } + .map_err(|e| format!("couldn't read code: {}", e))?; + // remove trailing whitespace. + user_input.truncate(user_input.trim_end().len()); + Ok(user_input) } else { println!( "Please direct your browser to {} and follow the instructions displayed \ diff --git a/src/installed.rs b/src/installed.rs index 25f6549..701a3b4 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -127,21 +127,16 @@ impl InstalledFlow { scopes, self.flow_delegate.redirect_uri(), ); - let authcode = match self + let mut authcode = self .flow_delegate .present_user_url(&url, true /* need code */) .await - { - Ok(mut code) => { - // Partial backwards compatibility in case an implementation adds a new line - // due to previous behaviour. - if code.ends_with('\n') { - code.pop(); - } - code - } - _ => return Err(Error::UserError("couldn't read code".to_string())), - }; + .map_err(Error::UserError)?; + // Partial backwards compatibility in case an implementation adds a new line + // due to previous behaviour. + if authcode.ends_with('\n') { + authcode.pop(); + } self.exchange_auth_code(&authcode, hyper_client, app_secret, None) .await } @@ -395,7 +390,6 @@ mod installed_flow_server { #[cfg(test)] mod tests { - use std::error::Error; use std::str::FromStr; use hyper::client::connect::HttpConnector; @@ -420,9 +414,7 @@ mod tests { &'a self, url: &'a str, need_code: bool, - ) -> Pin< - Box>> + Send + 'a>, - > { + ) -> Pin> + Send + 'a>> { Box::pin(async move { if need_code { Ok(self.0.clone()) @@ -449,7 +441,7 @@ mod tests { self.1 .get(rduri) .await - .map_err(|e| Box::new(e) as Box) + .map_err(|e| e.to_string()) .map(|_| "".to_string()) } }) From ae2258bc7a5b529cea60e1c13d0986ffb1c025a5 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 21 Nov 2019 11:21:39 -0800 Subject: [PATCH 53/71] Remove code to strip trailing newlines for backwards compatibility. Based on the comment in the code the justification for the change was because old FlowDelegates used to contain newlines and changing how the returned string from the delegate was handled would be a breaking change. In this case it should be safe to remove the hack because we're breaking compatibility. All users that once implemented FlowDelegate will now need to implement InstalledFlowDelegate and uphold the new contract which implicitly means the authcode returned should represent the authcode and nothing more. No manipulation of the returned string will be done. --- src/installed.rs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/installed.rs b/src/installed.rs index 701a3b4..a5aec17 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -127,16 +127,11 @@ impl InstalledFlow { scopes, self.flow_delegate.redirect_uri(), ); - let mut authcode = self + let authcode = self .flow_delegate .present_user_url(&url, true /* need code */) .await .map_err(Error::UserError)?; - // Partial backwards compatibility in case an implementation adds a new line - // due to previous behaviour. - if authcode.ends_with('\n') { - authcode.pop(); - } self.exchange_auth_code(&authcode, hyper_client, app_secret, None) .await } From fe5ea9bdb2a41bb154ca6a77d45ce19a12b4077f Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 21 Nov 2019 12:53:02 -0800 Subject: [PATCH 54/71] Rename Error::ClientError and RefreshError::ConnectionError to HttpError. PollError already contained an HttpError variant so this makes all variants that contain a hyper::Error consistently named. --- src/device.rs | 2 +- src/error.rs | 12 ++++++------ src/installed.rs | 11 ++--------- src/service_account.rs | 11 ++--------- 4 files changed, 11 insertions(+), 25 deletions(-) diff --git a/src/device.rs b/src/device.rs index f61e092..1ab9efb 100644 --- a/src/device.rs +++ b/src/device.rs @@ -153,7 +153,7 @@ impl DeviceFlow { .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(req)) .unwrap(); - let resp = client.request(req).await.map_err(Error::ClientError)?; + let resp = client.request(req).await?; // This return type is defined in https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.2 // The alias is present as Google use a non-standard name for verification_uri. // According to the standard interval is optional, however, all tested implementations provide it. diff --git a/src/error.rs b/src/error.rs index 34b4570..10b262f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -71,7 +71,7 @@ impl StdError for PollError { #[derive(Debug)] pub enum Error { /// Indicates connection failure - ClientError(hyper::Error), + HttpError(hyper::Error), /// The server returned an error. NegativeServerResponse { /// The error code @@ -93,7 +93,7 @@ pub enum Error { impl From for Error { fn from(error: hyper::Error) -> Error { - Error::ClientError(error) + Error::HttpError(error) } } @@ -121,7 +121,7 @@ impl From for Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self { - Error::ClientError(ref err) => err.fmt(f), + Error::HttpError(ref err) => err.fmt(f), Error::NegativeServerResponse { ref error, ref error_description, @@ -148,7 +148,7 @@ impl fmt::Display for Error { impl StdError for Error { fn source(&self) -> Option<&(dyn StdError + 'static)> { match *self { - Error::ClientError(ref err) => Some(err), + Error::HttpError(ref err) => Some(err), Error::LowLevelError(ref err) => Some(err), Error::JSONError(ref err) => Some(err), _ => None, @@ -160,14 +160,14 @@ impl StdError for Error { #[derive(Debug)] pub enum RefreshError { /// Indicates connection failure - ConnectionError(hyper::Error), + HttpError(hyper::Error), /// The server did not answer with a new token, providing the server message ServerError(String, Option), } impl From for RefreshError { fn from(value: hyper::Error) -> Self { - RefreshError::ConnectionError(value) + RefreshError::HttpError(value) } } diff --git a/src/installed.rs b/src/installed.rs index a5aec17..54ae4de 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -185,15 +185,8 @@ impl InstalledFlow { { let redirect_uri = self.flow_delegate.redirect_uri(); let request = Self::request_token(app_secret, authcode, redirect_uri, server_addr); - let resp = hyper_client - .request(request) - .await - .map_err(Error::ClientError)?; - let body = resp - .into_body() - .try_concat() - .await - .map_err(Error::ClientError)?; + let resp = hyper_client.request(request).await?; + let body = resp.into_body().try_concat().await?; #[derive(Deserialize)] struct JSONTokenResponse { diff --git a/src/service_account.rs b/src/service_account.rs index 6f4fc3a..71ac5c2 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -200,15 +200,8 @@ impl ServiceAccountFlow { .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(rqbody)) .unwrap(); - let response = hyper_client - .request(request) - .await - .map_err(Error::ClientError)?; - let body = response - .into_body() - .try_concat() - .await - .map_err(Error::ClientError)?; + let response = hyper_client.request(request).await?; + let body = response.into_body().try_concat().await?; /// This is the schema of the server's response. #[derive(Deserialize, Debug)] From d0880d07dbb2800e7119ad71f414ba8044c62a35 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 21 Nov 2019 16:49:10 -0800 Subject: [PATCH 55/71] Refactor error handling and as a consequence delegates. This Removes RefreshError and PollError. Both those types can be fully represented within Error and there seems little value in distinguishing that they were resulting from device polling or refreshes. In either case the user will need to handle the response from token() calls similarly. This also removes the AuthenticatorDelegate since it only served to notify users when refreshes failed, which can already be done by looking at the return code from token. DeviceFlow no longer has the ability to set a wait_timeout. This is trivial to do by wrapping the token() call in a tokio::Timeout future so there's little benefit for users specifying this value. The DeviceFlowDelegate also no longer has the ability to specify when to abort, or alter the interval polling happens on, but it does gain understanding of the 'slow_down' response as documented in the oauth rfc. It seemed very unlikely the delegate was going to do anything other that timeout after a given time and that's already possible using tokio::Timeout so it needlessly complicated the implementation. --- src/authenticator.rs | 55 +------ src/authenticator_delegate.rs | 54 ------- src/device.rs | 111 ++++---------- src/error.rs | 274 +++++++++++++++++++++------------- src/installed.rs | 4 +- src/refresh.rs | 17 ++- src/service_account.rs | 4 +- 7 files changed, 214 insertions(+), 305 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index af43e14..e639f99 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -1,7 +1,5 @@ //! Module contianing the core functionality for OAuth2 Authentication. -use crate::authenticator_delegate::{ - AuthenticatorDelegate, DefaultAuthenticatorDelegate, DeviceFlowDelegate, InstalledFlowDelegate, -}; +use crate::authenticator_delegate::{DeviceFlowDelegate, InstalledFlowDelegate}; use crate::device::DeviceFlow; use crate::error::Error; use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; @@ -15,13 +13,11 @@ use std::borrow::Cow; use std::io; use std::path::PathBuf; use std::sync::Mutex; -use std::time::Duration; /// Authenticator is responsible for fetching tokens, handling refreshing tokens, /// and optionally persisting tokens to disk. pub struct Authenticator { hyper_client: hyper::Client, - auth_delegate: Box, storage: Storage, auth_flow: AuthFlow, } @@ -49,19 +45,9 @@ where Some(app_secret), ) => { // token is expired but has a refresh token. - let token = match RefreshFlow::refresh_token( - &self.hyper_client, - app_secret, - &refresh_token, - ) - .await - { - Err(err) => { - self.auth_delegate.token_refresh_failed(&err); - return Err(err.into()); - } - Ok(token) => token, - }; + let token = + RefreshFlow::refresh_token(&self.hyper_client, app_secret, &refresh_token) + .await?; self.storage.set(hashed_scopes, token.clone()).await; Ok(token) } @@ -78,7 +64,6 @@ where /// Configure an Authenticator using the builder pattern. pub struct AuthenticatorBuilder { hyper_client_builder: C, - auth_delegate: Box, storage_type: StorageType, auth_flow: F, } @@ -158,12 +143,10 @@ impl ServiceAccountAuthenticator { /// ``` /// # async fn foo() { /// # let custom_hyper_client = hyper::Client::new(); -/// # let custom_auth_delegate = yup_oauth2::authenticator_delegate::DefaultAuthenticatorDelegate; /// # let app_secret = yup_oauth2::read_application_secret("/tmp/foo").await.unwrap(); /// let authenticator = yup_oauth2::DeviceFlowAuthenticator::builder(app_secret) /// .hyper_client(custom_hyper_client) /// .persist_tokens_to_disk("/tmp/tokenfile.json") -/// .auth_delegate(Box::new(custom_auth_delegate)) /// .build() /// .await /// .expect("failed to create authenticator"); @@ -173,7 +156,6 @@ impl AuthenticatorBuilder { async fn common_build( hyper_client_builder: C, storage_type: StorageType, - auth_delegate: Box, auth_flow: AuthFlow, ) -> io::Result> where @@ -190,7 +172,6 @@ impl AuthenticatorBuilder { Ok(Authenticator { hyper_client, storage, - auth_delegate, auth_flow, }) } @@ -198,7 +179,6 @@ impl AuthenticatorBuilder { fn with_auth_flow(auth_flow: F) -> AuthenticatorBuilder { AuthenticatorBuilder { hyper_client_builder: DefaultHyperClient, - auth_delegate: Box::new(DefaultAuthenticatorDelegate), storage_type: StorageType::Memory, auth_flow, } @@ -211,7 +191,6 @@ impl AuthenticatorBuilder { ) -> AuthenticatorBuilder, F> { AuthenticatorBuilder { hyper_client_builder: hyper_client, - auth_delegate: self.auth_delegate, storage_type: self.storage_type, auth_flow: self.auth_flow, } @@ -224,17 +203,6 @@ impl AuthenticatorBuilder { ..self } } - - /// Use the provided authenticator delegate. - pub fn auth_delegate( - self, - auth_delegate: Box, - ) -> AuthenticatorBuilder { - AuthenticatorBuilder { - auth_delegate, - ..self - } - } } /// ## Methods available when building a device flow Authenticator. @@ -245,7 +213,6 @@ impl AuthenticatorBuilder { /// let authenticator = yup_oauth2::DeviceFlowAuthenticator::builder(app_secret) /// .device_code_url("foo") /// .flow_delegate(Box::new(custom_flow_delegate)) -/// .wait_duration(std::time::Duration::from_secs(120)) /// .grant_type("foo") /// .build() /// .await @@ -275,17 +242,6 @@ impl AuthenticatorBuilder { } } - /// Use the provided wait duration. - pub fn wait_duration(self, wait_duration: Duration) -> Self { - AuthenticatorBuilder { - auth_flow: DeviceFlow { - wait_duration, - ..self.auth_flow - }, - ..self - } - } - /// Use the provided grant type. pub fn grant_type(self, grant_type: impl Into>) -> Self { AuthenticatorBuilder { @@ -305,7 +261,6 @@ impl AuthenticatorBuilder { Self::common_build( self.hyper_client_builder, self.storage_type, - self.auth_delegate, AuthFlow::DeviceFlow(self.auth_flow), ) .await @@ -348,7 +303,6 @@ impl AuthenticatorBuilder { Self::common_build( self.hyper_client_builder, self.storage_type, - self.auth_delegate, AuthFlow::InstalledFlow(self.auth_flow), ) .await @@ -389,7 +343,6 @@ impl AuthenticatorBuilder { Self::common_build( self.hyper_client_builder, self.storage_type, - self.auth_delegate, AuthFlow::ServiceAccountFlow(service_account_auth_flow), ) .await diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 9907178..9374c1d 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -1,25 +1,11 @@ //! Module containing types related to delegates. -use crate::error::RefreshError; - -use std::fmt; use std::pin::Pin; use std::time::Duration; use chrono::{DateTime, Local, Utc}; use futures::prelude::*; -/// A utility type to indicate how operations DeviceFlowHelper operations should be retried -pub enum Retry { - /// Signal you don't want to retry - Abort, - /// Signals you want to retry after the given duration - After(Duration), - /// Instruct the caller to attempt to keep going, or choose an alternate path. - /// If this is not supported, it will have the same effect as `Abort` - Skip, -} - /// Contains state of pending authentication requests #[derive(Clone, Debug, PartialEq)] pub struct PollInformation { @@ -36,43 +22,9 @@ pub struct PollInformation { pub interval: Duration, } -impl fmt::Display for PollInformation { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - writeln!(f, "Proceed with polling until {}", self.expires_at) - } -} - -/// A partially implemented trait to interact with the `Authenticator` -/// -/// The only method that needs to be implemented manually is `present_user_code(...)`, -/// as no assumptions are made on how this presentation should happen. -pub trait AuthenticatorDelegate: Send + Sync { - /// Called if we could not acquire a refresh token for a reason possibly specified - /// by the server. - /// This call is made for the delegate's information only. - fn token_refresh_failed(&self, _: &RefreshError) {} -} - /// DeviceFlowDelegate methods are called when a device flow needs to ask the /// application what to do in certain cases. pub trait DeviceFlowDelegate: Send + Sync { - /// Called if the request code is expired. You will have to start over in this case. - /// This will be the last call the delegate receives. - /// Given `DateTime` is the expiration date - fn expired(&self, _: DateTime) {} - - /// Called if the user denied access. You would have to start over. - /// This will be the last call the delegate receives. - fn denied(&self) {} - - /// Called as long as we are waiting for the user to authorize us. - /// Can be used to print progress information, or decide to time-out. - /// - /// If the returned `Retry` variant is a duration. - fn pending(&self, _: &PollInformation) -> Retry { - Retry::After(Duration::from_secs(5)) - } - /// The server has returned a `user_code` which must be shown to the user, /// along with the `verification_url`. /// # Notes @@ -136,12 +88,6 @@ async fn present_user_url(url: &str, need_code: bool) -> Result } } -/// Uses all default implementations by AuthenticatorDelegate, and makes the trait's -/// implementation usable in the first place. -#[derive(Copy, Clone)] -pub struct DefaultAuthenticatorDelegate; -impl AuthenticatorDelegate for DefaultAuthenticatorDelegate {} - /// Uses all default implementations in the DeviceFlowDelegate trait. #[derive(Copy, Clone)] pub struct DefaultDeviceFlowDelegate; diff --git a/src/device.rs b/src/device.rs index 1ab9efb..f3beaa7 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,14 +1,13 @@ use crate::authenticator_delegate::{ - DefaultDeviceFlowDelegate, DeviceFlowDelegate, PollInformation, Retry, + DefaultDeviceFlowDelegate, DeviceFlowDelegate, PollInformation, }; -use crate::error::{Error, JsonErrorOr, PollError}; +use crate::error::{AuthError, AuthErrorOr, Error}; use crate::types::{ApplicationSecret, Token}; use std::borrow::Cow; use std::time::Duration; -use ::log::error; -use chrono::{DateTime, Utc}; +use chrono::Utc; use futures::prelude::*; use hyper::header; use serde::Deserialize; @@ -27,7 +26,6 @@ pub struct DeviceFlow { pub(crate) app_secret: ApplicationSecret, pub(crate) device_code_url: Cow<'static, str>, pub(crate) flow_delegate: Box, - pub(crate) wait_duration: Duration, pub(crate) grant_type: Cow<'static, str>, } @@ -39,7 +37,6 @@ impl DeviceFlow { app_secret, device_code_url: GOOGLE_DEVICE_CODE_URL.into(), flow_delegate: Box::new(DefaultDeviceFlowDelegate), - wait_duration: Duration::from_secs(120), grant_type: GOOGLE_GRANT_TYPE.into(), } } @@ -61,18 +58,14 @@ impl DeviceFlow { ) .await?; self.flow_delegate.present_user_code(&pollinf); - tokio::timer::Timeout::new( - self.wait_for_device_token( - hyper_client, - &self.app_secret, - &pollinf, - &device_code, - &self.grant_type, - ), - self.wait_duration, + self.wait_for_device_token( + hyper_client, + &self.app_secret, + &pollinf, + &device_code, + &self.grant_type, ) .await - .map_err(|_| Error::Poll(PollError::TimedOut))? } async fn wait_for_device_token( @@ -89,28 +82,19 @@ impl DeviceFlow { let mut interval = pollinf.interval; loop { tokio::timer::delay_for(interval).await; - interval = match Self::poll_token( - &app_secret, - hyper_client, - device_code, - grant_type, - pollinf.expires_at, - &*self.flow_delegate as &dyn DeviceFlowDelegate, - ) - .await + interval = match Self::poll_token(&app_secret, hyper_client, device_code, grant_type) + .await { - Ok(None) => match self.flow_delegate.pending(&pollinf) { - Retry::Abort | Retry::Skip => return Err(Error::Poll(PollError::TimedOut)), - Retry::After(d) => d, - }, - Ok(Some(tok)) => return Ok(tok), - Err(e @ PollError::AccessDenied) - | Err(e @ PollError::TimedOut) - | Err(e @ PollError::Expired(_)) => return Err(Error::Poll(e)), - Err(ref e) => { - error!("Unknown error from poll token api: {}", e); - pollinf.interval + Ok(token) => return Ok(token), + Err(Error::AuthError(AuthError { error, .. })) + if error.as_str() == "authorization_pending" => + { + interval } + Err(Error::AuthError(AuthError { error, .. })) if error.as_str() == "slow_down" => { + interval + Duration::from_secs(5) + } + Err(err) => return Err(err), } } } @@ -170,7 +154,7 @@ impl DeviceFlow { let json_bytes = resp.into_body().try_concat().await?; let decoded: JsonData = - serde_json::from_slice::>(&json_bytes)?.into_result()?; + serde_json::from_slice::>(&json_bytes)?.into_result()?; let expires_in = decoded.expires_in.unwrap_or(60 * 60); let pi = PollInformation { user_code: decoded.user_code, @@ -204,17 +188,10 @@ impl DeviceFlow { client: &hyper::Client, device_code: &str, grant_type: &str, - expires_at: DateTime, - flow_delegate: &dyn DeviceFlowDelegate, - ) -> Result, PollError> + ) -> Result where C: hyper::client::connect::Connect + 'static, { - if expires_at <= Utc::now() { - flow_delegate.expired(expires_at); - return Err(PollError::Expired(expires_at)); - } - // We should be ready for a new request let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ @@ -229,44 +206,11 @@ impl DeviceFlow { .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(req)) .unwrap(); // TODO: Error checking - let res = client - .request(request) - .await - .map_err(PollError::HttpError)?; - let body = res - .into_body() - .try_concat() - .await - .map_err(PollError::HttpError)?; - #[derive(Deserialize)] - struct JsonError { - error: String, - } - - match serde_json::from_slice::(&body) { - Err(_) => {} // ignore, move on, it's not an error - Ok(res) => { - match res.error.as_ref() { - "access_denied" => { - flow_delegate.denied(); - return Err(PollError::AccessDenied); - } - "authorization_pending" => return Ok(None), - s => { - return Err(PollError::Other(format!( - "server message '{}' not understood", - s - ))) - } - }; - } - } - - // yes, we expect that ! - let mut t: Token = serde_json::from_slice(&body).unwrap(); + let res = client.request(request).await?; + let body = res.into_body().try_concat().await?; + let mut t = serde_json::from_slice::>(&body)?.into_result()?; t.set_expiry_absolute(); - - Ok(Some(t)) + Ok(t) } } @@ -307,7 +251,6 @@ mod tests { app_secret, device_code_url: device_code_url.into(), flow_delegate: Box::new(FD), - wait_duration: Duration::from_secs(5), grant_type: GOOGLE_GRANT_TYPE.into(), }; @@ -415,7 +358,7 @@ mod tests { .token(&client, &["https://www.googleapis.com/scope/1"]) .await; assert!(res.is_err()); - assert!(format!("{}", res.unwrap_err()).contains("Access denied by user")); + assert!(format!("{}", res.unwrap_err()).contains("access_denied")); _m.assert(); } } diff --git a/src/error.rs b/src/error.rs index 10b262f..0e132f5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,68 +1,142 @@ //! Module containing various error types. +use std::borrow::Cow; use std::error::Error as StdError; use std::fmt; use std::io; -use chrono::{DateTime, Utc}; use serde::Deserialize; +/// Error returned by the authorization server. +/// https://tools.ietf.org/html/rfc6749#section-5.2 +/// https://tools.ietf.org/html/rfc8628#section-3.5 #[derive(Deserialize, Debug)] -pub(crate) struct JsonError { - pub error: String, +pub struct AuthError { + /// Error code from the server. + pub error: AuthErrorCode, + /// Human-readable text providing additional information. pub error_description: Option, + /// A URI identifying a human-readable web page with information about the error. pub error_uri: Option, } -/// A helper type to deserialize either a JsonError or another piece of data. -#[derive(Deserialize, Debug)] -#[serde(untagged)] -pub(crate) enum JsonErrorOr { - Err(JsonError), - Data(T), -} - -impl JsonErrorOr { - pub(crate) fn into_result(self) -> Result { - match self { - JsonErrorOr::Err(err) => Result::Err(err), - JsonErrorOr::Data(value) => Result::Ok(value), +impl fmt::Display for AuthError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", &self.error.as_str())?; + if let Some(desc) = &self.error_description { + write!(f, ": {}", desc)?; } + if let Some(uri) = &self.error_uri { + write!(f, "; See {} for more info", uri)?; + } + Ok(()) } } +impl StdError for AuthError {} -/// Encapsulates all possible results of a `poll_token(...)` operation in the Device flow. -#[derive(Debug)] -pub enum PollError { - /// Connection failure - retry if you think it's worth it - HttpError(hyper::Error), - /// Indicates we are expired, including the expiration date - Expired(DateTime), - /// Indicates that the user declined access. String is server response +/// The error code returned by the authorization server. +#[derive(Debug, Clone, Eq, PartialEq)] + +pub enum AuthErrorCode { + /// invalid_request + InvalidRequest, + /// invalid_client + InvalidClient, + /// invalid_grant + InvalidGrant, + /// unauthorized_client + UnauthorizedClient, + /// unsupported_grant_type + UnsupportedGrantType, + /// invalid_scope + InvalidScope, + /// access_denied AccessDenied, - /// Indicates that too many attempts failed. - TimedOut, - /// Other type of error. + /// expired_token + ExpiredToken, + /// other error Other(String), } -impl fmt::Display for PollError { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - match *self { - PollError::HttpError(ref err) => err.fmt(f), - PollError::Expired(ref date) => writeln!(f, "Authentication expired at {}", date), - PollError::AccessDenied => "Access denied by user".fmt(f), - PollError::TimedOut => "Timed out waiting for token".fmt(f), - PollError::Other(ref s) => format!("Unknown server error: {}", s).fmt(f), +impl AuthErrorCode { + /// The error code as a &str + pub fn as_str(&self) -> &str { + match self { + AuthErrorCode::InvalidRequest => "invalid_request", + AuthErrorCode::InvalidClient => "invalid_client", + AuthErrorCode::InvalidGrant => "invalid_grant", + AuthErrorCode::UnauthorizedClient => "unauthorized_client", + AuthErrorCode::UnsupportedGrantType => "unsupported_grant_type", + AuthErrorCode::InvalidScope => "invalid_scope", + AuthErrorCode::AccessDenied => "access_denied", + AuthErrorCode::ExpiredToken => "expired_token", + AuthErrorCode::Other(s) => s.as_str(), + } + } + + fn from_string<'a>(s: impl Into>) -> AuthErrorCode { + let s = s.into(); + match s.as_ref() { + "invalid_request" => AuthErrorCode::InvalidRequest, + "invalid_client" => AuthErrorCode::InvalidClient, + "invalid_grant" => AuthErrorCode::InvalidGrant, + "unauthorized_client" => AuthErrorCode::UnauthorizedClient, + "unsupported_grant_type" => AuthErrorCode::UnsupportedGrantType, + "invalid_scope" => AuthErrorCode::InvalidScope, + "access_denied" => AuthErrorCode::AccessDenied, + "expired_token" => AuthErrorCode::ExpiredToken, + _ => AuthErrorCode::Other(s.into_owned()), } } } -impl StdError for PollError { - fn source(&self) -> Option<&(dyn StdError + 'static)> { - match *self { - PollError::HttpError(ref e) => Some(e), - _ => None, +impl From for AuthErrorCode { + fn from(s: String) -> Self { + AuthErrorCode::from_string(s) + } +} + +impl<'a> From<&'a str> for AuthErrorCode { + fn from(s: &str) -> Self { + AuthErrorCode::from_string(s) + } +} + +impl<'de> Deserialize<'de> for AuthErrorCode { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct V; + impl<'de> serde::de::Visitor<'de> for V { + type Value = AuthErrorCode; + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("any string") + } + fn visit_string(self, value: String) -> Result { + Ok(value.into()) + } + fn visit_str(self, value: &str) -> Result { + Ok(value.into()) + } + } + deserializer.deserialize_string(V) + } +} + +/// A helper type to deserialize either an AuthError or another piece of data. +#[derive(Deserialize, Debug)] +#[serde(untagged)] +pub(crate) enum AuthErrorOr { + AuthError(AuthError), + Data(T), +} + +impl AuthErrorOr { + pub(crate) fn into_result(self) -> Result { + match self { + AuthErrorOr::AuthError(err) => Result::Err(err), + AuthErrorOr::Data(value) => Result::Ok(value), } } } @@ -73,22 +147,13 @@ pub enum Error { /// Indicates connection failure HttpError(hyper::Error), /// The server returned an error. - NegativeServerResponse { - /// The error code - error: String, - /// Detailed description - error_description: Option, - }, + AuthError(AuthError), /// Error while decoding a JSON response. JSONError(serde_json::Error), /// Error within user input. UserError(String), /// A lower level IO error. LowLevelError(io::Error), - /// A poll error occurred in the DeviceFlow. - Poll(PollError), - /// An error occurred while refreshing tokens. - Refresh(RefreshError), } impl From for Error { @@ -97,12 +162,9 @@ impl From for Error { } } -impl From for Error { - fn from(value: JsonError) -> Error { - Error::NegativeServerResponse { - error: value.error, - error_description: value.error_description, - } +impl From for Error { + fn from(value: AuthError) -> Error { + Error::AuthError(value) } } @@ -112,35 +174,21 @@ impl From for Error { } } -impl From for Error { - fn from(value: RefreshError) -> Error { - Error::Refresh(value) - } -} - impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self { Error::HttpError(ref err) => err.fmt(f), - Error::NegativeServerResponse { - ref error, - ref error_description, - } => { - error.fmt(f)?; - if let Some(ref desc) = *error_description { - write!(f, ": {}", desc)?; - } - "\n".fmt(f) + Error::AuthError(ref err) => err.fmt(f), + Error::JSONError(ref e) => { + write!( + f, + "JSON Error; this might be a bug with unexpected server responses! {}", + e + )?; + Ok(()) } - Error::JSONError(ref e) => format!( - "JSON Error; this might be a bug with unexpected server responses! {}", - e - ) - .fmt(f), Error::UserError(ref s) => s.fmt(f), Error::LowLevelError(ref e) => e.fmt(f), - Error::Poll(ref pe) => pe.fmt(f), - Error::Refresh(ref rr) => format!("{:?}", rr).fmt(f), } } } @@ -149,39 +197,55 @@ impl StdError for Error { fn source(&self) -> Option<&(dyn StdError + 'static)> { match *self { Error::HttpError(ref err) => Some(err), - Error::LowLevelError(ref err) => Some(err), + Error::AuthError(ref err) => Some(err), Error::JSONError(ref err) => Some(err), + Error::LowLevelError(ref err) => Some(err), _ => None, } } } -/// All possible outcomes of the refresh flow -#[derive(Debug)] -pub enum RefreshError { - /// Indicates connection failure - HttpError(hyper::Error), - /// The server did not answer with a new token, providing the server message - ServerError(String, Option), -} +#[cfg(test)] +mod tests { + use super::*; -impl From for RefreshError { - fn from(value: hyper::Error) -> Self { - RefreshError::HttpError(value) - } -} - -impl From for RefreshError { - fn from(value: JsonError) -> Self { - RefreshError::ServerError(value.error, value.error_description) - } -} - -impl From for RefreshError { - fn from(_value: serde_json::Error) -> Self { - RefreshError::ServerError( - "failed to deserialize json token from refresh response".to_owned(), - None, - ) + #[test] + fn test_auth_error_code_deserialize() { + assert_eq!( + AuthErrorCode::InvalidRequest, + serde_json::from_str(r#""invalid_request""#).unwrap() + ); + assert_eq!( + AuthErrorCode::InvalidClient, + serde_json::from_str(r#""invalid_client""#).unwrap() + ); + assert_eq!( + AuthErrorCode::InvalidGrant, + serde_json::from_str(r#""invalid_grant""#).unwrap() + ); + assert_eq!( + AuthErrorCode::UnauthorizedClient, + serde_json::from_str(r#""unauthorized_client""#).unwrap() + ); + assert_eq!( + AuthErrorCode::UnsupportedGrantType, + serde_json::from_str(r#""unsupported_grant_type""#).unwrap() + ); + assert_eq!( + AuthErrorCode::InvalidScope, + serde_json::from_str(r#""invalid_scope""#).unwrap() + ); + assert_eq!( + AuthErrorCode::AccessDenied, + serde_json::from_str(r#""access_denied""#).unwrap() + ); + assert_eq!( + AuthErrorCode::ExpiredToken, + serde_json::from_str(r#""expired_token""#).unwrap() + ); + assert_eq!( + AuthErrorCode::Other("undefined".to_owned()), + serde_json::from_str(r#""undefined""#).unwrap() + ); } } diff --git a/src/installed.rs b/src/installed.rs index 54ae4de..5affa37 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -3,7 +3,7 @@ // Refer to the project root for licensing information. // use crate::authenticator_delegate::{DefaultInstalledFlowDelegate, InstalledFlowDelegate}; -use crate::error::{Error, JsonErrorOr}; +use crate::error::{AuthErrorOr, Error}; use crate::types::{ApplicationSecret, Token}; use std::convert::AsRef; @@ -201,7 +201,7 @@ impl InstalledFlow { refresh_token, token_type, expires_in, - } = serde_json::from_slice::>(&body)?.into_result()?; + } = serde_json::from_slice::>(&body)?.into_result()?; let mut token = Token { access_token, refresh_token, diff --git a/src/refresh.rs b/src/refresh.rs index e312287..ccf2d4d 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -1,4 +1,4 @@ -use crate::error::{JsonErrorOr, RefreshError}; +use crate::error::{AuthErrorOr, Error}; use crate::types::{ApplicationSecret, Token}; use chrono::Utc; @@ -33,7 +33,7 @@ impl RefreshFlow { client: &hyper::Client, client_secret: &ApplicationSecret, refresh_token: &str, - ) -> Result { + ) -> Result { let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ ("client_id", client_secret.client_id.as_str()), @@ -46,7 +46,7 @@ impl RefreshFlow { let request = hyper::Request::post(&client_secret.token_uri) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(req)) - .unwrap(); // TODO: error handling + .unwrap(); let resp = client.request(request).await?; let body = resp.into_body().try_concat().await?; @@ -62,7 +62,7 @@ impl RefreshFlow { access_token, token_type, expires_in, - } = serde_json::from_slice::>(&body)?.into_result()?; + } = serde_json::from_slice::>(&body)?.into_result()?; Ok(Token { access_token, token_type, @@ -116,13 +116,16 @@ mod tests { .match_body( mockito::Matcher::Regex(".*client_id=902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com.*refresh_token=my-refresh-token.*".to_string())) .with_status(400) - .with_body(r#"{"error": "invalid_token"}"#) + .with_body(r#"{"error": "invalid_request"}"#) .create(); let rr = RefreshFlow::refresh_token(&client, &app_secret, refresh_token).await; match rr { - Err(RefreshError::ServerError(e, None)) => { - assert_eq!(e, "invalid_token"); + Err(Error::AuthError(auth_error)) => { + assert_eq!( + auth_error.error, + crate::error::AuthErrorCode::InvalidRequest + ); } _ => panic!(format!("unexpected RefreshResult {:?}", rr)), } diff --git a/src/service_account.rs b/src/service_account.rs index 71ac5c2..89140bd 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -11,7 +11,7 @@ //! Copyright (c) 2016 Google Inc (lewinb@google.com). //! -use crate::error::{Error, JsonErrorOr}; +use crate::error::{AuthErrorOr, Error}; use crate::types::Token; use std::io; @@ -215,7 +215,7 @@ impl ServiceAccountFlow { access_token, token_type, expires_in, - } = serde_json::from_slice::>(&body)?.into_result()?; + } = serde_json::from_slice::>(&body)?.into_result()?; let expires_ts = chrono::Utc::now().timestamp() + expires_in; Ok(Token { access_token, From 0525926bb2bb2144e87502cb4c71173f8cd62fa4 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 22 Nov 2019 09:50:31 -0800 Subject: [PATCH 56/71] Improve Token Remove expires_in in favor of only having an expires_at DateTime field. Add a from_json method that deserializes from json data into the appropriate Token (or Error) and use that consistently throughout the codebase. --- Cargo.toml | 2 +- src/device.rs | 4 +-- src/installed.rs | 28 ++-------------- src/refresh.rs | 25 ++------------ src/service_account.rs | 31 +++-------------- src/types.rs | 75 ++++++++++++++++++++---------------------- 6 files changed, 46 insertions(+), 119 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5edf915..e75d053 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ edition = "2018" [dependencies] base64 = "0.10" -chrono = "0.4" +chrono = { version = "0.4", features = ["serde"] } http = "0.1" hyper = {version = "0.13.0-alpha.4", features = ["unstable-stream"]} hyper-rustls = "=0.18.0-alpha.2" diff --git a/src/device.rs b/src/device.rs index f3beaa7..31ff37f 100644 --- a/src/device.rs +++ b/src/device.rs @@ -208,9 +208,7 @@ impl DeviceFlow { .unwrap(); // TODO: Error checking let res = client.request(request).await?; let body = res.into_body().try_concat().await?; - let mut t = serde_json::from_slice::>(&body)?.into_result()?; - t.set_expiry_absolute(); - Ok(t) + Token::from_json(&body) } } diff --git a/src/installed.rs b/src/installed.rs index 5affa37..22034d8 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -3,7 +3,7 @@ // Refer to the project root for licensing information. // use crate::authenticator_delegate::{DefaultInstalledFlowDelegate, InstalledFlowDelegate}; -use crate::error::{AuthErrorOr, Error}; +use crate::error::Error; use crate::types::{ApplicationSecret, Token}; use std::convert::AsRef; @@ -15,7 +15,6 @@ use std::sync::{Arc, Mutex}; use futures::future::FutureExt; use futures_util::try_stream::TryStreamExt; use hyper::header; -use serde::Deserialize; use tokio::sync::oneshot; use url::form_urlencoded; use url::percent_encoding::{percent_encode, QUERY_ENCODE_SET}; @@ -187,30 +186,7 @@ impl InstalledFlow { let request = Self::request_token(app_secret, authcode, redirect_uri, server_addr); let resp = hyper_client.request(request).await?; let body = resp.into_body().try_concat().await?; - - #[derive(Deserialize)] - struct JSONTokenResponse { - access_token: String, - refresh_token: Option, - token_type: String, - expires_in: Option, - } - - let JSONTokenResponse { - access_token, - refresh_token, - token_type, - expires_in, - } = serde_json::from_slice::>(&body)?.into_result()?; - let mut token = Token { - access_token, - refresh_token, - token_type, - expires_in, - expires_in_timestamp: None, - }; - token.set_expiry_absolute(); - Ok(token) + Token::from_json(&body) } /// Sends the authorization code to the provider in order to obtain access and refresh tokens. diff --git a/src/refresh.rs b/src/refresh.rs index ccf2d4d..059e1da 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -1,10 +1,8 @@ -use crate::error::{AuthErrorOr, Error}; +use crate::error::Error; use crate::types::{ApplicationSecret, Token}; -use chrono::Utc; use futures_util::try_stream::TryStreamExt; use hyper::header; -use serde::Deserialize; use url::form_urlencoded; /// Implements the [OAuth2 Refresh Token Flow](https://developers.google.com/youtube/v3/guides/authentication#devices). @@ -50,26 +48,7 @@ impl RefreshFlow { let resp = client.request(request).await?; let body = resp.into_body().try_concat().await?; - - #[derive(Deserialize)] - struct JsonToken { - access_token: String, - token_type: String, - expires_in: i64, - } - - let JsonToken { - access_token, - token_type, - expires_in, - } = serde_json::from_slice::>(&body)?.into_result()?; - Ok(Token { - access_token, - token_type, - refresh_token: Some(refresh_token.to_string()), - expires_in: None, - expires_in_timestamp: Some(Utc::now().timestamp() + expires_in), - }) + Token::from_json(&body) } } diff --git a/src/service_account.rs b/src/service_account.rs index 89140bd..168edc2 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -11,7 +11,7 @@ //! Copyright (c) 2016 Google Inc (lewinb@google.com). //! -use crate::error::{AuthErrorOr, Error}; +use crate::error::Error; use crate::types::Token; use std::io; @@ -202,28 +202,7 @@ impl ServiceAccountFlow { .unwrap(); let response = hyper_client.request(request).await?; let body = response.into_body().try_concat().await?; - - /// This is the schema of the server's response. - #[derive(Deserialize, Debug)] - struct TokenResponse { - access_token: String, - token_type: String, - expires_in: i64, - } - - let TokenResponse { - access_token, - token_type, - expires_in, - } = serde_json::from_slice::>(&body)?.into_result()?; - let expires_ts = chrono::Utc::now().timestamp() + expires_in; - Ok(Token { - access_token, - token_type, - refresh_token: None, - expires_in: Some(expires_in), - expires_in_timestamp: Some(expires_ts), - }) + Token::from_json(&body) } } @@ -232,6 +211,7 @@ mod tests { use super::*; use crate::helper::read_service_account_key; use crate::parse_json; + use chrono::Utc; use hyper_rustls::HttpsConnector; use mockito::mock; @@ -263,8 +243,7 @@ mod tests { "token_type": "Bearer" }); let bad_json_response = serde_json::json!({ - "access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk", - "token_type": "Bearer" + "error": "access_denied", }); // Successful path. @@ -285,7 +264,7 @@ mod tests { .await .expect("token failed"); assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); - assert_eq!(Some(3600), tok.expires_in); + assert!(Utc::now() + chrono::Duration::seconds(3600) >= tok.expires_at.unwrap()); _m.assert(); } // Malformed response. diff --git a/src/types.rs b/src/types.rs index 34218b3..70a7aeb 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,4 +1,6 @@ -use chrono::{DateTime, TimeZone, Utc}; +use crate::error::{AuthErrorOr, Error}; + +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; /// Represents a token as returned by OAuth2 servers. @@ -21,50 +23,43 @@ pub struct Token { pub refresh_token: Option, /// The token type as string - usually 'Bearer'. pub token_type: String, - /// access_token will expire after this amount of time. - /// Prefer using expiry_date() - pub expires_in: Option, - /// timestamp is seconds since epoch indicating when the token will expire in absolute terms. - /// use expiry_date() to convert to DateTime. - pub expires_in_timestamp: Option, + /// The time when the token expires. + pub expires_at: Option>, } impl Token { + pub(crate) fn from_json(json_data: &[u8]) -> Result { + #[derive(Deserialize)] + struct RawToken { + access_token: String, + refresh_token: Option, + token_type: String, + expires_in: Option, + } + + let RawToken { + access_token, + refresh_token, + token_type, + expires_in, + } = serde_json::from_slice::>(json_data)?.into_result()?; + + let expires_at = expires_in + .map(|seconds_from_now| Utc::now() + chrono::Duration::seconds(seconds_from_now)); + + Ok(Token { + access_token, + refresh_token, + token_type, + expires_at, + }) + } + /// Returns true if we are expired. - /// - /// # Panics - /// * if our access_token is unset pub fn expired(&self) -> bool { - if self.access_token.is_empty() { - panic!("called expired() on unset token"); - } - if let Some(expiry_date) = self.expiry_date() { - expiry_date - chrono::Duration::minutes(1) <= Utc::now() - } else { - false - } - } - - /// Returns a DateTime object representing our expiry date. - pub fn expiry_date(&self) -> Option> { - let expires_in_timestamp = self.expires_in_timestamp?; - - Utc.timestamp(expires_in_timestamp, 0).into() - } - - /// Adjust our stored expiry format to be absolute, using the current time. - pub fn set_expiry_absolute(&mut self) -> &mut Token { - if self.expires_in_timestamp.is_some() { - assert!(self.expires_in.is_none()); - return self; - } - - if let Some(expires_in) = self.expires_in { - self.expires_in_timestamp = Some(Utc::now().timestamp() + expires_in); - self.expires_in = None; - } - - self + self.expires_at + .map(|expiration_time| expiration_time - chrono::Duration::minutes(1) <= Utc::now()) + .unwrap_or(false) } } From 4521e2f246e6e3013bc3ae6d633a8d9d83a98763 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 22 Nov 2019 10:29:55 -0800 Subject: [PATCH 57/71] Rename PollInformation DeviceAuthResponse. Have it correctly handle either verification_uri or verification_url and deserialize into a struct that has the data types desired. --- src/authenticator_delegate.rs | 65 +++++++++++++++++++++++++++++++---- src/device.rs | 62 +++++++++++---------------------- 2 files changed, 78 insertions(+), 49 deletions(-) diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 9374c1d..7537bc2 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -1,4 +1,5 @@ //! Module containing types related to delegates. +use crate::error::{AuthErrorOr, Error}; use std::pin::Pin; use std::time::Duration; @@ -8,12 +9,13 @@ use futures::prelude::*; /// Contains state of pending authentication requests #[derive(Clone, Debug, PartialEq)] -pub struct PollInformation { +pub struct DeviceAuthResponse { + /// The device verification code. + pub device_code: String, /// Code the user must enter ... pub user_code: String, - /// ... at the verification URL - pub verification_url: String, - + /// ... at the verification URI + pub verification_uri: String, /// The `user_code` expires at the given time /// It's the time the user has left to authenticate your application pub expires_at: DateTime, @@ -22,17 +24,66 @@ pub struct PollInformation { pub interval: Duration, } +impl DeviceAuthResponse { + pub(crate) fn from_json(json_data: &[u8]) -> Result { + Ok(serde_json::from_slice::>(json_data)?.into_result()?) + } +} + +impl<'de> serde::Deserialize<'de> for DeviceAuthResponse { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(serde::Deserialize)] + struct RawDeviceAuthResponse { + device_code: String, + user_code: String, + // The standard dictates that verification_uri is required, but + // sadly google uses verification_url currently. One of these two + // fields need to be set and verification_uri takes precedence if + // they both are set. + verification_uri: Option, + verification_url: Option, + expires_in: i64, + interval: Option, + } + + let RawDeviceAuthResponse { + device_code, + user_code, + verification_uri, + verification_url, + expires_in, + interval, + } = RawDeviceAuthResponse::deserialize(deserializer)?; + + let verification_uri = verification_uri.or(verification_url).ok_or_else(|| { + serde::de::Error::custom("neither verification_uri nor verification_url specified") + })?; + let expires_at = Utc::now() + chrono::Duration::seconds(expires_in); + let interval = Duration::from_secs(interval.unwrap_or(5)); + Ok(DeviceAuthResponse { + device_code, + user_code, + verification_uri, + expires_at, + interval, + }) + } +} + /// DeviceFlowDelegate methods are called when a device flow needs to ask the /// application what to do in certain cases. pub trait DeviceFlowDelegate: Send + Sync { /// The server has returned a `user_code` which must be shown to the user, - /// along with the `verification_url`. + /// along with the `verification_uri`. /// # Notes /// * Will be called exactly once, provided we didn't abort during `request_code` phase. - fn present_user_code(&self, pi: &PollInformation) { + fn present_user_code(&self, pi: &DeviceAuthResponse) { println!( "Please enter {} at {} and grant access to this application", - pi.user_code, pi.verification_url + pi.user_code, pi.verification_uri ); println!("Do not close this application until you either denied or granted access."); println!( diff --git a/src/device.rs b/src/device.rs index 31ff37f..e32e6fa 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,16 +1,14 @@ use crate::authenticator_delegate::{ - DefaultDeviceFlowDelegate, DeviceFlowDelegate, PollInformation, + DefaultDeviceFlowDelegate, DeviceAuthResponse, DeviceFlowDelegate, }; -use crate::error::{AuthError, AuthErrorOr, Error}; +use crate::error::{AuthError, Error}; use crate::types::{ApplicationSecret, Token}; use std::borrow::Cow; use std::time::Duration; -use chrono::Utc; use futures::prelude::*; use hyper::header; -use serde::Deserialize; use url::form_urlencoded; pub const GOOGLE_DEVICE_CODE_URL: &str = "https://accounts.google.com/o/oauth2/device/code"; @@ -50,19 +48,18 @@ impl DeviceFlow { T: AsRef, C: hyper::client::connect::Connect + 'static, { - let (pollinf, device_code) = Self::request_code( + let device_auth_resp = Self::request_code( &self.app_secret, hyper_client, &self.device_code_url, scopes, ) .await?; - self.flow_delegate.present_user_code(&pollinf); + self.flow_delegate.present_user_code(&device_auth_resp); self.wait_for_device_token( hyper_client, &self.app_secret, - &pollinf, - &device_code, + &device_auth_resp, &self.grant_type, ) .await @@ -72,18 +69,22 @@ impl DeviceFlow { &self, hyper_client: &hyper::Client, app_secret: &ApplicationSecret, - pollinf: &PollInformation, - device_code: &str, + device_auth_resp: &DeviceAuthResponse, grant_type: &str, ) -> Result where C: hyper::client::connect::Connect + 'static, { - let mut interval = pollinf.interval; + let mut interval = device_auth_resp.interval; loop { tokio::timer::delay_for(interval).await; - interval = match Self::poll_token(&app_secret, hyper_client, device_code, grant_type) - .await + interval = match Self::poll_token( + &app_secret, + hyper_client, + &device_auth_resp.device_code, + grant_type, + ) + .await { Ok(token) => return Ok(token), Err(Error::AuthError(AuthError { error, .. })) @@ -119,7 +120,7 @@ impl DeviceFlow { client: &hyper::Client, device_code_url: &str, scopes: &[T], - ) -> Result<(PollInformation, String), Error> + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -138,37 +139,14 @@ impl DeviceFlow { .body(hyper::Body::from(req)) .unwrap(); let resp = client.request(req).await?; - // This return type is defined in https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.2 - // The alias is present as Google use a non-standard name for verification_uri. - // According to the standard interval is optional, however, all tested implementations provide it. - // verification_uri_complete is optional in the standard but not provided in tested implementations. - #[derive(Deserialize)] - struct JsonData { - device_code: String, - user_code: String, - #[serde(alias = "verification_url")] - verification_uri: String, - expires_in: Option, - interval: i64, - } - - let json_bytes = resp.into_body().try_concat().await?; - let decoded: JsonData = - serde_json::from_slice::>(&json_bytes)?.into_result()?; - let expires_in = decoded.expires_in.unwrap_or(60 * 60); - let pi = PollInformation { - user_code: decoded.user_code, - verification_url: decoded.verification_uri, - expires_at: Utc::now() + chrono::Duration::seconds(expires_in), - interval: Duration::from_secs(i64::abs(decoded.interval) as u64), - }; - Ok((pi, decoded.device_code)) + let body = resp.into_body().try_concat().await?; + DeviceAuthResponse::from_json(&body) } /// If the first call is successful, this method may be called. /// As long as we are waiting for authentication, it will return `Ok(None)`. /// You should call it within the interval given the previously returned - /// `PollInformation.interval` field. + /// `DeviceAuthResponse.interval` field. /// /// The operation was successful once you receive an Ok(Some(Token)) for the first time. /// Subsequent calls will return the previous result, which may also be an error state. @@ -223,8 +201,8 @@ mod tests { #[derive(Clone)] struct FD; impl DeviceFlowDelegate for FD { - fn present_user_code(&self, pi: &PollInformation) { - assert_eq!("https://example.com/verify", pi.verification_url); + fn present_user_code(&self, pi: &DeviceAuthResponse) { + assert_eq!("https://example.com/verify", pi.verification_uri); } } From 0a4c1e79d244503680e2b1ac8c74ed092419c898 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 22 Nov 2019 11:09:53 -0800 Subject: [PATCH 58/71] Make DeviceFlowDelegate::present_user_code return a Future. This is to allow for implementations to use async code. The returned Future will be awaited before polling for the token begins. --- src/authenticator_delegate.rs | 27 +++++++++++++++++---------- src/device.rs | 11 +++++++++-- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 7537bc2..2f3933f 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -80,19 +80,26 @@ pub trait DeviceFlowDelegate: Send + Sync { /// along with the `verification_uri`. /// # Notes /// * Will be called exactly once, provided we didn't abort during `request_code` phase. - fn present_user_code(&self, pi: &DeviceAuthResponse) { - println!( - "Please enter {} at {} and grant access to this application", - pi.user_code, pi.verification_uri - ); - println!("Do not close this application until you either denied or granted access."); - println!( - "You have time until {}.", - pi.expires_at.with_timezone(&Local) - ); + fn present_user_code<'a>( + &'a self, + device_auth_resp: &'a DeviceAuthResponse, + ) -> Pin + Send + 'a>> { + Box::pin(present_user_code(device_auth_resp)) } } +async fn present_user_code(device_auth_resp: &DeviceAuthResponse) { + println!( + "Please enter {} at {} and grant access to this application", + device_auth_resp.user_code, device_auth_resp.verification_uri + ); + println!("Do not close this application until you either denied or granted access."); + println!( + "You have time until {}.", + device_auth_resp.expires_at.with_timezone(&Local) + ); +} + /// InstalledFlowDelegate methods are called when an installed flow needs to ask /// the application what to do in certain cases. pub trait InstalledFlowDelegate: Send + Sync { diff --git a/src/device.rs b/src/device.rs index e32e6fa..6888091 100644 --- a/src/device.rs +++ b/src/device.rs @@ -55,7 +55,9 @@ impl DeviceFlow { scopes, ) .await?; - self.flow_delegate.present_user_code(&device_auth_resp); + self.flow_delegate + .present_user_code(&device_auth_resp) + .await; self.wait_for_device_token( hyper_client, &self.app_secret, @@ -193,6 +195,7 @@ impl DeviceFlow { #[cfg(test)] mod tests { use hyper_rustls::HttpsConnector; + use std::pin::Pin; use super::*; @@ -201,8 +204,12 @@ mod tests { #[derive(Clone)] struct FD; impl DeviceFlowDelegate for FD { - fn present_user_code(&self, pi: &DeviceAuthResponse) { + fn present_user_code<'a>( + &'a self, + pi: &'a DeviceAuthResponse, + ) -> Pin + 'a + Send>> { assert_eq!("https://example.com/verify", pi.verification_uri); + Box::pin(futures::future::ready(())) } } From 635bd5e21a9c3f666bab040f6abe052c224c0f65 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 22 Nov 2019 16:30:15 -0800 Subject: [PATCH 59/71] Fix a bug introduced in the storage layer. When bloom filters were added the btreemap values changed to be a vector of tokens to accomodate the possibility of bloom filter collisions. The implementation naively just pushed new tokens onto the vec even if they were replacing previous tokens meaning old tokens were still kept around even after a refresh has replaced it. To fix this efficiently the storage layer now tracks both a hash value and a bloom filter along with each token. Their is a map keyed by hash for every token that points to a reference counted version of the token, and each token also exists in a separate vector. Updates to existing tokens happens in place, when new entries are added they are added to both data structures. --- src/authenticator.rs | 2 +- src/refresh.rs | 8 +- src/storage.rs | 237 +++++++++++++++++++++++++++---------------- src/types.rs | 7 -- 4 files changed, 159 insertions(+), 95 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index e639f99..132d2fa 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -31,7 +31,7 @@ where where T: AsRef, { - let hashed_scopes = storage::ScopesAndFilter::from(scopes); + let hashed_scopes = storage::ScopeSet::from(scopes); match (self.storage.get(hashed_scopes), self.auth_flow.app_secret()) { (Some(t), _) if !t.expired() => { // unexpired token found diff --git a/src/refresh.rs b/src/refresh.rs index 059e1da..910a38b 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -48,7 +48,13 @@ impl RefreshFlow { let resp = client.request(request).await?; let body = resp.into_body().try_concat().await?; - Token::from_json(&body) + let mut token = Token::from_json(&body)?; + // If the refresh result contains a refresh_token use it, otherwise + // continue using our previous refresh_token. + token + .refresh_token + .get_or_insert_with(|| refresh_token.to_owned()); + Ok(token) } } diff --git a/src/storage.rs b/src/storage.rs index 3172c85..b3e0c58 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -4,9 +4,11 @@ // use crate::types::Token; +use std::cell::RefCell; use std::collections::BTreeMap; use std::io; use std::path::{Path, PathBuf}; +use std::rc::Rc; use std::sync::Mutex; use serde::{Deserialize, Serialize}; @@ -22,12 +24,15 @@ use serde::{Deserialize, Serialize}; // definitively not a superset. // The current implementation uses a 64bit bloom filter with 4 hash functions. +/// ScopeHash is a hash value derived from a list of scopes. The hash value +/// represents a fingerprint of the set of scopes *independent* of the ordering. +#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] +struct ScopeHash(u64); + /// ScopeFilter represents a filter for a set of scopes. It can definitively /// prove that a given list of scopes is not a subset of another. #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] -struct ScopeFilter { - bitmask: u64, -} +struct ScopeFilter(u64); #[derive(Debug, Copy, Clone, Eq, PartialEq)] enum FilterResponse { @@ -36,27 +41,9 @@ enum FilterResponse { } impl ScopeFilter { - fn new(scopes: &[T]) -> Self - where - T: AsRef, - { - let mut bitmask = 0u64; - for scope in scopes { - let scope_hash = seahash::hash(scope.as_ref().as_bytes()); - // Use the first 4 6-bit chunks of the seahash as the 4 hash values - // in the bloom filter. - for i in 0..4 { - // h is a hash derived value in the range 0..64 - let h = (scope_hash >> (6 * i)) & 0b11_1111; - bitmask |= 1 << h; - } - } - ScopeFilter { bitmask } - } - /// Determine if this ScopeFilter could be a subset of the provided filter. fn is_subset_of(self, filter: ScopeFilter) -> FilterResponse { - if self.bitmask & filter.bitmask == self.bitmask { + if self.0 & filter.0 == self.0 { FilterResponse::Maybe } else { FilterResponse::No @@ -65,34 +52,26 @@ impl ScopeFilter { } #[derive(Debug)] -pub struct ScopesAndFilter<'a, T> { +pub(crate) struct ScopeSet<'a, T> { + hash: ScopeHash, filter: ScopeFilter, scopes: &'a [T], } // Implement Clone manually. Auto derive fails to work correctly because we want // Clone to be implemented regardless of whether T is Clone or not. -impl<'a, T> Clone for ScopesAndFilter<'a, T> { +impl<'a, T> Clone for ScopeSet<'a, T> { fn clone(&self) -> Self { - ScopesAndFilter { + ScopeSet { + hash: self.hash, filter: self.filter, scopes: self.scopes, } } } -impl<'a, T> Copy for ScopesAndFilter<'a, T> {} +impl<'a, T> Copy for ScopeSet<'a, T> {} -impl<'a, T> From<&'a [T]> for ScopesAndFilter<'a, T> -where - T: AsRef, -{ - fn from(scopes: &'a [T]) -> Self { - let filter = ScopeFilter::new(scopes); - ScopesAndFilter { filter, scopes } - } -} - -impl<'a, T> ScopesAndFilter<'a, T> +impl<'a, T> ScopeSet<'a, T> where T: AsRef, { @@ -102,7 +81,29 @@ where // From trait. This inherent method just serves to auto deref from array // refs to slices and proxy to the From impl. pub fn from(scopes: &'a [T]) -> Self { - >::from(scopes) + let (hash, filter) = scopes.iter().fold( + (ScopeHash(0), ScopeFilter(0)), + |(mut scope_hash, mut scope_filter), scope| { + let h = seahash::hash(scope.as_ref().as_bytes()); + + // Use the first 4 6-bit chunks of the seahash as the 4 hash values + // in the bloom filter. + for i in 0..4 { + // h is a hash derived value in the range 0..64 + let h = (h >> (6 * i)) & 0b11_1111; + scope_filter.0 |= 1 << h; + } + + // xor the hashes together to get an order independent fingerprint. + scope_hash.0 ^= h; + (scope_hash, scope_filter) + }, + ); + ScopeSet { + hash, + filter, + scopes, + } } } @@ -112,7 +113,7 @@ pub(crate) enum Storage { } impl Storage { - pub(crate) async fn set(&self, scopes: ScopesAndFilter<'_, T>, token: Token) + pub(crate) async fn set(&self, scopes: ScopeSet<'_, T>, token: Token) where T: AsRef, { @@ -122,7 +123,7 @@ impl Storage { } } - pub(crate) fn get(&self, scopes: ScopesAndFilter) -> Option + pub(crate) fn get(&self, scopes: ScopeSet) -> Option where T: AsRef, { @@ -134,85 +135,149 @@ impl Storage { } /// A single stored token. -#[derive(Debug, Clone, Serialize, Deserialize)] + +#[derive(Debug, Clone)] struct JSONToken { - pub scopes: Vec, - pub token: Token, + scopes: Vec, + token: Token, + hash: ScopeHash, + filter: ScopeFilter, +} + +impl<'de> Deserialize<'de> for JSONToken { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct RawJSONToken { + scopes: Vec, + token: Token, + } + let RawJSONToken { scopes, token } = RawJSONToken::deserialize(deserializer)?; + let ScopeSet { hash, filter, .. } = ScopeSet::from(&scopes); + Ok(JSONToken { + scopes, + token, + hash, + filter, + }) + } +} + +impl Serialize for JSONToken { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + #[derive(Serialize)] + struct RawJSONToken<'a> { + scopes: &'a [String], + token: &'a Token, + } + RawJSONToken { + scopes: &self.scopes, + token: &self.token, + } + .serialize(serializer) + } } /// List of tokens in a JSON object #[derive(Debug, Clone)] pub(crate) struct JSONTokens { - token_map: BTreeMap>, + token_map: BTreeMap>>, + tokens: Vec>>, } impl JSONTokens { pub(crate) fn new() -> Self { JSONTokens { token_map: BTreeMap::new(), + tokens: Vec::new(), } } pub(crate) async fn load_from_file(filename: &Path) -> Result { let contents = tokio::fs::read(filename).await?; - let token_vec: Vec = serde_json::from_slice(&contents) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - let mut token_map: BTreeMap> = BTreeMap::new(); - for token in token_vec { - let filter = ScopesAndFilter::from(&token.scopes).filter; - token_map.entry(filter).or_default().push(token); + let tokens: Vec>> = + serde_json::from_slice::>(&contents) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? + .into_iter() + .map(|json_token| Rc::new(RefCell::new(json_token))) + .collect(); + let mut token_map: BTreeMap>> = BTreeMap::new(); + for token in tokens.iter().cloned() { + let hash = token.borrow().hash; + token_map.insert(hash, token); } - Ok(JSONTokens { token_map }) + Ok(JSONTokens { token_map, tokens }) } - fn get(&self, ScopesAndFilter { filter, scopes }: ScopesAndFilter) -> Option + fn get( + &self, + ScopeSet { + hash, + filter, + scopes, + }: ScopeSet, + ) -> Option where T: AsRef, { + if let Some(json_token) = self.token_map.get(&hash) { + return Some(json_token.borrow().token.clone()); + } + let requested_scopes_are_subset_of = |other_scopes: &[String]| { scopes .iter() .all(|s| other_scopes.iter().any(|t| t.as_str() == s.as_ref())) }; - // Check for exact match of bloom filter first. In the common case an - // application will provide the same set of scopes repeatedly. If a - // token exists for the exact scope list requested a lookup of the - // ScopeFilter will return a list that would contain it. - if let Some(t) = self - .token_map - .get(&filter) - .into_iter() - .flat_map(|tokens_matching_filter| tokens_matching_filter.iter()) - .find(|js_token: &&JSONToken| requested_scopes_are_subset_of(&js_token.scopes)) - { - return Some(t.token.clone()); - } - // No exact match for the scopes provided. Search for any tokens that // exist for a superset of the scopes requested. - self.token_map + self.tokens .iter() - .filter(|(k, _)| filter.is_subset_of(**k) == FilterResponse::Maybe) - .flat_map(|(_, tokens_matching_filter)| tokens_matching_filter.iter()) - .find(|v: &&JSONToken| requested_scopes_are_subset_of(&v.scopes)) - .map(|t: &JSONToken| t.token.clone()) + .filter(|json_token| { + filter.is_subset_of(json_token.borrow().filter) == FilterResponse::Maybe + }) + .find(|v: &&Rc>| requested_scopes_are_subset_of(&v.borrow().scopes)) + .map(|t: &Rc>| t.borrow().token.clone()) } - fn set(&mut self, ScopesAndFilter { filter, scopes }: ScopesAndFilter, token: Token) - where + fn set( + &mut self, + ScopeSet { + hash, + filter, + scopes, + }: ScopeSet, + token: Token, + ) where T: AsRef, { - self.token_map.entry(filter).or_default().push(JSONToken { - scopes: scopes.iter().map(|x| x.as_ref().to_string()).collect(), - token, - }); + use std::collections::btree_map::Entry; + match self.token_map.entry(hash) { + Entry::Occupied(entry) => { + entry.get().borrow_mut().token = token; + } + Entry::Vacant(entry) => { + let json_token = Rc::new(RefCell::new(JSONToken { + scopes: scopes.iter().map(|x| x.as_ref().to_owned()).collect(), + token, + hash, + filter, + })); + entry.insert(json_token.clone()); + self.tokens.push(json_token); + } + } } fn all_tokens(&self) -> Vec { - self.token_map - .values() - .flat_map(|v| v.iter()) - .cloned() + self.tokens + .iter() + .map(|t: &Rc>| t.borrow().clone()) .collect() } } @@ -255,7 +320,7 @@ impl DiskStorage { }) } - async fn set(&self, scopes: ScopesAndFilter<'_, T>, token: Token) + async fn set(&self, scopes: ScopeSet<'_, T>, token: Token) where T: AsRef, { @@ -271,7 +336,7 @@ impl DiskStorage { .expect("disk storage task not running"); } - pub(crate) fn get(&self, scopes: ScopesAndFilter) -> Option + pub(crate) fn get(&self, scopes: ScopeSet) -> Option where T: AsRef, { @@ -285,9 +350,9 @@ mod tests { #[test] fn test_scope_filter() { - let foo = ScopeFilter::new(&["foo"]); - let bar = ScopeFilter::new(&["bar"]); - let foobar = ScopeFilter::new(&["foo", "bar"]); + let foo = ScopeSet::from(&["foo"]).filter; + let bar = ScopeSet::from(&["bar"]).filter; + let foobar = ScopeSet::from(&["foo", "bar"]).filter; // foo and bar are both subsets of foobar. This condition should hold no // matter what changes are made to the bloom filter implementation. diff --git a/src/types.rs b/src/types.rs index 70a7aeb..c9b3225 100644 --- a/src/types.rs +++ b/src/types.rs @@ -8,13 +8,6 @@ use serde::{Deserialize, Serialize}; /// It is produced by all authentication flows. /// It authenticates certain operations, and must be refreshed once /// it reached it's expiry date. -/// -/// The type is tuned to be suitable for direct de-serialization from server -/// replies, as well as for serialization for later reuse. This is the reason -/// for the two fields dealing with expiry - once in relative in and once in -/// absolute terms. -/// -/// Utility methods make common queries easier, see `expired()`. #[derive(Clone, PartialEq, Debug, Deserialize, Serialize)] pub struct Token { /// used when authenticating calls to oauth2 enabled services. From 50824c7777719bef058e0c6248c7235630e0376a Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Sat, 23 Nov 2019 14:52:26 -0800 Subject: [PATCH 60/71] Use Arc> rather than Rc> in DiskStorage. This keeps DiskStorage Sync + Send and therefore Authenticator Sync + Send. The DiskStorage was threadsafe because JSONTokens contains a Mutex around all the Rc> objects, but there's no way to prove to the type system that none of the Rc's get cloned to an alias used outside the Mutex so it's not provably safe. I'll probably reevaluate the design here, but in the meantime the double locking is fine. --- src/storage.rs | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/storage.rs b/src/storage.rs index b3e0c58..67ea4e1 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -4,12 +4,10 @@ // use crate::types::Token; -use std::cell::RefCell; use std::collections::BTreeMap; use std::io; use std::path::{Path, PathBuf}; -use std::rc::Rc; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; use serde::{Deserialize, Serialize}; @@ -186,8 +184,8 @@ impl Serialize for JSONToken { /// List of tokens in a JSON object #[derive(Debug, Clone)] pub(crate) struct JSONTokens { - token_map: BTreeMap>>, - tokens: Vec>>, + token_map: BTreeMap>>, + tokens: Vec>>, } impl JSONTokens { @@ -200,15 +198,15 @@ impl JSONTokens { pub(crate) async fn load_from_file(filename: &Path) -> Result { let contents = tokio::fs::read(filename).await?; - let tokens: Vec>> = + let tokens: Vec>> = serde_json::from_slice::>(&contents) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? .into_iter() - .map(|json_token| Rc::new(RefCell::new(json_token))) + .map(|json_token| Arc::new(Mutex::new(json_token))) .collect(); - let mut token_map: BTreeMap>> = BTreeMap::new(); + let mut token_map: BTreeMap>> = BTreeMap::new(); for token in tokens.iter().cloned() { - let hash = token.borrow().hash; + let hash = token.lock().unwrap().hash; token_map.insert(hash, token); } Ok(JSONTokens { token_map, tokens }) @@ -226,7 +224,7 @@ impl JSONTokens { T: AsRef, { if let Some(json_token) = self.token_map.get(&hash) { - return Some(json_token.borrow().token.clone()); + return Some(json_token.lock().unwrap().token.clone()); } let requested_scopes_are_subset_of = |other_scopes: &[String]| { @@ -239,10 +237,10 @@ impl JSONTokens { self.tokens .iter() .filter(|json_token| { - filter.is_subset_of(json_token.borrow().filter) == FilterResponse::Maybe + filter.is_subset_of(json_token.lock().unwrap().filter) == FilterResponse::Maybe }) - .find(|v: &&Rc>| requested_scopes_are_subset_of(&v.borrow().scopes)) - .map(|t: &Rc>| t.borrow().token.clone()) + .find(|v: &&Arc>| requested_scopes_are_subset_of(&v.lock().unwrap().scopes)) + .map(|t: &Arc>| t.lock().unwrap().token.clone()) } fn set( @@ -259,10 +257,10 @@ impl JSONTokens { use std::collections::btree_map::Entry; match self.token_map.entry(hash) { Entry::Occupied(entry) => { - entry.get().borrow_mut().token = token; + entry.get().lock().unwrap().token = token; } Entry::Vacant(entry) => { - let json_token = Rc::new(RefCell::new(JSONToken { + let json_token = Arc::new(Mutex::new(JSONToken { scopes: scopes.iter().map(|x| x.as_ref().to_owned()).collect(), token, hash, @@ -277,7 +275,7 @@ impl JSONTokens { fn all_tokens(&self) -> Vec { self.tokens .iter() - .map(|t: &Rc>| t.borrow().clone()) + .map(|t: &Arc>| t.lock().unwrap().clone()) .collect() } } @@ -287,8 +285,11 @@ pub(crate) struct DiskStorage { write_tx: tokio::sync::mpsc::Sender>, } +fn is_send() {} + impl DiskStorage { pub(crate) async fn new(path: PathBuf) -> Result { + is_send::(); let tokens = match JSONTokens::load_from_file(&path).await { Ok(tokens) => tokens, Err(e) if e.kind() == io::ErrorKind::NotFound => JSONTokens::new(), From c829fb453dbd30d40fd306c95a503b91ba12fcec Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Sat, 23 Nov 2019 15:00:31 -0800 Subject: [PATCH 61/71] cargo fmt --- src/storage.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/storage.rs b/src/storage.rs index 67ea4e1..25942c2 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -239,7 +239,9 @@ impl JSONTokens { .filter(|json_token| { filter.is_subset_of(json_token.lock().unwrap().filter) == FilterResponse::Maybe }) - .find(|v: &&Arc>| requested_scopes_are_subset_of(&v.lock().unwrap().scopes)) + .find(|v: &&Arc>| { + requested_scopes_are_subset_of(&v.lock().unwrap().scopes) + }) .map(|t: &Arc>| t.lock().unwrap().token.clone()) } From 497ebf61c525086b355a95511ec53566252a8ea3 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Sat, 23 Nov 2019 15:00:38 -0800 Subject: [PATCH 62/71] Add a test to ensure that Authenticator is Send+Sync --- src/authenticator.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/authenticator.rs b/src/authenticator.rs index 132d2fa..082bb52 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -429,3 +429,14 @@ enum StorageType { Memory, Disk(PathBuf), } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ensure_send_sync() { + fn is_send_sync() {} + is_send_sync::::Connector>>() + } +} From 1b39ce4413e9398e9cdb2e6658341f0ad7253a96 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Sat, 23 Nov 2019 15:19:19 -0800 Subject: [PATCH 63/71] Refactor storage to only use a BTreeMap. Keeping the same tokens in a Vec and BTreeMap created more overhead than was warranted. It makes much more sense to simply iterator over the BTreeMap than keep a separate Vec. --- src/storage.rs | 48 ++++++++++++++++++------------------------------ 1 file changed, 18 insertions(+), 30 deletions(-) diff --git a/src/storage.rs b/src/storage.rs index 25942c2..bc4fbc7 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -7,7 +7,7 @@ use crate::types::Token; use std::collections::BTreeMap; use std::io; use std::path::{Path, PathBuf}; -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; use serde::{Deserialize, Serialize}; @@ -184,32 +184,25 @@ impl Serialize for JSONToken { /// List of tokens in a JSON object #[derive(Debug, Clone)] pub(crate) struct JSONTokens { - token_map: BTreeMap>>, - tokens: Vec>>, + token_map: BTreeMap, } impl JSONTokens { pub(crate) fn new() -> Self { JSONTokens { token_map: BTreeMap::new(), - tokens: Vec::new(), } } pub(crate) async fn load_from_file(filename: &Path) -> Result { let contents = tokio::fs::read(filename).await?; - let tokens: Vec>> = + let token_map: BTreeMap = serde_json::from_slice::>(&contents) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? .into_iter() - .map(|json_token| Arc::new(Mutex::new(json_token))) + .map(|json_token| (json_token.hash, json_token)) .collect(); - let mut token_map: BTreeMap>> = BTreeMap::new(); - for token in tokens.iter().cloned() { - let hash = token.lock().unwrap().hash; - token_map.insert(hash, token); - } - Ok(JSONTokens { token_map, tokens }) + Ok(JSONTokens { token_map }) } fn get( @@ -224,7 +217,7 @@ impl JSONTokens { T: AsRef, { if let Some(json_token) = self.token_map.get(&hash) { - return Some(json_token.lock().unwrap().token.clone()); + return Some(json_token.token.clone()); } let requested_scopes_are_subset_of = |other_scopes: &[String]| { @@ -234,15 +227,11 @@ impl JSONTokens { }; // No exact match for the scopes provided. Search for any tokens that // exist for a superset of the scopes requested. - self.tokens - .iter() - .filter(|json_token| { - filter.is_subset_of(json_token.lock().unwrap().filter) == FilterResponse::Maybe - }) - .find(|v: &&Arc>| { - requested_scopes_are_subset_of(&v.lock().unwrap().scopes) - }) - .map(|t: &Arc>| t.lock().unwrap().token.clone()) + self.token_map + .values() + .filter(|json_token| filter.is_subset_of(json_token.filter) == FilterResponse::Maybe) + .find(|v: &&JSONToken| requested_scopes_are_subset_of(&v.scopes)) + .map(|t: &JSONToken| t.token.clone()) } fn set( @@ -258,26 +247,25 @@ impl JSONTokens { { use std::collections::btree_map::Entry; match self.token_map.entry(hash) { - Entry::Occupied(entry) => { - entry.get().lock().unwrap().token = token; + Entry::Occupied(mut entry) => { + entry.get_mut().token = token; } Entry::Vacant(entry) => { - let json_token = Arc::new(Mutex::new(JSONToken { + let json_token = JSONToken { scopes: scopes.iter().map(|x| x.as_ref().to_owned()).collect(), token, hash, filter, - })); + }; entry.insert(json_token.clone()); - self.tokens.push(json_token); } } } fn all_tokens(&self) -> Vec { - self.tokens - .iter() - .map(|t: &Arc>| t.lock().unwrap().clone()) + self.token_map + .values() + .map(|t: &JSONToken| t.clone()) .collect() } } From 5e39a818944b9293051d1034904041480f586e45 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Tue, 26 Nov 2019 15:19:57 -0800 Subject: [PATCH 64/71] Go back to waiting for disk writes on every token set. Defering disk writes is still probably a good idea, but unfortunately there are some tradeoffs with rust's async story that make it non-ideal. Ideally we would defer writes, but have a Drop impl on DiskStorage that waited for all the deferred writes to complete. While it's trival to create a future that waits for all deferred writes to finish it's not currently possible to write a Drop impl that waits on a future. It would be possible to write an inherent async fn that takes self by value and waits for the writes, but that method would need to be propogated up all the way to users of the library and they would need to remember to invoke it before dropping the Authenticator. --- Cargo.toml | 1 + src/authenticator.rs | 4 +- src/error.rs | 6 ++ src/storage.rs | 159 +++++++++++++++++++++++++++---------------- 4 files changed, 109 insertions(+), 61 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e75d053..d9a377a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ seahash = "3.0.6" [dev-dependencies] mockito = "0.17" env_logger = "0.6" +tempfile = "3.1" [workspace] members = ["examples/test-installed/", "examples/test-svc-acct/", "examples/test-device/"] diff --git a/src/authenticator.rs b/src/authenticator.rs index 082bb52..6eea182 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -48,13 +48,13 @@ where let token = RefreshFlow::refresh_token(&self.hyper_client, app_secret, &refresh_token) .await?; - self.storage.set(hashed_scopes, token.clone()).await; + self.storage.set(hashed_scopes, token.clone()).await?; Ok(token) } _ => { // no token in the cache or the token returned can't be refreshed. let t = self.auth_flow.token(&self.hyper_client, scopes).await?; - self.storage.set(hashed_scopes, t.clone()).await; + self.storage.set(hashed_scopes, t.clone()).await?; Ok(t) } } diff --git a/src/error.rs b/src/error.rs index 0e132f5..d83a6aa 100644 --- a/src/error.rs +++ b/src/error.rs @@ -174,6 +174,12 @@ impl From for Error { } } +impl From for Error { + fn from(value: io::Error) -> Error { + Error::LowLevelError(value) + } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self { diff --git a/src/storage.rs b/src/storage.rs index bc4fbc7..d838ba1 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -4,7 +4,7 @@ // use crate::types::Token; -use std::collections::BTreeMap; +use std::collections::HashMap; use std::io; use std::path::{Path, PathBuf}; use std::sync::Mutex; @@ -24,12 +24,12 @@ use serde::{Deserialize, Serialize}; /// ScopeHash is a hash value derived from a list of scopes. The hash value /// represents a fingerprint of the set of scopes *independent* of the ordering. -#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] struct ScopeHash(u64); /// ScopeFilter represents a filter for a set of scopes. It can definitively /// prove that a given list of scopes is not a subset of another. -#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] struct ScopeFilter(u64); #[derive(Debug, Copy, Clone, Eq, PartialEq)] @@ -111,7 +111,11 @@ pub(crate) enum Storage { } impl Storage { - pub(crate) async fn set(&self, scopes: ScopeSet<'_, T>, token: Token) + pub(crate) async fn set( + &self, + scopes: ScopeSet<'_, T>, + token: Token, + ) -> Result<(), io::Error> where T: AsRef, { @@ -184,25 +188,60 @@ impl Serialize for JSONToken { /// List of tokens in a JSON object #[derive(Debug, Clone)] pub(crate) struct JSONTokens { - token_map: BTreeMap, + token_map: HashMap, +} + +impl Serialize for JSONTokens { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.collect_seq(self.token_map.values()) + } +} + +impl<'de> Deserialize<'de> for JSONTokens { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct V; + impl<'de> serde::de::Visitor<'de> for V { + type Value = JSONTokens; + + // Format a message stating what data this Visitor expects to receive. + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a sequence of JSONToken's") + } + + fn visit_seq(self, mut access: M) -> Result + where + M: serde::de::SeqAccess<'de>, + { + let mut token_map = HashMap::with_capacity(access.size_hint().unwrap_or(0)); + while let Some(json_token) = access.next_element::()? { + token_map.insert(json_token.hash, json_token); + } + Ok(JSONTokens { token_map }) + } + } + + // Instantiate our Visitor and ask the Deserializer to drive + // it over the input data, resulting in an instance of MyMap. + deserializer.deserialize_seq(V) + } } impl JSONTokens { pub(crate) fn new() -> Self { JSONTokens { - token_map: BTreeMap::new(), + token_map: HashMap::new(), } } - pub(crate) async fn load_from_file(filename: &Path) -> Result { + async fn load_from_file(filename: &Path) -> Result { let contents = tokio::fs::read(filename).await?; - let token_map: BTreeMap = - serde_json::from_slice::>(&contents) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? - .into_iter() - .map(|json_token| (json_token.hash, json_token)) - .collect(); - Ok(JSONTokens { token_map }) + serde_json::from_slice(&contents).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) } fn get( @@ -242,10 +281,11 @@ impl JSONTokens { scopes, }: ScopeSet, token: Token, - ) where + ) -> Result<(), io::Error> + where T: AsRef, { - use std::collections::btree_map::Entry; + use std::collections::hash_map::Entry; match self.token_map.entry(hash) { Entry::Occupied(mut entry) => { entry.get_mut().token = token; @@ -260,71 +300,45 @@ impl JSONTokens { entry.insert(json_token.clone()); } } - } - - fn all_tokens(&self) -> Vec { - self.token_map - .values() - .map(|t: &JSONToken| t.clone()) - .collect() + Ok(()) } } pub(crate) struct DiskStorage { tokens: Mutex, - write_tx: tokio::sync::mpsc::Sender>, + filename: PathBuf, } -fn is_send() {} - impl DiskStorage { - pub(crate) async fn new(path: PathBuf) -> Result { - is_send::(); - let tokens = match JSONTokens::load_from_file(&path).await { + pub(crate) async fn new(filename: PathBuf) -> Result { + let tokens = match JSONTokens::load_from_file(&filename).await { Ok(tokens) => tokens, Err(e) if e.kind() == io::ErrorKind::NotFound => JSONTokens::new(), Err(e) => return Err(e), }; - // Writing to disk will happen in a separate task. This means in the - // common case returning a token to the user will not be required to - // wait for disk i/o. We communicate with a dedicated writer task via a - // buffered channel. This ensures that the writes happen in the order - // received, and if writes fall too far behind we will block GetToken - // requests until disk i/o completes. - let (write_tx, mut write_rx) = tokio::sync::mpsc::channel::>(2); - tokio::spawn(async move { - while let Some(tokens) = write_rx.recv().await { - match serde_json::to_string(&tokens) { - Err(e) => log::error!("Failed to serialize tokens: {}", e), - Ok(ser) => { - if let Err(e) = tokio::fs::write(path.clone(), &ser).await { - log::error!("Failed to write tokens to disk: {}", e); - } - } - } - } - }); Ok(DiskStorage { tokens: Mutex::new(tokens), - write_tx, + filename, }) } - async fn set(&self, scopes: ScopeSet<'_, T>, token: Token) + pub(crate) async fn set( + &self, + scopes: ScopeSet<'_, T>, + token: Token, + ) -> Result<(), io::Error> where T: AsRef, { - let cloned_tokens = { - let mut tokens = self.tokens.lock().unwrap(); - tokens.set(scopes, token); - tokens.all_tokens() + let json = { + use std::ops::Deref; + let mut lock = self.tokens.lock().unwrap(); + lock.set(scopes, token)?; + serde_json::to_string(lock.deref()) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? }; - self.write_tx - .clone() - .send(cloned_tokens) - .await - .expect("disk storage task not running"); + tokio::fs::write(self.filename.clone(), json).await } pub(crate) fn get(&self, scopes: ScopeSet) -> Option @@ -358,4 +372,31 @@ mod tests { assert!(foobar.is_subset_of(foo) == FilterResponse::No); assert!(foobar.is_subset_of(bar) == FilterResponse::No); } + + #[tokio::test] + async fn test_disk_storage() { + let new_token = |access_token: &str| Token { + access_token: access_token.to_owned(), + refresh_token: None, + token_type: "Bearer".to_owned(), + expires_at: None, + }; + let tempdir = tempfile::tempdir().unwrap(); + let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json")) + .await + .unwrap(); + let scope_set = ScopeSet::from(&["myscope"]); + assert!(storage.get(scope_set).is_none()); + storage + .set(scope_set, new_token("my_access_token")) + .await + .unwrap(); + assert_eq!(storage.get(scope_set), Some(new_token("my_access_token"))); + + // Create a new DiskStorage instance and verify the tokens were read from disk correctly. + let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json")) + .await + .unwrap(); + assert_eq!(storage.get(scope_set), Some(new_token("my_access_token"))); + } } From 045c3e77355e242a0978493114637d20f9302a3c Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Wed, 27 Nov 2019 09:27:34 -0800 Subject: [PATCH 65/71] Move all the end to end tests into an integration test All the same functionality can be tested through the publicly exposed API providing more extensive coverage. --- src/device.rs | 155 ----------- src/error.rs | 3 +- src/helper.rs | 10 - src/installed.rs | 164 +----------- src/refresh.rs | 61 ----- src/service_account.rs | 75 ------ src/storage.rs | 35 +-- tests/tests.rs | 574 +++++++++++++++++++++++++++++++++++++++++ 8 files changed, 595 insertions(+), 482 deletions(-) create mode 100644 tests/tests.rs diff --git a/src/device.rs b/src/device.rs index 6888091..e80ad01 100644 --- a/src/device.rs +++ b/src/device.rs @@ -191,158 +191,3 @@ impl DeviceFlow { Token::from_json(&body) } } - -#[cfg(test)] -mod tests { - use hyper_rustls::HttpsConnector; - use std::pin::Pin; - - use super::*; - - #[tokio::test] - async fn test_device_end2end() { - #[derive(Clone)] - struct FD; - impl DeviceFlowDelegate for FD { - fn present_user_code<'a>( - &'a self, - pi: &'a DeviceAuthResponse, - ) -> Pin + 'a + Send>> { - assert_eq!("https://example.com/verify", pi.verification_uri); - Box::pin(futures::future::ready(())) - } - } - - let server_url = mockito::server_url(); - let app_secret: ApplicationSecret = crate::parse_json!({ - "client_id": "902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com", - "project_id": "yup-test-243420", - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": format!("{}/token", server_url), - "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_secret": "iuMPN6Ne1PD7cos29Tk9rlqH", - "redirect_uris": ["urn:ietf:wg:oauth:2.0:oob","http://localhost"], - }); - let device_code_url = format!("{}/code", server_url); - - let https = HttpsConnector::new(); - let client = hyper::Client::builder() - .keep_alive(false) - .build::<_, hyper::Body>(https); - - let flow = DeviceFlow { - app_secret, - device_code_url: device_code_url.into(), - flow_delegate: Box::new(FD), - grant_type: GOOGLE_GRANT_TYPE.into(), - }; - - // Successful path - { - let code_response = serde_json::json!({ - "device_code": "devicecode", - "user_code": "usercode", - "verification_url": "https://example.com/verify", - "expires_in": 1234567, - "interval": 1 - }); - let _m = mockito::mock("POST", "/code") - .match_body(mockito::Matcher::Regex( - ".*client_id=902216714886.*".to_string(), - )) - .with_status(200) - .with_body(code_response.to_string()) - .create(); - let token_response = serde_json::json!({ - "access_token": "accesstoken", - "refresh_token": "refreshtoken", - "token_type": "Bearer", - "expires_in": 1234567 - }); - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*client_secret=iuMPN6Ne1PD7cos29Tk9rlqH&code=devicecode.*".to_string(), - )) - .with_status(200) - .with_body(token_response.to_string()) - .create(); - - let token = flow - .token(&client, &["https://www.googleapis.com/scope/1"]) - .await - .expect("token failed"); - assert_eq!("accesstoken", token.access_token); - _m.assert(); - } - - // Code is not delivered. - { - let code_response = serde_json::json!({ - "error": "invalid_client_id", - "error_description": "description" - }); - let _m = mockito::mock("POST", "/code") - .match_body(mockito::Matcher::Regex( - ".*client_id=902216714886.*".to_string(), - )) - .with_status(400) - .with_body(code_response.to_string()) - .create(); - let token_response = serde_json::json!({ - "access_token": "accesstoken", - "refresh_token": "refreshtoken", - "token_type": "Bearer", - "expires_in": 1234567 - }); - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*client_secret=iuMPN6Ne1PD7cos29Tk9rlqH&code=devicecode.*".to_string(), - )) - .with_status(200) - .with_body(token_response.to_string()) - .expect(0) // Never called! - .create(); - - let res = flow - .token(&client, &["https://www.googleapis.com/scope/1"]) - .await; - assert!(res.is_err()); - assert!(format!("{}", res.unwrap_err()).contains("invalid_client_id")); - _m.assert(); - } - - // Token is not delivered. - { - let code_response = serde_json::json!({ - "device_code": "devicecode", - "user_code": "usercode", - "verification_url": "https://example.com/verify", - "expires_in": 1234567, - "interval": 1 - }); - let _m = mockito::mock("POST", "/code") - .match_body(mockito::Matcher::Regex( - ".*client_id=902216714886.*".to_string(), - )) - .with_status(200) - .with_body(code_response.to_string()) - .create(); - let token_response = serde_json::json!({"error": "access_denied"}); - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*client_secret=iuMPN6Ne1PD7cos29Tk9rlqH&code=devicecode.*".to_string(), - )) - .with_status(400) - .with_body(token_response.to_string()) - .expect(1) - .create(); - - let res = flow - .token(&client, &["https://www.googleapis.com/scope/1"]) - .await; - assert!(res.is_err()); - assert!(format!("{}", res.unwrap_err()).contains("access_denied")); - _m.assert(); - } - } -} diff --git a/src/error.rs b/src/error.rs index d83a6aa..63cf636 100644 --- a/src/error.rs +++ b/src/error.rs @@ -10,7 +10,7 @@ use serde::Deserialize; /// Error returned by the authorization server. /// https://tools.ietf.org/html/rfc6749#section-5.2 /// https://tools.ietf.org/html/rfc8628#section-3.5 -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, PartialEq, Eq)] pub struct AuthError { /// Error code from the server. pub error: AuthErrorCode, @@ -36,7 +36,6 @@ impl StdError for AuthError {} /// The error code returned by the authorization server. #[derive(Debug, Clone, Eq, PartialEq)] - pub enum AuthErrorCode { /// invalid_request InvalidRequest, diff --git a/src/helper.rs b/src/helper.rs index 81200f6..b143db7 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -70,13 +70,3 @@ where debug_assert_eq!(size, result.len()); result } - -#[cfg(test)] -#[macro_export] -/// Utility function for parsing json. Useful in unit tests. Simply wrap the -/// json! macro in a from_value to deserialize the contents to arbitrary structs. -macro_rules! parse_json { - ($($json:tt)+) => { - ::serde_json::from_value(::serde_json::json!($($json)+)).expect("failed to deserialize") - } -} diff --git a/src/installed.rs b/src/installed.rs index 22034d8..f2ae102 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -354,170 +354,8 @@ mod installed_flow_server { #[cfg(test)] mod tests { - use std::str::FromStr; - - use hyper::client::connect::HttpConnector; - use hyper::Uri; - use hyper_rustls::HttpsConnector; - use mockito::mock; - use super::*; - use crate::authenticator_delegate::InstalledFlowDelegate; - - #[tokio::test] - async fn test_end2end() { - #[derive(Clone)] - struct FD( - String, - hyper::Client, hyper::Body>, - ); - impl InstalledFlowDelegate for FD { - /// Depending on need_code, return the pre-set code or send the code to the server at - /// the redirect_uri given in the url. - fn present_user_url<'a>( - &'a self, - url: &'a str, - need_code: bool, - ) -> Pin> + Send + 'a>> { - Box::pin(async move { - if need_code { - Ok(self.0.clone()) - } else { - // Parse presented url to obtain redirect_uri with location of local - // code-accepting server. - let uri = Uri::from_str(url.as_ref()).unwrap(); - let query = uri.query().unwrap(); - let parsed = form_urlencoded::parse(query.as_bytes()).into_owned(); - let mut rduri = None; - for (k, v) in parsed { - if k == "redirect_uri" { - rduri = Some(v); - break; - } - } - if rduri.is_none() { - return Err("no redirect_uri!".into()); - } - let mut rduri = rduri.unwrap(); - rduri.push_str(&format!("?code={}", self.0)); - let rduri = Uri::from_str(rduri.as_ref()).unwrap(); - // Hit server. - self.1 - .get(rduri) - .await - .map_err(|e| e.to_string()) - .map(|_| "".to_string()) - } - }) - } - } - - let server_url = mockito::server_url(); - let app_secret: ApplicationSecret = crate::parse_json!({ - "client_id": "902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com", - "project_id": "yup-test-243420", - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": format!("{}/token", server_url), - "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_secret": "iuMPN6Ne1PD7cos29Tk9rlqH", - "redirect_uris": ["urn:ietf:wg:oauth:2.0:oob","http://localhost"] - }); - - let https = HttpsConnector::new(); - let client = hyper::Client::builder() - .keep_alive(false) - .build::<_, hyper::Body>(https); - - let fd = FD("authorizationcode".to_string(), client.clone()); - let inf = InstalledFlow { - app_secret: app_secret.clone(), - method: InstalledFlowReturnMethod::Interactive, - flow_delegate: Box::new(fd), - }; - - // Successful path. - { - let _m = mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*code=authorizationcode.*client_id=9022167.*".to_string(), - )) - .with_body( - serde_json::json!({ - "access_token": "accesstoken", - "refresh_token": "refreshtoken", - "token_type": "Bearer", - "expires_in": 12345678 - }) - .to_string(), - ) - .expect(1) - .create(); - - let tok = inf - .token(&client, &["https://googleapis.com/some/scope"]) - .await - .expect("failed to get token"); - assert_eq!("accesstoken", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); - _m.assert(); - } - - // Successful path with HTTP redirect. - { - let inf = InstalledFlow { - app_secret: app_secret.clone(), - method: InstalledFlowReturnMethod::HTTPRedirect, - flow_delegate: Box::new(FD( - "authorizationcodefromlocalserver".to_string(), - client.clone(), - )), - }; - let _m = mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*code=authorizationcodefromlocalserver.*client_id=9022167.*".to_string(), - )) - .with_body( - serde_json::json!({ - "access_token": "accesstoken", - "refresh_token": "refreshtoken", - "token_type": "Bearer", - "expires_in": 12345678 - }) - .to_string(), - ) - .expect(1) - .create(); - - let tok = inf - .token(&client, &["https://googleapis.com/some/scope"]) - .await - .expect("failed to get token"); - assert_eq!("accesstoken", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); - _m.assert(); - } - - // Error from server. - { - let _m = mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*code=authorizationcode.*client_id=9022167.*".to_string(), - )) - .with_status(400) - .with_body(serde_json::json!({"error": "invalid_code"}).to_string()) - .expect(1) - .create(); - - let tokr = inf - .token(&client, &["https://googleapis.com/some/scope"]) - .await; - assert!(tokr.is_err()); - assert!(format!("{}", tokr.unwrap_err()).contains("invalid_code")); - _m.assert(); - } - } + use hyper::Uri; #[test] fn test_request_url_builder() { diff --git a/src/refresh.rs b/src/refresh.rs index 910a38b..1ed72e8 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -57,64 +57,3 @@ impl RefreshFlow { Ok(token) } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::helper; - - use hyper_rustls::HttpsConnector; - - #[tokio::test] - async fn test_refresh_end2end() { - let server_url = mockito::server_url(); - - let app_secret = r#"{"installed":{"client_id":"902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com","project_id":"yup-test-243420","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_secret":"iuMPN6Ne1PD7cos29Tk9rlqH","redirect_uris":["urn:ietf:wg:oauth:2.0:oob","http://localhost"]}}"#; - let mut app_secret = helper::parse_application_secret(app_secret).unwrap(); - app_secret.token_uri = format!("{}/token", server_url); - let refresh_token = "my-refresh-token"; - - let https = HttpsConnector::new(); - let client = hyper::Client::builder() - .keep_alive(false) - .build::<_, hyper::Body>(https); - - // Success - { - let _m = mockito::mock("POST", "/token") - .match_body( - mockito::Matcher::Regex(".*client_id=902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com.*refresh_token=my-refresh-token.*".to_string())) - .with_status(200) - .with_body(r#"{"access_token": "new-access-token", "token_type": "Bearer", "expires_in": 1234567}"#) - .create(); - let token = RefreshFlow::refresh_token(&client, &app_secret, refresh_token) - .await - .expect("token failed"); - assert_eq!("new-access-token", token.access_token); - assert_eq!("Bearer", token.token_type); - _m.assert(); - } - - // Refresh error. - { - let _m = mockito::mock("POST", "/token") - .match_body( - mockito::Matcher::Regex(".*client_id=902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com.*refresh_token=my-refresh-token.*".to_string())) - .with_status(400) - .with_body(r#"{"error": "invalid_request"}"#) - .create(); - - let rr = RefreshFlow::refresh_token(&client, &app_secret, refresh_token).await; - match rr { - Err(Error::AuthError(auth_error)) => { - assert_eq!( - auth_error.error, - crate::error::AuthErrorCode::InvalidRequest - ); - } - _ => panic!(format!("unexpected RefreshResult {:?}", rr)), - } - _m.assert(); - } - } -} diff --git a/src/service_account.rs b/src/service_account.rs index 168edc2..1eeeb47 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -210,83 +210,8 @@ impl ServiceAccountFlow { mod tests { use super::*; use crate::helper::read_service_account_key; - use crate::parse_json; - use chrono::Utc; use hyper_rustls::HttpsConnector; - use mockito::mock; - - #[tokio::test] - async fn test_mocked_http() { - env_logger::try_init().unwrap(); - let https = HttpsConnector::new(); - let client = hyper::Client::builder() - .keep_alive(false) - .build::<_, hyper::Body>(https); - let server_url = &mockito::server_url(); - let key: ServiceAccountKey = parse_json!({ - "type": "service_account", - "project_id": "yup-test-243420", - "private_key_id": "26de294916614a5ebdf7a065307ed3ea9941902b", - "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDemmylrvp1KcOn\n9yTAVVKPpnpYznvBvcAU8Qjwr2fSKylpn7FQI54wCk5VJVom0jHpAmhxDmNiP8yv\nHaqsef+87Oc0n1yZ71/IbeRcHZc2OBB33/LCFqf272kThyJo3qspEqhuAw0e8neg\nLQb4jpm9PsqR8IjOoAtXQSu3j0zkXemMYFy93PWHjVpPEUX16NGfsWH7oxspBHOk\n9JPGJL8VJdbiAoDSDgF0y9RjJY5I52UeHNhMsAkTYs6mIG4kKXt2+T9tAyHw8aho\nwmuytQAfydTflTfTG8abRtliF3nil2taAc5VB07dP1b4dVYy/9r6M8Z0z4XM7aP+\nNdn2TKm3AgMBAAECggEAWi54nqTlXcr2M5l535uRb5Xz0f+Q/pv3ceR2iT+ekXQf\n+mUSShOr9e1u76rKu5iDVNE/a7H3DGopa7ZamzZvp2PYhSacttZV2RbAIZtxU6th\n7JajPAM+t9klGh6wj4jKEcE30B3XVnbHhPJI9TCcUyFZoscuPXt0LLy/z8Uz0v4B\nd5JARwyxDMb53VXwukQ8nNY2jP7WtUig6zwE5lWBPFMbi8GwGkeGZOruAK5sPPwY\nGBAlfofKANI7xKx9UXhRwisB4+/XI1L0Q6xJySv9P+IAhDUI6z6kxR+WkyT/YpG3\nX9gSZJc7qEaxTIuDjtep9GTaoEqiGntjaFBRKoe+VQKBgQDzM1+Ii+REQqrGlUJo\nx7KiVNAIY/zggu866VyziU6h5wjpsoW+2Npv6Dv7nWvsvFodrwe50Y3IzKtquIal\nVd8aa50E72JNImtK/o5Nx6xK0VySjHX6cyKENxHRDnBmNfbALRM+vbD9zMD0lz2q\nmns/RwRGq3/98EqxP+nHgHSr9QKBgQDqUYsFAAfvfT4I75Glc9svRv8IsaemOm07\nW1LCwPnj1MWOhsTxpNF23YmCBupZGZPSBFQobgmHVjQ3AIo6I2ioV6A+G2Xq/JCF\nmzfbvZfqtbbd+nVgF9Jr1Ic5T4thQhAvDHGUN77BpjEqZCQLAnUWJx9x7e2xvuBl\n1A6XDwH/ewKBgQDv4hVyNyIR3nxaYjFd7tQZYHTOQenVffEAd9wzTtVbxuo4sRlR\nNM7JIRXBSvaATQzKSLHjLHqgvJi8LITLIlds1QbNLl4U3UVddJbiy3f7WGTqPFfG\nkLhUF4mgXpCpkMLxrcRU14Bz5vnQiDmQRM4ajS7/kfwue00BZpxuZxst3QKBgQCI\nRI3FhaQXyc0m4zPfdYYVc4NjqfVmfXoC1/REYHey4I1XetbT9Nb/+ow6ew0UbgSC\nUZQjwwJ1m1NYXU8FyovVwsfk9ogJ5YGiwYb1msfbbnv/keVq0c/Ed9+AG9th30qM\nIf93hAfClITpMz2mzXIMRQpLdmQSR4A2l+E4RjkSOwKBgQCB78AyIdIHSkDAnCxz\nupJjhxEhtQ88uoADxRoEga7H/2OFmmPsqfytU4+TWIdal4K+nBCBWRvAX1cU47vH\nJOlSOZI0gRKe0O4bRBQc8GXJn/ubhYSxI02IgkdGrIKpOb5GG10m85ZvqsXw3bKn\nRVHMD0ObF5iORjZUqD0yRitAdg==\n-----END PRIVATE KEY-----\n", - "client_email": "yup-test-sa-1@yup-test-243420.iam.gserviceaccount.com", - "client_id": "102851967901799660408", - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": format!("{}/token", server_url), - "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/yup-test-sa-1%40yup-test-243420.iam.gserviceaccount.com" - }); - - let json_response = serde_json::json!({ - "access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk", - "expires_in": 3600, - "token_type": "Bearer" - }); - let bad_json_response = serde_json::json!({ - "error": "access_denied", - }); - - // Successful path. - { - let _m = mock("POST", "/token") - .with_status(200) - .with_header("content-type", "text/json") - .with_body(json_response.to_string()) - .expect(1) - .create(); - let acc = ServiceAccountFlow::new(ServiceAccountFlowOpts { - key: key.clone(), - subject: None, - }) - .unwrap(); - let tok = acc - .token(&client, &["https://www.googleapis.com/auth/pubsub"]) - .await - .expect("token failed"); - assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); - assert!(Utc::now() + chrono::Duration::seconds(3600) >= tok.expires_at.unwrap()); - _m.assert(); - } - // Malformed response. - { - let _m = mock("POST", "/token") - .with_status(200) - .with_header("content-type", "text/json") - .with_body(bad_json_response.to_string()) - .create(); - let acc = ServiceAccountFlow::new(ServiceAccountFlowOpts { - key: key.clone(), - subject: None, - }) - .unwrap(); - let result = acc - .token(&client, &["https://www.googleapis.com/auth/pubsub"]) - .await; - assert!(result.is_err()); - _m.assert(); - } - } - // Valid but deactivated key. const TEST_PRIVATE_KEY_PATH: &'static str = "examples/Sanguine-69411a0c0eea.json"; diff --git a/src/storage.rs b/src/storage.rs index d838ba1..7d38575 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -381,22 +381,25 @@ mod tests { token_type: "Bearer".to_owned(), expires_at: None, }; - let tempdir = tempfile::tempdir().unwrap(); - let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json")) - .await - .unwrap(); let scope_set = ScopeSet::from(&["myscope"]); - assert!(storage.get(scope_set).is_none()); - storage - .set(scope_set, new_token("my_access_token")) - .await - .unwrap(); - assert_eq!(storage.get(scope_set), Some(new_token("my_access_token"))); - - // Create a new DiskStorage instance and verify the tokens were read from disk correctly. - let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json")) - .await - .unwrap(); - assert_eq!(storage.get(scope_set), Some(new_token("my_access_token"))); + let tempdir = tempfile::tempdir().unwrap(); + { + let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json")) + .await + .unwrap(); + assert!(storage.get(scope_set).is_none()); + storage + .set(scope_set, new_token("my_access_token")) + .await + .unwrap(); + assert_eq!(storage.get(scope_set), Some(new_token("my_access_token"))); + } + { + // Create a new DiskStorage instance and verify the tokens were read from disk correctly. + let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json")) + .await + .unwrap(); + assert_eq!(storage.get(scope_set), Some(new_token("my_access_token"))); + } } } diff --git a/tests/tests.rs b/tests/tests.rs new file mode 100644 index 0000000..40fe487 --- /dev/null +++ b/tests/tests.rs @@ -0,0 +1,574 @@ +use yup_oauth2::{ + authenticator::Authenticator, + authenticator_delegate::{DeviceAuthResponse, DeviceFlowDelegate, InstalledFlowDelegate}, + error::{AuthError, AuthErrorCode}, + ApplicationSecret, DeviceFlowAuthenticator, Error, InstalledFlowAuthenticator, + InstalledFlowReturnMethod, ServiceAccountAuthenticator, ServiceAccountKey, +}; + +use std::future::Future; +use std::path::PathBuf; +use std::pin::Pin; + +use hyper::client::connect::HttpConnector; +use hyper::Uri; +use hyper_rustls::HttpsConnector; +use url::form_urlencoded; + +/// Utility function for parsing json. Useful in unit tests. Simply wrap the +/// json! macro in a from_value to deserialize the contents to arbitrary structs. +macro_rules! parse_json { + ($($json:tt)+) => { + ::serde_json::from_value(::serde_json::json!($($json)+)).expect("failed to deserialize") + } +} + +async fn create_device_flow_auth() -> Authenticator> { + let server_url = mockito::server_url(); + let app_secret: ApplicationSecret = parse_json!({ + "client_id": "902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com", + "project_id": "yup-test-243420", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": format!("{}/token", server_url), + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_secret": "iuMPN6Ne1PD7cos29Tk9rlqH", + "redirect_uris": ["urn:ietf:wg:oauth:2.0:oob","http://localhost"], + }); + struct FD; + impl DeviceFlowDelegate for FD { + fn present_user_code<'a>( + &'a self, + pi: &'a DeviceAuthResponse, + ) -> Pin + 'a + Send>> { + assert_eq!("https://example.com/verify", pi.verification_uri); + Box::pin(futures::future::ready(())) + } + } + + DeviceFlowAuthenticator::builder(app_secret) + .flow_delegate(Box::new(FD)) + .device_code_url(format!("{}/code", server_url)) + .build() + .await + .unwrap() +} + +#[tokio::test] +async fn test_device_success() { + let auth = create_device_flow_auth().await; + let code_response = serde_json::json!({ + "device_code": "devicecode", + "user_code": "usercode", + "verification_url": "https://example.com/verify", + "expires_in": 1234567, + "interval": 1 + }); + let _m = mockito::mock("POST", "/code") + .match_body(mockito::Matcher::Regex( + ".*client_id=902216714886.*".to_string(), + )) + .with_status(200) + .with_body(code_response.to_string()) + .create(); + let token_response = serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 1234567 + }); + let _m = mockito::mock("POST", "/token") + .match_body(mockito::Matcher::Regex( + ".*client_secret=iuMPN6Ne1PD7cos29Tk9rlqH&code=devicecode.*".to_string(), + )) + .with_status(200) + .with_body(token_response.to_string()) + .create(); + + let token = auth + .token(&["https://www.googleapis.com/scope/1"]) + .await + .expect("token failed"); + assert_eq!("accesstoken", token.access_token); + _m.assert(); +} + +#[tokio::test] +async fn test_device_no_code() { + let auth = create_device_flow_auth().await; + let code_response = serde_json::json!({ + "error": "invalid_client_id", + "error_description": "description" + }); + let _m = mockito::mock("POST", "/code") + .match_body(mockito::Matcher::Regex( + ".*client_id=902216714886.*".to_string(), + )) + .with_status(400) + .with_body(code_response.to_string()) + .create(); + let token_response = serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 1234567 + }); + let _m = mockito::mock("POST", "/token") + .match_body(mockito::Matcher::Regex( + ".*client_secret=iuMPN6Ne1PD7cos29Tk9rlqH&code=devicecode.*".to_string(), + )) + .with_status(200) + .with_body(token_response.to_string()) + .expect(0) // Never called! + .create(); + + let res = auth.token(&["https://www.googleapis.com/scope/1"]).await; + assert!(res.is_err()); + assert!(format!("{}", res.unwrap_err()).contains("invalid_client_id")); + _m.assert(); +} + +#[tokio::test] +async fn test_device_no_token() { + let auth = create_device_flow_auth().await; + let code_response = serde_json::json!({ + "device_code": "devicecode", + "user_code": "usercode", + "verification_url": "https://example.com/verify", + "expires_in": 1234567, + "interval": 1 + }); + let _m = mockito::mock("POST", "/code") + .match_body(mockito::Matcher::Regex( + ".*client_id=902216714886.*".to_string(), + )) + .with_status(200) + .with_body(code_response.to_string()) + .create(); + let token_response = serde_json::json!({"error": "access_denied"}); + let _m = mockito::mock("POST", "/token") + .match_body(mockito::Matcher::Regex( + ".*client_secret=iuMPN6Ne1PD7cos29Tk9rlqH&code=devicecode.*".to_string(), + )) + .with_status(400) + .with_body(token_response.to_string()) + .expect(1) + .create(); + + let res = auth.token(&["https://www.googleapis.com/scope/1"]).await; + assert!(res.is_err()); + assert!(format!("{}", res.unwrap_err()).contains("access_denied")); + _m.assert(); +} + +async fn create_installed_flow_auth( + method: InstalledFlowReturnMethod, + filename: Option, +) -> Authenticator> { + let server_url = mockito::server_url(); + let app_secret: ApplicationSecret = parse_json!({ + "client_id": "902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com", + "project_id": "yup-test-243420", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": format!("{}/token", server_url), + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_secret": "iuMPN6Ne1PD7cos29Tk9rlqH", + "redirect_uris": ["urn:ietf:wg:oauth:2.0:oob","http://localhost"], + }); + struct FD(hyper::Client>); + impl InstalledFlowDelegate for FD { + /// Depending on need_code, return the pre-set code or send the code to the server at + /// the redirect_uri given in the url. + fn present_user_url<'a>( + &'a self, + url: &'a str, + need_code: bool, + ) -> Pin> + Send + 'a>> { + use std::str::FromStr; + Box::pin(async move { + if need_code { + Ok("authorizationcode".to_owned()) + } else { + // Parse presented url to obtain redirect_uri with location of local + // code-accepting server. + let uri = Uri::from_str(url.as_ref()).unwrap(); + let query = uri.query().unwrap(); + let parsed = form_urlencoded::parse(query.as_bytes()).into_owned(); + let mut rduri = None; + for (k, v) in parsed { + if k == "redirect_uri" { + rduri = Some(v); + break; + } + } + if rduri.is_none() { + return Err("no redirect_uri!".into()); + } + let mut rduri = rduri.unwrap(); + rduri.push_str("?code=authorizationcode"); + let rduri = Uri::from_str(rduri.as_ref()).unwrap(); + // Hit server. + self.0 + .get(rduri) + .await + .map_err(|e| e.to_string()) + .map(|_| "".to_string()) + } + }) + } + } + + let mut builder = InstalledFlowAuthenticator::builder(app_secret, method).flow_delegate( + Box::new(FD(hyper::Client::builder().build(HttpsConnector::new()))), + ); + + builder = if let Some(filename) = filename { + builder.persist_tokens_to_disk(filename) + } else { + builder + }; + + builder.build().await.unwrap() +} + +#[tokio::test] +async fn test_installed_interactive_success() { + let auth = create_installed_flow_auth(InstalledFlowReturnMethod::Interactive, None).await; + let _m = mockito::mock("POST", "/token") + .match_body(mockito::Matcher::Regex( + ".*code=authorizationcode.*client_id=9022167.*".to_string(), + )) + .with_body( + serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 12345678 + }) + .to_string(), + ) + .expect(1) + .create(); + + let tok = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!("accesstoken", tok.access_token); + assert_eq!("refreshtoken", tok.refresh_token.unwrap()); + assert_eq!("Bearer", tok.token_type); + _m.assert(); +} + +#[tokio::test] +async fn test_installed_redirect_success() { + let auth = create_installed_flow_auth(InstalledFlowReturnMethod::HTTPRedirect, None).await; + let _m = mockito::mock("POST", "/token") + .match_body(mockito::Matcher::Regex( + ".*code=authorizationcode.*client_id=9022167.*".to_string(), + )) + .with_body( + serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 12345678 + }) + .to_string(), + ) + .expect(1) + .create(); + + let tok = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!("accesstoken", tok.access_token); + assert_eq!("refreshtoken", tok.refresh_token.unwrap()); + assert_eq!("Bearer", tok.token_type); + _m.assert(); +} + +#[tokio::test] +async fn test_installed_error() { + let auth = create_installed_flow_auth(InstalledFlowReturnMethod::Interactive, None).await; + let _m = mockito::mock("POST", "/token") + .match_body(mockito::Matcher::Regex( + ".*code=authorizationcode.*client_id=9022167.*".to_string(), + )) + .with_status(400) + .with_body(serde_json::json!({"error": "invalid_code"}).to_string()) + .expect(1) + .create(); + + let tokr = auth.token(&["https://googleapis.com/some/scope"]).await; + assert!(tokr.is_err()); + assert!(format!("{}", tokr.unwrap_err()).contains("invalid_code")); + _m.assert(); +} + +async fn create_service_account_auth() -> Authenticator> { + let server_url = &mockito::server_url(); + let key: ServiceAccountKey = parse_json!({ + "type": "service_account", + "project_id": "yup-test-243420", + "private_key_id": "26de294916614a5ebdf7a065307ed3ea9941902b", + "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDemmylrvp1KcOn\n9yTAVVKPpnpYznvBvcAU8Qjwr2fSKylpn7FQI54wCk5VJVom0jHpAmhxDmNiP8yv\nHaqsef+87Oc0n1yZ71/IbeRcHZc2OBB33/LCFqf272kThyJo3qspEqhuAw0e8neg\nLQb4jpm9PsqR8IjOoAtXQSu3j0zkXemMYFy93PWHjVpPEUX16NGfsWH7oxspBHOk\n9JPGJL8VJdbiAoDSDgF0y9RjJY5I52UeHNhMsAkTYs6mIG4kKXt2+T9tAyHw8aho\nwmuytQAfydTflTfTG8abRtliF3nil2taAc5VB07dP1b4dVYy/9r6M8Z0z4XM7aP+\nNdn2TKm3AgMBAAECggEAWi54nqTlXcr2M5l535uRb5Xz0f+Q/pv3ceR2iT+ekXQf\n+mUSShOr9e1u76rKu5iDVNE/a7H3DGopa7ZamzZvp2PYhSacttZV2RbAIZtxU6th\n7JajPAM+t9klGh6wj4jKEcE30B3XVnbHhPJI9TCcUyFZoscuPXt0LLy/z8Uz0v4B\nd5JARwyxDMb53VXwukQ8nNY2jP7WtUig6zwE5lWBPFMbi8GwGkeGZOruAK5sPPwY\nGBAlfofKANI7xKx9UXhRwisB4+/XI1L0Q6xJySv9P+IAhDUI6z6kxR+WkyT/YpG3\nX9gSZJc7qEaxTIuDjtep9GTaoEqiGntjaFBRKoe+VQKBgQDzM1+Ii+REQqrGlUJo\nx7KiVNAIY/zggu866VyziU6h5wjpsoW+2Npv6Dv7nWvsvFodrwe50Y3IzKtquIal\nVd8aa50E72JNImtK/o5Nx6xK0VySjHX6cyKENxHRDnBmNfbALRM+vbD9zMD0lz2q\nmns/RwRGq3/98EqxP+nHgHSr9QKBgQDqUYsFAAfvfT4I75Glc9svRv8IsaemOm07\nW1LCwPnj1MWOhsTxpNF23YmCBupZGZPSBFQobgmHVjQ3AIo6I2ioV6A+G2Xq/JCF\nmzfbvZfqtbbd+nVgF9Jr1Ic5T4thQhAvDHGUN77BpjEqZCQLAnUWJx9x7e2xvuBl\n1A6XDwH/ewKBgQDv4hVyNyIR3nxaYjFd7tQZYHTOQenVffEAd9wzTtVbxuo4sRlR\nNM7JIRXBSvaATQzKSLHjLHqgvJi8LITLIlds1QbNLl4U3UVddJbiy3f7WGTqPFfG\nkLhUF4mgXpCpkMLxrcRU14Bz5vnQiDmQRM4ajS7/kfwue00BZpxuZxst3QKBgQCI\nRI3FhaQXyc0m4zPfdYYVc4NjqfVmfXoC1/REYHey4I1XetbT9Nb/+ow6ew0UbgSC\nUZQjwwJ1m1NYXU8FyovVwsfk9ogJ5YGiwYb1msfbbnv/keVq0c/Ed9+AG9th30qM\nIf93hAfClITpMz2mzXIMRQpLdmQSR4A2l+E4RjkSOwKBgQCB78AyIdIHSkDAnCxz\nupJjhxEhtQ88uoADxRoEga7H/2OFmmPsqfytU4+TWIdal4K+nBCBWRvAX1cU47vH\nJOlSOZI0gRKe0O4bRBQc8GXJn/ubhYSxI02IgkdGrIKpOb5GG10m85ZvqsXw3bKn\nRVHMD0ObF5iORjZUqD0yRitAdg==\n-----END PRIVATE KEY-----\n", + "client_email": "yup-test-sa-1@yup-test-243420.iam.gserviceaccount.com", + "client_id": "102851967901799660408", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": format!("{}/token", server_url), + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/yup-test-sa-1%40yup-test-243420.iam.gserviceaccount.com" + }); + + ServiceAccountAuthenticator::builder(key) + .build() + .await + .unwrap() +} + +#[tokio::test] +async fn test_service_account_success() { + use chrono::Utc; + let auth = create_service_account_auth().await; + + let json_response = serde_json::json!({ + "access_token": "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk", + "expires_in": 3600, + "token_type": "Bearer" + }); + let _m = mockito::mock("POST", "/token") + .with_status(200) + .with_header("content-type", "text/json") + .with_body(json_response.to_string()) + .expect(1) + .create(); + let tok = auth + .token(&["https://www.googleapis.com/auth/pubsub"]) + .await + .expect("token failed"); + assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); + assert!(Utc::now() + chrono::Duration::seconds(3600) >= tok.expires_at.unwrap()); + _m.assert(); +} + +#[tokio::test] +async fn test_service_account_error() { + let auth = create_service_account_auth().await; + let bad_json_response = serde_json::json!({ + "error": "access_denied", + }); + + let _m = mockito::mock("POST", "/token") + .with_status(200) + .with_header("content-type", "text/json") + .with_body(bad_json_response.to_string()) + .create(); + let result = auth + .token(&["https://www.googleapis.com/auth/pubsub"]) + .await; + assert!(result.is_err()); + _m.assert(); +} + +#[tokio::test] +async fn test_refresh() { + let auth = create_installed_flow_auth(InstalledFlowReturnMethod::Interactive, None).await; + // We refresh a token whenever it's within 1 minute of expiring. So + // acquiring a token that expires in 59 seconds will force a refresh on + // the next token call. + let _m = mockito::mock("POST", "/token") + .match_body(mockito::Matcher::Regex( + ".*code=authorizationcode.*client_id=9022167.*".to_string(), + )) + .with_body( + serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 59, + }) + .to_string(), + ) + .expect(1) + .create(); + let tok = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!("accesstoken", tok.access_token); + assert_eq!("refreshtoken", tok.refresh_token.unwrap()); + assert_eq!("Bearer", tok.token_type); + _m.assert(); + + let _m = mockito::mock("POST", "/token") + .match_body(mockito::Matcher::Regex( + ".*client_id=9022167.*refresh_token=refreshtoken.*".to_string(), + )) + .with_body( + serde_json::json!({ + "access_token": "accesstoken2", + "token_type": "Bearer", + "expires_in": 59, + }) + .to_string(), + ) + .expect(1) + .create(); + + let tok = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!("accesstoken2", tok.access_token); + assert_eq!("refreshtoken", tok.refresh_token.unwrap()); + assert_eq!("Bearer", tok.token_type); + _m.assert(); + + let _m = mockito::mock("POST", "/token") + .match_body(mockito::Matcher::Regex( + ".*client_id=9022167.*refresh_token=refreshtoken.*".to_string(), + )) + .with_body( + serde_json::json!({ + "error": "invalid_request", + }) + .to_string(), + ) + .expect(1) + .create(); + + let tok_err = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect_err("token refresh succeeded unexpectedly"); + match tok_err { + Error::AuthError(AuthError { + error: AuthErrorCode::InvalidRequest, + .. + }) => {} + e => panic!("unexpected error on refresh: {:?}", e), + } + _m.assert(); +} + +#[tokio::test] +async fn test_memory_storage() { + let auth = create_installed_flow_auth(InstalledFlowReturnMethod::Interactive, None).await; + let _m = mockito::mock("POST", "/token") + .match_body(mockito::Matcher::Regex( + ".*code=authorizationcode.*client_id=9022167.*".to_string(), + )) + .with_body( + serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 12345678 + }) + .to_string(), + ) + .expect(1) + .create(); + + // Call token twice. Ensure that only one http request is made and + // identical tokens are returned. + let token1 = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + let token2 = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!(token1.access_token.as_str(), "accesstoken"); + assert_eq!(token1, token2); + _m.assert(); + + // Create a new authenticator. This authenticator does not share a cache + // with the previous one. Validate that it receives a different token. + let auth2 = create_installed_flow_auth(InstalledFlowReturnMethod::Interactive, None).await; + let _m = mockito::mock("POST", "/token") + .match_body(mockito::Matcher::Regex( + ".*code=authorizationcode.*client_id=9022167.*".to_string(), + )) + .with_body( + serde_json::json!({ + "access_token": "accesstoken2", + "refresh_token": "refreshtoken2", + "token_type": "Bearer", + "expires_in": 12345678 + }) + .to_string(), + ) + .expect(1) + .create(); + let token3 = auth2 + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!(token3.access_token.as_str(), "accesstoken2"); + _m.assert(); +} + +#[tokio::test] +async fn test_disk_storage() { + let tempdir = tempfile::tempdir().unwrap(); + let storage_path = tempdir.path().join("tokenstorage.json"); + { + let auth = create_installed_flow_auth( + InstalledFlowReturnMethod::Interactive, + Some(storage_path.clone()), + ) + .await; + let _m = mockito::mock("POST", "/token") + .match_body(mockito::Matcher::Regex( + ".*code=authorizationcode.*client_id=9022167.*".to_string(), + )) + .with_body( + serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 12345678 + }) + .to_string(), + ) + .expect(1) + .create(); + + // Call token twice. Ensure that only one http request is made and + // identical tokens are returned. + let token1 = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + let token2 = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!(token1.access_token.as_str(), "accesstoken"); + assert_eq!(token1, token2); + _m.assert(); + } + + // Create a new authenticator. This authenticator uses the same token + // storage file as the previous one so should receive a token without + // making any http requests. + let auth = create_installed_flow_auth( + InstalledFlowReturnMethod::Interactive, + Some(storage_path.clone()), + ) + .await; + // Call token twice. Ensure that identical tokens are returned. + let token1 = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + let token2 = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!(token1.access_token.as_str(), "accesstoken"); + assert_eq!(token1, token2); +} From 36d186deb405c7e3bc4b66a144a769323fdf6775 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 29 Nov 2019 15:25:23 -0800 Subject: [PATCH 66/71] Authenticator now returns an AccessToken. What was previously called Token is now TokenInfo and is merely an internal implementation detail. The publicly visible type is now called AccessToken and differs from TokenInfo by not including the refresh token. This makes it a smaller type for users to pass around as well as reducing the ways that a refresh token may be leaked. Since the Authenticator is responsible for refreshing the tokens there isn't any reason users should need to concern themselves with refresh tokens. --- src/authenticator.rs | 26 +++++++------- src/device.rs | 10 +++--- src/installed.rs | 12 +++---- src/lib.rs | 2 +- src/refresh.rs | 10 +++--- src/service_account.rs | 6 ++-- src/storage.rs | 23 ++++++------- src/types.rs | 78 ++++++++++++++++++++++++++++++++++++------ tests/tests.rs | 30 ++++++---------- 9 files changed, 122 insertions(+), 75 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index 6eea182..9d0feda 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -6,7 +6,7 @@ use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; use crate::refresh::RefreshFlow; use crate::service_account::{ServiceAccountFlow, ServiceAccountFlowOpts, ServiceAccountKey}; use crate::storage::{self, Storage}; -use crate::types::{ApplicationSecret, Token}; +use crate::types::{AccessToken, ApplicationSecret, TokenInfo}; use private::AuthFlow; use std::borrow::Cow; @@ -27,35 +27,35 @@ where C: hyper::client::connect::Connect + 'static, { /// Return the current token for the provided scopes. - pub async fn token<'a, T>(&'a self, scopes: &'a [T]) -> Result + pub async fn token<'a, T>(&'a self, scopes: &'a [T]) -> Result where T: AsRef, { let hashed_scopes = storage::ScopeSet::from(scopes); match (self.storage.get(hashed_scopes), self.auth_flow.app_secret()) { - (Some(t), _) if !t.expired() => { + (Some(t), _) if !t.is_expired() => { // unexpired token found - Ok(t) + Ok(t.into()) } ( - Some(Token { + Some(TokenInfo { refresh_token: Some(refresh_token), .. }), Some(app_secret), ) => { // token is expired but has a refresh token. - let token = + let token_info = RefreshFlow::refresh_token(&self.hyper_client, app_secret, &refresh_token) .await?; - self.storage.set(hashed_scopes, token.clone()).await?; - Ok(token) + self.storage.set(hashed_scopes, token_info.clone()).await?; + Ok(token_info.into()) } _ => { // no token in the cache or the token returned can't be refreshed. - let t = self.auth_flow.token(&self.hyper_client, scopes).await?; - self.storage.set(hashed_scopes, t.clone()).await?; - Ok(t) + let token_info = self.auth_flow.token(&self.hyper_client, scopes).await?; + self.storage.set(hashed_scopes, token_info.clone()).await?; + Ok(token_info.into()) } } } @@ -354,7 +354,7 @@ mod private { use crate::error::Error; use crate::installed::InstalledFlow; use crate::service_account::ServiceAccountFlow; - use crate::types::{ApplicationSecret, Token}; + use crate::types::{ApplicationSecret, TokenInfo}; pub enum AuthFlow { DeviceFlow(DeviceFlow), @@ -375,7 +375,7 @@ mod private { &'a self, hyper_client: &'a hyper::Client, scopes: &'a [T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, diff --git a/src/device.rs b/src/device.rs index e80ad01..b21098b 100644 --- a/src/device.rs +++ b/src/device.rs @@ -2,7 +2,7 @@ use crate::authenticator_delegate::{ DefaultDeviceFlowDelegate, DeviceAuthResponse, DeviceFlowDelegate, }; use crate::error::{AuthError, Error}; -use crate::types::{ApplicationSecret, Token}; +use crate::types::{ApplicationSecret, TokenInfo}; use std::borrow::Cow; use std::time::Duration; @@ -43,7 +43,7 @@ impl DeviceFlow { &self, hyper_client: &hyper::Client, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -73,7 +73,7 @@ impl DeviceFlow { app_secret: &ApplicationSecret, device_auth_resp: &DeviceAuthResponse, grant_type: &str, - ) -> Result + ) -> Result where C: hyper::client::connect::Connect + 'static, { @@ -168,7 +168,7 @@ impl DeviceFlow { client: &hyper::Client, device_code: &str, grant_type: &str, - ) -> Result + ) -> Result where C: hyper::client::connect::Connect + 'static, { @@ -188,6 +188,6 @@ impl DeviceFlow { .unwrap(); // TODO: Error checking let res = client.request(request).await?; let body = res.into_body().try_concat().await?; - Token::from_json(&body) + TokenInfo::from_json(&body) } } diff --git a/src/installed.rs b/src/installed.rs index f2ae102..4a6bebd 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -4,7 +4,7 @@ // use crate::authenticator_delegate::{DefaultInstalledFlowDelegate, InstalledFlowDelegate}; use crate::error::Error; -use crate::types::{ApplicationSecret, Token}; +use crate::types::{ApplicationSecret, TokenInfo}; use std::convert::AsRef; use std::future::Future; @@ -93,7 +93,7 @@ impl InstalledFlow { &self, hyper_client: &hyper::Client, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -115,7 +115,7 @@ impl InstalledFlow { hyper_client: &hyper::Client, app_secret: &ApplicationSecret, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -140,7 +140,7 @@ impl InstalledFlow { hyper_client: &hyper::Client, app_secret: &ApplicationSecret, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -178,7 +178,7 @@ impl InstalledFlow { hyper_client: &hyper::Client, app_secret: &ApplicationSecret, server_addr: Option, - ) -> Result + ) -> Result where C: hyper::client::connect::Connect + 'static, { @@ -186,7 +186,7 @@ impl InstalledFlow { let request = Self::request_token(app_secret, authcode, redirect_uri, server_addr); let resp = hyper_client.request(request).await?; let body = resp.into_body().try_concat().await?; - Token::from_json(&body) + TokenInfo::from_json(&body) } /// Sends the authorization code to the provider in order to obtain access and refresh tokens. diff --git a/src/lib.rs b/src/lib.rs index aa7c0fb..7fa2e58 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,4 +93,4 @@ pub use crate::service_account::ServiceAccountKey; #[doc(inline)] pub use crate::error::Error; -pub use crate::types::{ApplicationSecret, ConsoleApplicationSecret, Token}; +pub use crate::types::{AccessToken, ApplicationSecret, ConsoleApplicationSecret}; diff --git a/src/refresh.rs b/src/refresh.rs index 1ed72e8..0ce14bf 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -1,5 +1,5 @@ use crate::error::Error; -use crate::types::{ApplicationSecret, Token}; +use crate::types::{ApplicationSecret, TokenInfo}; use futures_util::try_stream::TryStreamExt; use hyper::header; @@ -10,7 +10,7 @@ use url::form_urlencoded; /// Refresh an expired access token, as obtained by any other authentication flow. /// This flow is useful when your `Token` is expired and allows to obtain a new /// and valid access token. -pub struct RefreshFlow; +pub(crate) struct RefreshFlow; impl RefreshFlow { /// Attempt to refresh the given token, and obtain a new, valid one. @@ -27,11 +27,11 @@ impl RefreshFlow { /// /// # Examples /// Please see the crate landing page for an example. - pub async fn refresh_token( + pub(crate) async fn refresh_token( client: &hyper::Client, client_secret: &ApplicationSecret, refresh_token: &str, - ) -> Result { + ) -> Result { let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ ("client_id", client_secret.client_id.as_str()), @@ -48,7 +48,7 @@ impl RefreshFlow { let resp = client.request(request).await?; let body = resp.into_body().try_concat().await?; - let mut token = Token::from_json(&body)?; + let mut token = TokenInfo::from_json(&body)?; // If the refresh result contains a refresh_token use it, otherwise // continue using our previous refresh_token. token diff --git a/src/service_account.rs b/src/service_account.rs index 1eeeb47..55c545f 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -12,7 +12,7 @@ //! use crate::error::Error; -use crate::types::Token; +use crate::types::TokenInfo; use std::io; @@ -181,7 +181,7 @@ impl ServiceAccountFlow { &self, hyper_client: &hyper::Client, scopes: &[T], - ) -> Result + ) -> Result where T: AsRef, C: hyper::client::connect::Connect + 'static, @@ -202,7 +202,7 @@ impl ServiceAccountFlow { .unwrap(); let response = hyper_client.request(request).await?; let body = response.into_body().try_concat().await?; - Token::from_json(&body) + TokenInfo::from_json(&body) } } diff --git a/src/storage.rs b/src/storage.rs index 7d38575..89ed9c2 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -2,7 +2,7 @@ // // See project root for licensing information. // -use crate::types::Token; +use crate::types::TokenInfo; use std::collections::HashMap; use std::io; @@ -114,7 +114,7 @@ impl Storage { pub(crate) async fn set( &self, scopes: ScopeSet<'_, T>, - token: Token, + token: TokenInfo, ) -> Result<(), io::Error> where T: AsRef, @@ -125,7 +125,7 @@ impl Storage { } } - pub(crate) fn get(&self, scopes: ScopeSet) -> Option + pub(crate) fn get(&self, scopes: ScopeSet) -> Option where T: AsRef, { @@ -141,7 +141,7 @@ impl Storage { #[derive(Debug, Clone)] struct JSONToken { scopes: Vec, - token: Token, + token: TokenInfo, hash: ScopeHash, filter: ScopeFilter, } @@ -154,7 +154,7 @@ impl<'de> Deserialize<'de> for JSONToken { #[derive(Deserialize)] struct RawJSONToken { scopes: Vec, - token: Token, + token: TokenInfo, } let RawJSONToken { scopes, token } = RawJSONToken::deserialize(deserializer)?; let ScopeSet { hash, filter, .. } = ScopeSet::from(&scopes); @@ -175,7 +175,7 @@ impl Serialize for JSONToken { #[derive(Serialize)] struct RawJSONToken<'a> { scopes: &'a [String], - token: &'a Token, + token: &'a TokenInfo, } RawJSONToken { scopes: &self.scopes, @@ -251,7 +251,7 @@ impl JSONTokens { filter, scopes, }: ScopeSet, - ) -> Option + ) -> Option where T: AsRef, { @@ -280,7 +280,7 @@ impl JSONTokens { filter, scopes, }: ScopeSet, - token: Token, + token: TokenInfo, ) -> Result<(), io::Error> where T: AsRef, @@ -326,7 +326,7 @@ impl DiskStorage { pub(crate) async fn set( &self, scopes: ScopeSet<'_, T>, - token: Token, + token: TokenInfo, ) -> Result<(), io::Error> where T: AsRef, @@ -341,7 +341,7 @@ impl DiskStorage { tokio::fs::write(self.filename.clone(), json).await } - pub(crate) fn get(&self, scopes: ScopeSet) -> Option + pub(crate) fn get(&self, scopes: ScopeSet) -> Option where T: AsRef, { @@ -375,10 +375,9 @@ mod tests { #[tokio::test] async fn test_disk_storage() { - let new_token = |access_token: &str| Token { + let new_token = |access_token: &str| TokenInfo { access_token: access_token.to_owned(), refresh_token: None, - token_type: "Bearer".to_owned(), expires_at: None, }; let scope_set = ScopeSet::from(&["myscope"]); diff --git a/src/types.rs b/src/types.rs index c9b3225..5633f0d 100644 --- a/src/types.rs +++ b/src/types.rs @@ -3,25 +3,70 @@ use crate::error::{AuthErrorOr, Error}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +/// Represents an access token returned by oauth2 servers. All access tokens are +/// Bearer tokens. Other types of tokens are not supported. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] +pub struct AccessToken { + value: String, + expires_at: Option>, +} + +impl AccessToken { + /// A string representation of the access token. + pub fn as_str(&self) -> &str { + &self.value + } + + /// The time the access token will expire, if any. + pub fn expiration_time(&self) -> Option> { + self.expires_at + } + + /// Determine if the access token is expired. + /// This will report that the token is expired 1 minute prior to the + /// expiration time to ensure that when the token is actually sent to the + /// server it's still valid. + pub fn is_expired(&self) -> bool { + // Consider the token expired if it's within 1 minute of it's expiration + // time. + self.expires_at + .map(|expiration_time| expiration_time - chrono::Duration::minutes(1) <= Utc::now()) + .unwrap_or(false) + } +} + +impl AsRef for AccessToken { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl From for AccessToken { + fn from(value: TokenInfo) -> Self { + AccessToken { + value: value.access_token, + expires_at: value.expires_at, + } + } +} + /// Represents a token as returned by OAuth2 servers. /// /// It is produced by all authentication flows. /// It authenticates certain operations, and must be refreshed once /// it reached it's expiry date. #[derive(Clone, PartialEq, Debug, Deserialize, Serialize)] -pub struct Token { +pub(crate) struct TokenInfo { /// used when authenticating calls to oauth2 enabled services. - pub access_token: String, + pub(crate) access_token: String, /// used to refresh an expired access_token. - pub refresh_token: Option, - /// The token type as string - usually 'Bearer'. - pub token_type: String, + pub(crate) refresh_token: Option, /// The time when the token expires. - pub expires_at: Option>, + pub(crate) expires_at: Option>, } -impl Token { - pub(crate) fn from_json(json_data: &[u8]) -> Result { +impl TokenInfo { + pub(crate) fn from_json(json_data: &[u8]) -> Result { #[derive(Deserialize)] struct RawToken { access_token: String, @@ -37,19 +82,30 @@ impl Token { expires_in, } = serde_json::from_slice::>(json_data)?.into_result()?; + if token_type.to_lowercase().as_str() != "bearer" { + use std::io; + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + r#"unknown token type returned; expected "bearer" found {}"#, + token_type + ), + ) + .into()); + } + let expires_at = expires_in .map(|seconds_from_now| Utc::now() + chrono::Duration::seconds(seconds_from_now)); - Ok(Token { + Ok(TokenInfo { access_token, refresh_token, - token_type, expires_at, }) } /// Returns true if we are expired. - pub fn expired(&self) -> bool { + pub fn is_expired(&self) -> bool { self.expires_at .map(|expiration_time| expiration_time - chrono::Duration::minutes(1) <= Utc::now()) .unwrap_or(false) diff --git a/tests/tests.rs b/tests/tests.rs index 40fe487..fe2d6b3 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -88,7 +88,7 @@ async fn test_device_success() { .token(&["https://www.googleapis.com/scope/1"]) .await .expect("token failed"); - assert_eq!("accesstoken", token.access_token); + assert_eq!("accesstoken", token.as_str()); _m.assert(); } @@ -253,9 +253,7 @@ async fn test_installed_interactive_success() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!("accesstoken", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); + assert_eq!("accesstoken", tok.as_str()); _m.assert(); } @@ -282,9 +280,7 @@ async fn test_installed_redirect_success() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!("accesstoken", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); + assert_eq!("accesstoken", tok.as_str()); _m.assert(); } @@ -347,8 +343,8 @@ async fn test_service_account_success() { .token(&["https://www.googleapis.com/auth/pubsub"]) .await .expect("token failed"); - assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); - assert!(Utc::now() + chrono::Duration::seconds(3600) >= tok.expires_at.unwrap()); + assert!(tok.as_str().contains("ya29.c.ElouBywiys0Ly")); + assert!(Utc::now() + chrono::Duration::seconds(3600) >= tok.expiration_time().unwrap()); _m.assert(); } @@ -396,9 +392,7 @@ async fn test_refresh() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!("accesstoken", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); + assert_eq!("accesstoken", tok.as_str()); _m.assert(); let _m = mockito::mock("POST", "/token") @@ -420,9 +414,7 @@ async fn test_refresh() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!("accesstoken2", tok.access_token); - assert_eq!("refreshtoken", tok.refresh_token.unwrap()); - assert_eq!("Bearer", tok.token_type); + assert_eq!("accesstoken2", tok.as_str()); _m.assert(); let _m = mockito::mock("POST", "/token") @@ -481,7 +473,7 @@ async fn test_memory_storage() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!(token1.access_token.as_str(), "accesstoken"); + assert_eq!(token1.as_str(), "accesstoken"); assert_eq!(token1, token2); _m.assert(); @@ -507,7 +499,7 @@ async fn test_memory_storage() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!(token3.access_token.as_str(), "accesstoken2"); + assert_eq!(token3.as_str(), "accesstoken2"); _m.assert(); } @@ -547,7 +539,7 @@ async fn test_disk_storage() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!(token1.access_token.as_str(), "accesstoken"); + assert_eq!(token1.as_str(), "accesstoken"); assert_eq!(token1, token2); _m.assert(); } @@ -569,6 +561,6 @@ async fn test_disk_storage() { .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); - assert_eq!(token1.access_token.as_str(), "accesstoken"); + assert_eq!(token1.as_str(), "accesstoken"); assert_eq!(token1, token2); } From 6817fce0bc9d14009a1d1dfb9365d991ce6fbc4b Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 29 Nov 2019 15:31:41 -0800 Subject: [PATCH 67/71] Extend the refresh tests. Verify that a second refresh can happen after the first. This adds coverage to ensure that a refresh flow keeps the refresh token intact by showing that a second refresh can succeed. --- tests/tests.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/tests.rs b/tests/tests.rs index fe2d6b3..44ff000 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -417,6 +417,28 @@ async fn test_refresh() { assert_eq!("accesstoken2", tok.as_str()); _m.assert(); + let _m = mockito::mock("POST", "/token") + .match_body(mockito::Matcher::Regex( + ".*client_id=9022167.*refresh_token=refreshtoken.*".to_string(), + )) + .with_body( + serde_json::json!({ + "access_token": "accesstoken3", + "token_type": "Bearer", + "expires_in": 59, + }) + .to_string(), + ) + .expect(1) + .create(); + + let tok = auth + .token(&["https://googleapis.com/some/scope"]) + .await + .expect("failed to get token"); + assert_eq!("accesstoken3", tok.as_str()); + _m.assert(); + let _m = mockito::mock("POST", "/token") .match_body(mockito::Matcher::Regex( ".*client_id=9022167.*refresh_token=refreshtoken.*".to_string(), From 5c0334ee6fec9548fe8c3966e9574debf69e2797 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Mon, 2 Dec 2019 10:56:13 -0800 Subject: [PATCH 68/71] Add debug logging. Could be helpful when troubleshooting issues with various providers if the user is able to turn on debug logging. The most critical logging provided is the request and responses sent and received from the oauth servers. --- src/authenticator.rs | 25 +++++++++++++++++++++++++ src/device.rs | 22 +++++++++++++++++----- src/installed.rs | 18 +++++++++++++----- src/refresh.rs | 11 ++++++++--- src/service_account.rs | 6 ++++-- src/storage.rs | 2 +- 6 files changed, 68 insertions(+), 16 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index 9d0feda..6b71424 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -10,6 +10,7 @@ use crate::types::{AccessToken, ApplicationSecret, TokenInfo}; use private::AuthFlow; use std::borrow::Cow; +use std::fmt; use std::io; use std::path::PathBuf; use std::sync::Mutex; @@ -22,6 +23,25 @@ pub struct Authenticator { auth_flow: AuthFlow, } +struct DisplayScopes<'a, T>(&'a [T]); +impl<'a, T> fmt::Display for DisplayScopes<'a, T> +where + T: AsRef, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("[")?; + let mut iter = self.0.iter(); + if let Some(first) = iter.next() { + f.write_str(first.as_ref())?; + for scope in iter { + f.write_str(", ")?; + f.write_str(scope.as_ref())?; + } + } + f.write_str("]") + } +} + impl Authenticator where C: hyper::client::connect::Connect + 'static, @@ -31,10 +51,15 @@ where where T: AsRef, { + log::debug!( + "access token requested for scopes: {}", + DisplayScopes(scopes) + ); let hashed_scopes = storage::ScopeSet::from(scopes); match (self.storage.get(hashed_scopes), self.auth_flow.app_secret()) { (Some(t), _) if !t.is_expired() => { // unexpired token found + log::debug!("found valid token in cache: {:?}", t); Ok(t.into()) } ( diff --git a/src/device.rs b/src/device.rs index b21098b..092cb80 100644 --- a/src/device.rs +++ b/src/device.rs @@ -55,6 +55,7 @@ impl DeviceFlow { scopes, ) .await?; + log::debug!("Presenting code to user"); self.flow_delegate .present_user_code(&device_auth_resp) .await; @@ -78,6 +79,7 @@ impl DeviceFlow { C: hyper::client::connect::Connect + 'static, { let mut interval = device_auth_resp.interval; + log::debug!("Polling every {:?} for device token", interval); loop { tokio::timer::delay_for(interval).await; interval = match Self::poll_token( @@ -92,10 +94,16 @@ impl DeviceFlow { Err(Error::AuthError(AuthError { error, .. })) if error.as_str() == "authorization_pending" => { + log::debug!("still waiting on authorization from the server"); interval } Err(Error::AuthError(AuthError { error, .. })) if error.as_str() == "slow_down" => { - interval + Duration::from_secs(5) + let interval = interval + Duration::from_secs(5); + log::debug!( + "server requested slow_down. Increasing polling interval to {:?}", + interval + ); + interval } Err(err) => return Err(err), } @@ -140,8 +148,10 @@ impl DeviceFlow { .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(req)) .unwrap(); - let resp = client.request(req).await?; - let body = resp.into_body().try_concat().await?; + log::debug!("requesting code from server: {:?}", req); + let (head, body) = client.request(req).await?.into_parts(); + let body = body.try_concat().await?; + log::debug!("received response; head: {:?}, body: {:?}", head, body); DeviceAuthResponse::from_json(&body) } @@ -186,8 +196,10 @@ impl DeviceFlow { .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(req)) .unwrap(); // TODO: Error checking - let res = client.request(request).await?; - let body = res.into_body().try_concat().await?; + log::debug!("polling for token: {:?}", request); + let (head, body) = client.request(request).await?.into_parts(); + let body = body.try_concat().await?; + log::debug!("received response; head: {:?} body: {:?}", head, body); TokenInfo::from_json(&body) } } diff --git a/src/installed.rs b/src/installed.rs index 4a6bebd..fd85154 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -126,12 +126,14 @@ impl InstalledFlow { scopes, self.flow_delegate.redirect_uri(), ); - let authcode = self + log::debug!("Presenting auth url to user: {}", url); + let auth_code = self .flow_delegate .present_user_url(&url, true /* need code */) .await .map_err(Error::UserError)?; - self.exchange_auth_code(&authcode, hyper_client, app_secret, None) + log::debug!("Received auth code: {}", auth_code); + self.exchange_auth_code(&auth_code, hyper_client, app_secret, None) .await } @@ -162,11 +164,11 @@ impl InstalledFlow { scopes, Some(redirect_uri.as_ref()), ); + log::debug!("Presenting auth url to user: {}", url); let _ = self .flow_delegate .present_user_url(&url, false /* need code */) .await; - let auth_code = server.wait_for_auth_code().await; self.exchange_auth_code(&auth_code, hyper_client, app_secret, Some(server_addr)) .await @@ -184,8 +186,10 @@ impl InstalledFlow { { let redirect_uri = self.flow_delegate.redirect_uri(); let request = Self::request_token(app_secret, authcode, redirect_uri, server_addr); - let resp = hyper_client.request(request).await?; - let body = resp.into_body().try_concat().await?; + log::debug!("Sending request: {:?}", request); + let (head, body) = hyper_client.request(request).await?.into_parts(); + let body = body.try_concat().await?; + log::debug!("Received response; head: {:?} body: {:?}", head, body); TokenInfo::from_json(&body) } @@ -265,6 +269,7 @@ impl InstalledFlowServer { }) .await; }); + log::debug!("HTTP server listening on {}", addr); Ok(InstalledFlowServer { addr, auth_code_rx, @@ -278,11 +283,14 @@ impl InstalledFlowServer { } async fn wait_for_auth_code(self) -> String { + log::debug!("Waiting for HTTP server to receive auth code"); // Wait for the auth code from the server. let auth_code = self .auth_code_rx .await .expect("server shutdown while waiting for auth_code"); + log::debug!("HTTP server received auth code: {}", auth_code); + log::debug!("Shutting down HTTP server"); // auth code received. shutdown the server let _ = self.trigger_shutdown_tx.send(()); self.shutdown_complete.await; diff --git a/src/refresh.rs b/src/refresh.rs index 0ce14bf..3e28f87 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -32,6 +32,10 @@ impl RefreshFlow { client_secret: &ApplicationSecret, refresh_token: &str, ) -> Result { + log::debug!( + "refreshing access token with refresh token: {}", + refresh_token + ); let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ ("client_id", client_secret.client_id.as_str()), @@ -45,9 +49,10 @@ impl RefreshFlow { .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(req)) .unwrap(); - - let resp = client.request(request).await?; - let body = resp.into_body().try_concat().await?; + log::debug!("Sending request: {:?}", request); + let (head, body) = client.request(request).await?.into_parts(); + let body = body.try_concat().await?; + log::debug!("Received response; head: {:?}, body: {:?}", head, body); let mut token = TokenInfo::from_json(&body)?; // If the refresh result contains a refresh_token use it, otherwise // continue using our previous refresh_token. diff --git a/src/service_account.rs b/src/service_account.rs index 55c545f..2427f7c 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -200,8 +200,10 @@ impl ServiceAccountFlow { .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(rqbody)) .unwrap(); - let response = hyper_client.request(request).await?; - let body = response.into_body().try_concat().await?; + log::debug!("requesting token from service account: {:?}", request); + let (head, body) = hyper_client.request(request).await?.into_parts(); + let body = body.try_concat().await?; + log::debug!("received response; head: {:?}, body: {:?}", head, body); TokenInfo::from_json(&body) } } diff --git a/src/storage.rs b/src/storage.rs index 89ed9c2..2a6f274 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -227,7 +227,7 @@ impl<'de> Deserialize<'de> for JSONTokens { } // Instantiate our Visitor and ask the Deserializer to drive - // it over the input data, resulting in an instance of MyMap. + // it over the input data. deserializer.deserialize_seq(V) } } From 348a59d96e13d6dcc43920ad623676d9d0dce86c Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Wed, 11 Dec 2019 13:49:07 -0800 Subject: [PATCH 69/71] Create the token file with more secure permissions on unix. This creates files with 0600 permissions on unix. Still the default permissions on non-unix platforms. --- src/storage.rs | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/storage.rs b/src/storage.rs index 2a6f274..4acc30a 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -331,6 +331,7 @@ impl DiskStorage { where T: AsRef, { + use tokio::io::AsyncWriteExt; let json = { use std::ops::Deref; let mut lock = self.tokens.lock().unwrap(); @@ -338,7 +339,9 @@ impl DiskStorage { serde_json::to_string(lock.deref()) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? }; - tokio::fs::write(self.filename.clone(), json).await + let mut f = open_writeable_file(&self.filename).await?; + f.write_all(json.as_bytes()).await?; + Ok(()) } pub(crate) fn get(&self, scopes: ScopeSet) -> Option @@ -349,6 +352,30 @@ impl DiskStorage { } } +#[cfg(unix)] +async fn open_writeable_file( + filename: impl AsRef, +) -> Result { + // Ensure if the file is created it's only readable and writable by the + // current user. + use std::os::unix::fs::OpenOptionsExt; + let opts: tokio::fs::OpenOptions = { + let mut opts = std::fs::OpenOptions::new(); + opts.write(true).create(true).truncate(true).mode(0o600); + opts.into() + }; + opts.open(filename).await +} + +#[cfg(not(unix))] +async fn open_writeable_file( + filename: impl AsRef, +) -> Result { + // I don't have knowledge of windows or other platforms to know how to + // create a file that's only readable by the current user. + tokio::fs::File::create(filename).await +} + #[cfg(test)] mod tests { use super::*; From 9238153723715d2e3a308ca077356fcaec8fd903 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Tue, 17 Dec 2019 16:00:31 -0800 Subject: [PATCH 70/71] Move to hyper 0.13.1!!!! --- Cargo.toml | 12 +++++------ examples/test-device/Cargo.toml | 5 +---- examples/test-device/src/main.rs | 5 +---- examples/test-installed/Cargo.toml | 5 +---- examples/test-installed/src/main.rs | 4 +--- examples/test-svc-acct/Cargo.toml | 5 +---- examples/test-svc-acct/src/main.rs | 1 - src/authenticator.rs | 8 +++---- src/authenticator_delegate.rs | 2 +- src/device.rs | 15 ++++++------- src/installed.rs | 33 ++++++++--------------------- src/refresh.rs | 10 +++++---- src/service_account.rs | 5 ++--- tests/tests.rs | 2 +- 14 files changed, 40 insertions(+), 72 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d9a377a..ffcb393 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,18 +13,16 @@ edition = "2018" [dependencies] base64 = "0.10" chrono = { version = "0.4", features = ["serde"] } -http = "0.1" -hyper = {version = "0.13.0-alpha.4", features = ["unstable-stream"]} -hyper-rustls = "=0.18.0-alpha.2" +http = "0.2" +hyper = "0.13.1" +hyper-rustls = "0.19" log = "0.4" rustls = "0.16" +seahash = "3.0.6" serde = {version = "1.0", features = ["derive"]} serde_json = "1.0" +tokio = { version = "0.2", features = ["fs", "macros", "io-std", "time"] } url = "1" -futures-preview = "=0.3.0-alpha.19" -tokio = "=0.2.0-alpha.6" -futures-util-preview = "=0.3.0-alpha.19" -seahash = "3.0.6" [dev-dependencies] mockito = "0.17" diff --git a/examples/test-device/Cargo.toml b/examples/test-device/Cargo.toml index 39ca484..3cf49c5 100644 --- a/examples/test-device/Cargo.toml +++ b/examples/test-device/Cargo.toml @@ -6,7 +6,4 @@ edition = "2018" [dependencies] yup-oauth2 = { path = "../../" } -hyper = {version = "0.13.0-alpha.4", features = ["unstable-stream"]} -hyper-rustls = "=0.18.0-alpha.2" -futures-preview = "=0.3.0-alpha.19" -tokio = "=0.2.0-alpha.6" +tokio = { version = "0.2", features = ["macros"] } \ No newline at end of file diff --git a/examples/test-device/src/main.rs b/examples/test-device/src/main.rs index 6d505b4..6586d40 100644 --- a/examples/test-device/src/main.rs +++ b/examples/test-device/src/main.rs @@ -1,11 +1,8 @@ use yup_oauth2::DeviceFlowAuthenticator; -use std::path; -use tokio; - #[tokio::main] async fn main() { - let app_secret = yup_oauth2::read_application_secret(path::Path::new("clientsecret.json")) + let app_secret = yup_oauth2::read_application_secret("clientsecret.json") .await .expect("clientsecret"); let auth = DeviceFlowAuthenticator::builder(app_secret) diff --git a/examples/test-installed/Cargo.toml b/examples/test-installed/Cargo.toml index e7fa5d2..6b69312 100644 --- a/examples/test-installed/Cargo.toml +++ b/examples/test-installed/Cargo.toml @@ -6,7 +6,4 @@ edition = "2018" [dependencies] yup-oauth2 = { path = "../../" } -hyper = {version = "0.13.0-alpha.4", features = ["unstable-stream"]} -hyper-rustls = "=0.18.0-alpha.2" -futures-preview = "=0.3.0-alpha.19" -tokio = "=0.2.0-alpha.6" +tokio = { version = "0.2", features = ["macros"] } diff --git a/examples/test-installed/src/main.rs b/examples/test-installed/src/main.rs index c59f9c9..43a797c 100644 --- a/examples/test-installed/src/main.rs +++ b/examples/test-installed/src/main.rs @@ -1,10 +1,8 @@ use yup_oauth2::{InstalledFlowAuthenticator, InstalledFlowReturnMethod}; -use std::path::Path; - #[tokio::main] async fn main() { - let app_secret = yup_oauth2::read_application_secret(Path::new("clientsecret.json")) + let app_secret = yup_oauth2::read_application_secret("clientsecret.json") .await .expect("clientsecret.json"); diff --git a/examples/test-svc-acct/Cargo.toml b/examples/test-svc-acct/Cargo.toml index 14c7d9b..abc2694 100644 --- a/examples/test-svc-acct/Cargo.toml +++ b/examples/test-svc-acct/Cargo.toml @@ -6,7 +6,4 @@ edition = "2018" [dependencies] yup-oauth2 = { path = "../../" } -hyper = {version = "0.13.0-alpha.4", features = ["unstable-stream"]} -hyper-rustls = "=0.18.0-alpha.2" -futures-preview = "=0.3.0-alpha.19" -tokio = "=0.2.0-alpha.6" +tokio = { version = "0.2", features = ["macros"] } diff --git a/examples/test-svc-acct/src/main.rs b/examples/test-svc-acct/src/main.rs index ee79ece..ee67c03 100644 --- a/examples/test-svc-acct/src/main.rs +++ b/examples/test-svc-acct/src/main.rs @@ -1,4 +1,3 @@ -use tokio; use yup_oauth2::ServiceAccountAuthenticator; #[tokio::main] diff --git a/src/authenticator.rs b/src/authenticator.rs index 6b71424..0500e67 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -44,7 +44,7 @@ where impl Authenticator where - C: hyper::client::connect::Connect + 'static, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { /// Return the current token for the provided scopes. pub async fn token<'a, T>(&'a self, scopes: &'a [T]) -> Result @@ -403,7 +403,7 @@ mod private { ) -> Result where T: AsRef, - C: hyper::client::connect::Connect + 'static, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { match self { AuthFlow::DeviceFlow(device_flow) => device_flow.token(hyper_client, scopes).await, @@ -421,7 +421,7 @@ mod private { /// A trait implemented for any hyper::Client as well as the DefaultHyperClient. pub trait HyperClientBuilder { /// The hyper connector that the resulting hyper client will use. - type Connector: hyper::client::connect::Connect + 'static; + type Connector: hyper::client::connect::Connect + Clone + Send + Sync + 'static; /// Create a hyper::Client fn build_hyper_client(self) -> hyper::Client; @@ -441,7 +441,7 @@ impl HyperClientBuilder for DefaultHyperClient { impl HyperClientBuilder for hyper::Client where - C: hyper::client::connect::Connect + 'static, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { type Connector = C; diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 2f3933f..6308bd5 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -5,7 +5,7 @@ use std::pin::Pin; use std::time::Duration; use chrono::{DateTime, Local, Utc}; -use futures::prelude::*; +use std::future::Future; /// Contains state of pending authentication requests #[derive(Clone, Debug, PartialEq)] diff --git a/src/device.rs b/src/device.rs index 092cb80..70f0780 100644 --- a/src/device.rs +++ b/src/device.rs @@ -7,7 +7,6 @@ use crate::types::{ApplicationSecret, TokenInfo}; use std::borrow::Cow; use std::time::Duration; -use futures::prelude::*; use hyper::header; use url::form_urlencoded; @@ -46,7 +45,7 @@ impl DeviceFlow { ) -> Result where T: AsRef, - C: hyper::client::connect::Connect + 'static, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { let device_auth_resp = Self::request_code( &self.app_secret, @@ -76,12 +75,12 @@ impl DeviceFlow { grant_type: &str, ) -> Result where - C: hyper::client::connect::Connect + 'static, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { let mut interval = device_auth_resp.interval; log::debug!("Polling every {:?} for device token", interval); loop { - tokio::timer::delay_for(interval).await; + tokio::time::delay_for(interval).await; interval = match Self::poll_token( &app_secret, hyper_client, @@ -133,7 +132,7 @@ impl DeviceFlow { ) -> Result where T: AsRef, - C: hyper::client::connect::Connect + 'static, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ @@ -150,7 +149,7 @@ impl DeviceFlow { .unwrap(); log::debug!("requesting code from server: {:?}", req); let (head, body) = client.request(req).await?.into_parts(); - let body = body.try_concat().await?; + let body = hyper::body::to_bytes(body).await?; log::debug!("received response; head: {:?}, body: {:?}", head, body); DeviceAuthResponse::from_json(&body) } @@ -180,7 +179,7 @@ impl DeviceFlow { grant_type: &str, ) -> Result where - C: hyper::client::connect::Connect + 'static, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { // We should be ready for a new request let req = form_urlencoded::Serializer::new(String::new()) @@ -198,7 +197,7 @@ impl DeviceFlow { .unwrap(); // TODO: Error checking log::debug!("polling for token: {:?}", request); let (head, body) = client.request(request).await?.into_parts(); - let body = body.try_concat().await?; + let body = hyper::body::to_bytes(body).await?; log::debug!("received response; head: {:?} body: {:?}", head, body); TokenInfo::from_json(&body) } diff --git a/src/installed.rs b/src/installed.rs index fd85154..3611162 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -7,13 +7,9 @@ use crate::error::Error; use crate::types::{ApplicationSecret, TokenInfo}; use std::convert::AsRef; -use std::future::Future; use std::net::SocketAddr; -use std::pin::Pin; use std::sync::{Arc, Mutex}; -use futures::future::FutureExt; -use futures_util::try_stream::TryStreamExt; use hyper::header; use tokio::sync::oneshot; use url::form_urlencoded; @@ -96,7 +92,7 @@ impl InstalledFlow { ) -> Result where T: AsRef, - C: hyper::client::connect::Connect + 'static, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { match self.method { InstalledFlowReturnMethod::HTTPRedirect => { @@ -118,7 +114,7 @@ impl InstalledFlow { ) -> Result where T: AsRef, - C: hyper::client::connect::Connect + 'static, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { let url = build_authentication_request_url( &app_secret.auth_uri, @@ -145,7 +141,7 @@ impl InstalledFlow { ) -> Result where T: AsRef, - C: hyper::client::connect::Connect + 'static, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { use std::borrow::Cow; let server = InstalledFlowServer::run()?; @@ -182,13 +178,13 @@ impl InstalledFlow { server_addr: Option, ) -> Result where - C: hyper::client::connect::Connect + 'static, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { let redirect_uri = self.flow_delegate.redirect_uri(); let request = Self::request_token(app_secret, authcode, redirect_uri, server_addr); log::debug!("Sending request: {:?}", request); let (head, body) = hyper_client.request(request).await?.into_parts(); - let body = body.try_concat().await?; + let body = hyper::body::to_bytes(body).await?; log::debug!("Received response; head: {:?} body: {:?}", head, body); TokenInfo::from_json(&body) } @@ -224,22 +220,11 @@ impl InstalledFlow { } } -fn spawn_with_handle(f: F) -> impl Future -where - F: Future + 'static + Send, -{ - let (tx, rx) = oneshot::channel(); - tokio::spawn(f.map(move |_| tx.send(()).unwrap())); - async { - let _ = rx.await; - } -} - struct InstalledFlowServer { addr: SocketAddr, auth_code_rx: oneshot::Receiver, trigger_shutdown_tx: oneshot::Sender<()>, - shutdown_complete: Pin + Send>>, + shutdown_complete: tokio::task::JoinHandle<()>, } impl InstalledFlowServer { @@ -262,7 +247,7 @@ impl InstalledFlowServer { let server = hyper::server::Server::try_bind(&addr)?; let server = server.http1_only(true).serve(service); let addr = server.local_addr(); - let shutdown_complete = spawn_with_handle(async { + let shutdown_complete = tokio::spawn(async { let _ = server .with_graceful_shutdown(async move { let _ = trigger_shutdown_rx.await; @@ -274,7 +259,7 @@ impl InstalledFlowServer { addr, auth_code_rx, trigger_shutdown_tx, - shutdown_complete: Box::pin(shutdown_complete), + shutdown_complete, }) } @@ -293,7 +278,7 @@ impl InstalledFlowServer { log::debug!("Shutting down HTTP server"); // auth code received. shutdown the server let _ = self.trigger_shutdown_tx.send(()); - self.shutdown_complete.await; + let _ = self.shutdown_complete.await; auth_code } } diff --git a/src/refresh.rs b/src/refresh.rs index 3e28f87..53d12fa 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -1,7 +1,6 @@ use crate::error::Error; use crate::types::{ApplicationSecret, TokenInfo}; -use futures_util::try_stream::TryStreamExt; use hyper::header; use url::form_urlencoded; @@ -27,11 +26,14 @@ impl RefreshFlow { /// /// # Examples /// Please see the crate landing page for an example. - pub(crate) async fn refresh_token( + pub(crate) async fn refresh_token( client: &hyper::Client, client_secret: &ApplicationSecret, refresh_token: &str, - ) -> Result { + ) -> Result + where + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, + { log::debug!( "refreshing access token with refresh token: {}", refresh_token @@ -51,7 +53,7 @@ impl RefreshFlow { .unwrap(); log::debug!("Sending request: {:?}", request); let (head, body) = client.request(request).await?.into_parts(); - let body = body.try_concat().await?; + let body = hyper::body::to_bytes(body).await?; log::debug!("Received response; head: {:?}, body: {:?}", head, body); let mut token = TokenInfo::from_json(&body)?; // If the refresh result contains a refresh_token use it, otherwise diff --git a/src/service_account.rs b/src/service_account.rs index 2427f7c..81d4c6e 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -16,7 +16,6 @@ use crate::types::TokenInfo; use std::io; -use futures::prelude::*; use hyper::header; use rustls::{ self, @@ -184,7 +183,7 @@ impl ServiceAccountFlow { ) -> Result where T: AsRef, - C: hyper::client::connect::Connect + 'static, + C: hyper::client::connect::Connect + Clone + Send + Sync + 'static, { let claims = Claims::new(&self.key, scopes, self.subject.as_ref().map(|x| x.as_str())); let signed = self.signer.sign_claims(&claims).map_err(|_| { @@ -202,7 +201,7 @@ impl ServiceAccountFlow { .unwrap(); log::debug!("requesting token from service account: {:?}", request); let (head, body) = hyper_client.request(request).await?.into_parts(); - let body = body.try_concat().await?; + let body = hyper::body::to_bytes(body).await?; log::debug!("received response; head: {:?}, body: {:?}", head, body); TokenInfo::from_json(&body) } diff --git a/tests/tests.rs b/tests/tests.rs index 44ff000..4f75991 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -41,7 +41,7 @@ async fn create_device_flow_auth() -> Authenticator Pin + 'a + Send>> { assert_eq!("https://example.com/verify", pi.verification_uri); - Box::pin(futures::future::ready(())) + Box::pin(async {}) } } From 1d5c3a4512026dbbb0560cae3108c33d726dc993 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Tue, 17 Dec 2019 15:23:40 -0800 Subject: [PATCH 71/71] Switch from mockito to httptest --- Cargo.toml | 2 +- examples/test-device/Cargo.toml | 2 +- tests/tests.rs | 568 +++++++++++++++++--------------- 3 files changed, 301 insertions(+), 271 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ffcb393..a30608d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ tokio = { version = "0.2", features = ["fs", "macros", "io-std", "time"] } url = "1" [dev-dependencies] -mockito = "0.17" +httptest = "0.5" env_logger = "0.6" tempfile = "3.1" diff --git a/examples/test-device/Cargo.toml b/examples/test-device/Cargo.toml index 3cf49c5..21557bb 100644 --- a/examples/test-device/Cargo.toml +++ b/examples/test-device/Cargo.toml @@ -6,4 +6,4 @@ edition = "2018" [dependencies] yup-oauth2 = { path = "../../" } -tokio = { version = "0.2", features = ["macros"] } \ No newline at end of file +tokio = { version = "0.2", features = ["macros"] } diff --git a/tests/tests.rs b/tests/tests.rs index 4f75991..265916c 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -10,6 +10,7 @@ use std::future::Future; use std::path::PathBuf; use std::pin::Pin; +use httptest::{mappers::*, responders::json_encoded, Expectation, Server}; use hyper::client::connect::HttpConnector; use hyper::Uri; use hyper_rustls::HttpsConnector; @@ -23,13 +24,12 @@ macro_rules! parse_json { } } -async fn create_device_flow_auth() -> Authenticator> { - let server_url = mockito::server_url(); +async fn create_device_flow_auth(server: &Server) -> Authenticator> { let app_secret: ApplicationSecret = parse_json!({ "client_id": "902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com", "project_id": "yup-test-243420", "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": format!("{}/token", server_url), + "token_uri": server.url_str("/token"), "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "client_secret": "iuMPN6Ne1PD7cos29Tk9rlqH", "redirect_uris": ["urn:ietf:wg:oauth:2.0:oob","http://localhost"], @@ -47,7 +47,7 @@ async fn create_device_flow_auth() -> Authenticator Authenticator, ) -> Authenticator> { - let server_url = mockito::server_url(); let app_secret: ApplicationSecret = parse_json!({ "client_id": "902216714886-k2v9uei3p1dk6h686jbsn9mo96tnbvto.apps.googleusercontent.com", "project_id": "yup-test-243420", "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": format!("{}/token", server_url), + "token_uri": server.url_str("/token"), "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "client_secret": "iuMPN6Ne1PD7cos29Tk9rlqH", "redirect_uris": ["urn:ietf:wg:oauth:2.0:oob","http://localhost"], @@ -232,78 +235,95 @@ async fn create_installed_flow_auth( #[tokio::test] async fn test_installed_interactive_success() { - let auth = create_installed_flow_auth(InstalledFlowReturnMethod::Interactive, None).await; - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*code=authorizationcode.*client_id=9022167.*".to_string(), - )) - .with_body( - serde_json::json!({ - "access_token": "accesstoken", - "refresh_token": "refreshtoken", - "token_type": "Bearer", - "expires_in": 12345678 - }) - .to_string(), - ) - .expect(1) - .create(); + let _ = env_logger::try_init(); + let server = Server::run(); + let auth = + create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("code", "authorizationcode")), + contains_entry(("client_id", matches("9022167.*"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 12345678 + }))), + ); let tok = auth .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); assert_eq!("accesstoken", tok.as_str()); - _m.assert(); } #[tokio::test] async fn test_installed_redirect_success() { - let auth = create_installed_flow_auth(InstalledFlowReturnMethod::HTTPRedirect, None).await; - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*code=authorizationcode.*client_id=9022167.*".to_string(), - )) - .with_body( - serde_json::json!({ - "access_token": "accesstoken", - "refresh_token": "refreshtoken", - "token_type": "Bearer", - "expires_in": 12345678 - }) - .to_string(), - ) - .expect(1) - .create(); + let _ = env_logger::try_init(); + let server = Server::run(); + let auth = + create_installed_flow_auth(&server, InstalledFlowReturnMethod::HTTPRedirect, None).await; + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("code", "authorizationcode")), + contains_entry(("client_id", matches("9022167.*"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 12345678 + }))), + ); let tok = auth .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); assert_eq!("accesstoken", tok.as_str()); - _m.assert(); } #[tokio::test] async fn test_installed_error() { - let auth = create_installed_flow_auth(InstalledFlowReturnMethod::Interactive, None).await; - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*code=authorizationcode.*client_id=9022167.*".to_string(), - )) - .with_status(400) - .with_body(serde_json::json!({"error": "invalid_code"}).to_string()) - .expect(1) - .create(); + let _ = env_logger::try_init(); + let server = Server::run(); + let auth = + create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("code", "authorizationcode")), + contains_entry(("client_id", matches("9022167.*"))), + ])) + ]) + .respond_with( + http::Response::builder() + .status(404) + .body(serde_json::json!({"error": "invalid_code"}).to_string()) + .unwrap(), + ), + ); let tokr = auth.token(&["https://googleapis.com/some/scope"]).await; assert!(tokr.is_err()); assert!(format!("{}", tokr.unwrap_err()).contains("invalid_code")); - _m.assert(); } -async fn create_service_account_auth() -> Authenticator> { - let server_url = &mockito::server_url(); +async fn create_service_account_auth( + server: &Server, +) -> Authenticator> { let key: ServiceAccountKey = parse_json!({ "type": "service_account", "project_id": "yup-test-243420", @@ -312,7 +332,7 @@ async fn create_service_account_auth() -> Authenticator Authenticator= tok.expiration_time().unwrap()); - _m.assert(); } #[tokio::test] async fn test_service_account_error() { - let auth = create_service_account_auth().await; - let bad_json_response = serde_json::json!({ - "error": "access_denied", - }); + let _ = env_logger::try_init(); + let server = Server::run(); + let auth = create_service_account_auth(&server).await; + server.expect( + Expectation::matching(all_of![request::method("POST"), request::path("/token"),]) + .respond_with(json_encoded(serde_json::json!({ + "error": "access_denied", + }))), + ); - let _m = mockito::mock("POST", "/token") - .with_status(200) - .with_header("content-type", "text/json") - .with_body(bad_json_response.to_string()) - .create(); let result = auth .token(&["https://www.googleapis.com/auth/pubsub"]) .await; assert!(result.is_err()); - _m.assert(); } #[tokio::test] async fn test_refresh() { - let auth = create_installed_flow_auth(InstalledFlowReturnMethod::Interactive, None).await; + let _ = env_logger::try_init(); + let server = Server::run(); + let auth = + create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; // We refresh a token whenever it's within 1 minute of expiring. So // acquiring a token that expires in 59 seconds will force a refresh on // the next token call. - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*code=authorizationcode.*client_id=9022167.*".to_string(), - )) - .with_body( - serde_json::json!({ - "access_token": "accesstoken", - "refresh_token": "refreshtoken", - "token_type": "Bearer", - "expires_in": 59, - }) - .to_string(), - ) - .expect(1) - .create(); + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("code", "authorizationcode")), + contains_entry(("client_id", matches("^9022167"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 59, + }))), + ); let tok = auth .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); assert_eq!("accesstoken", tok.as_str()); - _m.assert(); - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*client_id=9022167.*refresh_token=refreshtoken.*".to_string(), - )) - .with_body( - serde_json::json!({ - "access_token": "accesstoken2", - "token_type": "Bearer", - "expires_in": 59, - }) - .to_string(), - ) - .expect(1) - .create(); + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("refresh_token", "refreshtoken")), + contains_entry(("client_id", matches("^9022167"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken2", + "token_type": "Bearer", + "expires_in": 59, + }))), + ); let tok = auth .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); assert_eq!("accesstoken2", tok.as_str()); - _m.assert(); - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*client_id=9022167.*refresh_token=refreshtoken.*".to_string(), - )) - .with_body( - serde_json::json!({ - "access_token": "accesstoken3", - "token_type": "Bearer", - "expires_in": 59, - }) - .to_string(), - ) - .expect(1) - .create(); + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("refresh_token", "refreshtoken")), + contains_entry(("client_id", matches("^9022167"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken3", + "token_type": "Bearer", + "expires_in": 59, + }))), + ); let tok = auth .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); assert_eq!("accesstoken3", tok.as_str()); - _m.assert(); - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*client_id=9022167.*refresh_token=refreshtoken.*".to_string(), - )) - .with_body( - serde_json::json!({ - "error": "invalid_request", - }) - .to_string(), - ) - .expect(1) - .create(); + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("refresh_token", "refreshtoken")), + contains_entry(("client_id", matches("^9022167"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "error": "invalid_request", + }))), + ); let tok_err = auth .token(&["https://googleapis.com/some/scope"]) @@ -463,27 +486,30 @@ async fn test_refresh() { }) => {} e => panic!("unexpected error on refresh: {:?}", e), } - _m.assert(); } #[tokio::test] async fn test_memory_storage() { - let auth = create_installed_flow_auth(InstalledFlowReturnMethod::Interactive, None).await; - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*code=authorizationcode.*client_id=9022167.*".to_string(), - )) - .with_body( - serde_json::json!({ - "access_token": "accesstoken", - "refresh_token": "refreshtoken", - "token_type": "Bearer", - "expires_in": 12345678 - }) - .to_string(), - ) - .expect(1) - .create(); + let _ = env_logger::try_init(); + let server = Server::run(); + let auth = + create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("code", "authorizationcode")), + contains_entry(("client_id", matches("^9022167"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 12345678, + }))), + ); // Call token twice. Ensure that only one http request is made and // identical tokens are returned. @@ -497,59 +523,63 @@ async fn test_memory_storage() { .expect("failed to get token"); assert_eq!(token1.as_str(), "accesstoken"); assert_eq!(token1, token2); - _m.assert(); // Create a new authenticator. This authenticator does not share a cache // with the previous one. Validate that it receives a different token. - let auth2 = create_installed_flow_auth(InstalledFlowReturnMethod::Interactive, None).await; - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*code=authorizationcode.*client_id=9022167.*".to_string(), - )) - .with_body( - serde_json::json!({ - "access_token": "accesstoken2", - "refresh_token": "refreshtoken2", - "token_type": "Bearer", - "expires_in": 12345678 - }) - .to_string(), - ) - .expect(1) - .create(); + let auth2 = + create_installed_flow_auth(&server, InstalledFlowReturnMethod::Interactive, None).await; + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("code", "authorizationcode")), + contains_entry(("client_id", matches("^9022167"))), + ])) + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken2", + "refresh_token": "refreshtoken2", + "token_type": "Bearer", + "expires_in": 12345678, + }))), + ); let token3 = auth2 .token(&["https://googleapis.com/some/scope"]) .await .expect("failed to get token"); assert_eq!(token3.as_str(), "accesstoken2"); - _m.assert(); } #[tokio::test] async fn test_disk_storage() { + let _ = env_logger::try_init(); + let server = Server::run(); let tempdir = tempfile::tempdir().unwrap(); let storage_path = tempdir.path().join("tokenstorage.json"); + server.expect( + Expectation::matching(all_of![ + request::method("POST"), + request::path("/token"), + request::body(url_decoded(all_of![ + contains_entry(("code", "authorizationcode")), + contains_entry(("client_id", matches("^9022167"))), + ])), + ]) + .respond_with(json_encoded(serde_json::json!({ + "access_token": "accesstoken", + "refresh_token": "refreshtoken", + "token_type": "Bearer", + "expires_in": 12345678 + }))), + ); { let auth = create_installed_flow_auth( + &server, InstalledFlowReturnMethod::Interactive, Some(storage_path.clone()), ) .await; - let _m = mockito::mock("POST", "/token") - .match_body(mockito::Matcher::Regex( - ".*code=authorizationcode.*client_id=9022167.*".to_string(), - )) - .with_body( - serde_json::json!({ - "access_token": "accesstoken", - "refresh_token": "refreshtoken", - "token_type": "Bearer", - "expires_in": 12345678 - }) - .to_string(), - ) - .expect(1) - .create(); // Call token twice. Ensure that only one http request is made and // identical tokens are returned. @@ -563,13 +593,13 @@ async fn test_disk_storage() { .expect("failed to get token"); assert_eq!(token1.as_str(), "accesstoken"); assert_eq!(token1, token2); - _m.assert(); } // Create a new authenticator. This authenticator uses the same token // storage file as the previous one so should receive a token without // making any http requests. let auth = create_installed_flow_auth( + &server, InstalledFlowReturnMethod::Interactive, Some(storage_path.clone()), )