From 696577aa01c2bda04714e82081020e92a896f1b5 Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Fri, 8 Nov 2019 12:43:17 -0800 Subject: [PATCH] Accept scopes as a slice of anything that can produce a &str. Along with the public facing change the implementation has been modified to no longer clone the scopes instead using the pointer to the scopes the user provided. This greatly reduces the number of allocations on each token() call. Note that this also changes the hashing method used for token storage in an incompatible way with the previous implementation. The previous implementation pre-sorted the vector and hashed the contents to make the result independent of the ordering of the scopes. Instead we now combine the hash values of each scope together with XOR, thus producing a hash value that does not depend on order without needing to allocate another vector and sort. --- examples/test-device/src/main.rs | 2 +- examples/test-installed/src/main.rs | 3 +- examples/test-svc-acct/src/main.rs | 5 +- src/authenticator.rs | 26 ++++---- src/device.rs | 43 +++++++------- src/helper.rs | 22 +++++++ src/installed.rs | 42 ++++++------- src/lib.rs | 3 +- src/service_account.rs | 53 +++++++++-------- src/storage.rs | 92 ++++++++++++++++------------- src/types.rs | 7 +-- 11 files changed, 166 insertions(+), 132 deletions(-) diff --git a/examples/test-device/src/main.rs b/examples/test-device/src/main.rs index 42b3ab8..64e413c 100644 --- a/examples/test-device/src/main.rs +++ b/examples/test-device/src/main.rs @@ -12,7 +12,7 @@ async fn main() { .build() .expect("authenticator"); - let scopes = vec!["https://www.googleapis.com/auth/youtube.readonly"]; + let scopes = &["https://www.googleapis.com/auth/youtube.readonly"]; match auth.token(scopes).await { Err(e) => println!("error: {:?}", e), Ok(t) => println!("token: {:?}", t), diff --git a/examples/test-installed/src/main.rs b/examples/test-installed/src/main.rs index f4909a3..c333255 100644 --- a/examples/test-installed/src/main.rs +++ b/examples/test-installed/src/main.rs @@ -15,8 +15,7 @@ async fn main() { .persist_tokens_to_disk("tokencache.json") .build() .unwrap(); - let s = "https://www.googleapis.com/auth/drive.file".to_string(); - let scopes = vec![s]; + let scopes = &["https://www.googleapis.com/auth/drive.file"]; match auth.token(scopes).await { Err(e) => println!("error: {:?}", e), diff --git a/examples/test-svc-acct/src/main.rs b/examples/test-svc-acct/src/main.rs index 3d18fdc..6ad49f1 100644 --- a/examples/test-svc-acct/src/main.rs +++ b/examples/test-svc-acct/src/main.rs @@ -8,14 +8,15 @@ async fn main() { let creds = yup_oauth2::service_account_key_from_file(path::Path::new("serviceaccount.json")).unwrap(); let sa = yup_oauth2::ServiceAccountAccess::new(creds).build(); + let scopes = &["https://www.googleapis.com/auth/pubsub"]; let tok = sa - .token(vec!["https://www.googleapis.com/auth/pubsub"]) + .token(scopes) .await .unwrap(); println!("token is: {:?}", tok); let tok = sa - .token(vec!["https://www.googleapis.com/auth/pubsub"]) + .token(scopes) .await .unwrap(); println!("cached token is {:?} and should be identical", tok); diff --git a/src/authenticator.rs b/src/authenticator.rs index f39461f..bbd6781 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -192,7 +192,11 @@ where AD: 'static + AuthenticatorDelegate, C: 'static + hyper::client::connect::Connect + Clone + Send, { - async fn get_token(&self, scope_key: u64, scopes: Vec) -> Result { + async fn get_token(&self, scopes: &[T]) -> Result + where + T: AsRef + Sync, + { + let scope_key = hash_scopes(scopes); let store = self.store.clone(); let delegate = &self.delegate; let client = self.client.clone(); @@ -200,8 +204,8 @@ where let gettoken = self.inner.clone(); loop { match store.get( - scope_key.clone(), - &scopes, + scope_key, + scopes, ) { Ok(Some(t)) => { if !t.expired() { @@ -234,7 +238,7 @@ where RefreshResult::Success(t) => { let x = store.set( scope_key, - &scopes, + scopes, Some(t.clone()), ); if let Err(e) = x { @@ -251,10 +255,10 @@ where } Ok(None) => { let store = store.clone(); - let t = gettoken.token(scopes.clone()).await?; + let t = gettoken.token(scopes).await?; if let Err(e) = store.set( scope_key, - &scopes, + scopes, Some(t.clone()), ) { match delegate.token_storage_failure(true, &e) { @@ -291,15 +295,13 @@ impl< self.inner.application_secret() } - fn token<'a, I, T>( + fn token<'a, T>( &'a self, - scopes: I, + scopes: &'a [T], ) -> Pin> + Send + 'a>> where - T: Into, - I: IntoIterator, + T: AsRef + Sync, { - let (scope_key, scopes) = hash_scopes(scopes); - Box::pin(self.get_token(scope_key, scopes)) + Box::pin(self.get_token(scopes)) } } diff --git a/src/device.rs b/src/device.rs index a924f17..018bde7 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,4 +1,3 @@ -use std::iter::{FromIterator, IntoIterator}; use std::pin::Pin; use std::time::Duration; @@ -7,7 +6,6 @@ use chrono::{self, Utc}; use futures::{prelude::*}; use hyper; use hyper::header; -use itertools::Itertools; use serde_json as json; use url::form_urlencoded; @@ -104,15 +102,14 @@ where FD: FlowDelegate + 'static, C: hyper::client::connect::Connect + 'static, { - fn token<'a, I, T>( + fn token<'a, T>( &'a self, - scopes: I, + scopes: &'a [T], ) -> Pin> + Send + 'a>> where - T: Into, - I: IntoIterator, + T: AsRef + Sync, { - Box::pin(self.retrieve_device_token(Vec::from_iter(scopes.into_iter().map(Into::into)))) + Box::pin(self.retrieve_device_token(scopes)) } fn api_key(&self) -> Option { None @@ -131,10 +128,13 @@ where { /// Essentially what `GetToken::token` does: Retrieve a token for the given scopes without /// caching. - pub async fn retrieve_device_token<'a>( + pub async fn retrieve_device_token( &self, - scopes: Vec, - ) -> Result { + scopes: &[T], + ) -> Result + where + T: AsRef, + { let application_secret = self.application_secret.clone(); let client = self.client.clone(); let wait = self.wait; @@ -193,24 +193,21 @@ where /// * If called after a successful result was returned at least once. /// # Examples /// See test-cases in source code for a more complete example. - async fn request_code( + async fn request_code( application_secret: ApplicationSecret, client: hyper::Client, device_code_url: String, - scopes: Vec, - ) -> Result<(PollInformation, String), RequestError> { + scopes: &[T], + ) -> Result<(PollInformation, String), RequestError> + where + T: AsRef, + { // note: cloned() shouldn't be needed, see issue // https://github.com/servo/rust-url/issues/81 let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ ("client_id", application_secret.client_id.clone()), - ( - "scope", - scopes - .into_iter() - .intersperse(" ".to_string()) - .collect::(), - ), + ("scope", crate::helper::join(scopes, " ")), ]) .finish(); @@ -409,7 +406,7 @@ mod tests { let fut = async { let token = flow - .token(vec!["https://www.googleapis.com/scope/1"]) + .token(&["https://www.googleapis.com/scope/1"]) .await .unwrap(); assert_eq!("accesstoken", token.access_token); @@ -441,7 +438,7 @@ mod tests { .create(); let fut = async { - let res = flow.token(vec!["https://www.googleapis.com/scope/1"]).await; + let res = flow.token(&["https://www.googleapis.com/scope/1"]).await; assert!(res.is_err()); assert!(format!("{}", res.unwrap_err()).contains("invalid_client_id")); Ok(()) as Result<(), ()> @@ -471,7 +468,7 @@ mod tests { .create(); let fut = async { - let res = flow.token(vec!["https://www.googleapis.com/scope/1"]).await; + let res = flow.token(&["https://www.googleapis.com/scope/1"]).await; assert!(res.is_err()); assert!(format!("{}", res.unwrap_err()).contains("Access denied by user")); Ok(()) as Result<(), ()> diff --git a/src/helper.rs b/src/helper.rs index c9471fc..7ef7d06 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -61,3 +61,25 @@ pub fn service_account_key_from_file>(path: S) -> io::Result Ok(decoded), } } + +pub(crate) fn join(pieces: &[T], separator: &str) -> String +where + T: AsRef, +{ + let mut iter = pieces.iter(); + let first = match iter.next() { + Some(p) => p, + None => return String::new(), + }; + let num_separators = pieces.len() - 1; + let pieces_size: usize = pieces.iter().map(|p| p.as_ref().len()).sum(); + let size = pieces_size + separator.len() * num_separators; + let mut result = String::with_capacity(size); + result.push_str(first.as_ref()); + for p in iter { + result.push_str(separator); + result.push_str(p.as_ref()); + } + debug_assert_eq!(size, result.len()); + result +} \ No newline at end of file diff --git a/src/installed.rs b/src/installed.rs index ba6758e..26ee573 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -66,15 +66,14 @@ where FD: FlowDelegate + 'static, C: hyper::client::connect::Connect + 'static, { - fn token<'a, I, T>( + fn token<'a, T>( &'a self, - scopes: I, + scopes: &'a [T], ) -> Pin> + Send + 'a>> where - T: Into, - I: IntoIterator, + T: AsRef + Sync, { - Box::pin(self.obtain_token(scopes.into_iter().map(Into::into).collect())) + Box::pin(self.obtain_token(scopes)) } fn api_key(&self) -> Option { None @@ -175,27 +174,29 @@ where /// . Return that token /// /// It's recommended not to use the DefaultFlowDelegate, but a specialized one. - async fn obtain_token<'a>( + async fn obtain_token( &self, - scopes: Vec, // Note: I haven't found a better way to give a list of strings here, due to ownership issues with futures. - ) -> Result { + scopes: &[T], + ) -> Result + where + T: AsRef, + { match self.method { InstalledFlowReturnMethod::HTTPRedirect(port) => { - self.ask_auth_code_via_http(scopes.iter(), port).await + self.ask_auth_code_via_http(scopes, port).await } InstalledFlowReturnMethod::HTTPRedirectEphemeral => { - self.ask_auth_code_via_http(scopes.iter(), 0).await + self.ask_auth_code_via_http(scopes, 0).await } InstalledFlowReturnMethod::Interactive => { - self.ask_auth_code_interactively(scopes.iter()).await + self.ask_auth_code_interactively(scopes).await } } } - async fn ask_auth_code_interactively<'a, S, T>(&self, scopes: S) -> Result + async fn ask_auth_code_interactively(&self, scopes: &[T]) -> Result where - T: AsRef + 'a, - S: Iterator, + T: AsRef, { let auth_delegate = &self.fd; let appsecret = &self.appsecret; @@ -223,14 +224,13 @@ where self.exchange_auth_code(authcode, None).await } - async fn ask_auth_code_via_http<'a, S, T>( + async fn ask_auth_code_via_http( &self, - scopes: S, + scopes: &[T], desired_port: u16, ) -> Result where - T: AsRef + 'a, - S: Iterator, + T: AsRef, { let auth_delegate = &self.fd; let appsecret = &self.appsecret; @@ -583,7 +583,7 @@ mod tests { let fut = || { async { let tok = inf - .token(vec!["https://googleapis.com/some/scope"]) + .token(&["https://googleapis.com/some/scope"]) .await .map_err(|_| ())?; assert_eq!("accesstoken", tok.access_token); @@ -612,7 +612,7 @@ mod tests { let fut = async { let tok = inf - .token(vec!["https://googleapis.com/some/scope"]) + .token(&["https://googleapis.com/some/scope"]) .await .map_err(|_| ())?; assert_eq!("accesstoken", tok.access_token); @@ -635,7 +635,7 @@ mod tests { .create(); let fut = async { - let tokr = inf.token(vec!["https://googleapis.com/some/scope"]).await; + let tokr = inf.token(&["https://googleapis.com/some/scope"]).await; assert!(tokr.is_err()); assert!(format!("{}", tokr.unwrap_err()).contains("invalid_code")); Ok(()) as Result<(), ()> diff --git a/src/lib.rs b/src/lib.rs index 523e703..181e1e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,8 +65,7 @@ //! .build() //! .unwrap(); //! -//! let s = "https://www.googleapis.com/auth/drive.file".to_string(); -//! let scopes = vec![s]; +//! let scopes = &["https://www.googleapis.com/auth/drive.file"]; //! //! // token() is the one important function of this crate; it does everything to //! // obtain a token that can be sent e.g. as Bearer token. diff --git a/src/service_account.rs b/src/service_account.rs index 355f77e..88831c8 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -157,17 +157,15 @@ impl JWT { } } -/// Set `iss`, `aud`, `exp`, `iat`, `scope` field in the returned `Claims`. `scopes` is an iterator -/// yielding strings with OAuth scopes. -fn init_claims_from_key<'a, I, T>(key: &ServiceAccountKey, scopes: I) -> Claims +/// Set `iss`, `aud`, `exp`, `iat`, `scope` field in the returned `Claims`. +fn init_claims_from_key(key: &ServiceAccountKey, scopes: &[T]) -> Claims where - T: AsRef + 'a, - I: IntoIterator, + T: AsRef, { let iat = chrono::Utc::now().timestamp(); let expiry = iat + 3600 - 5; // Max validity is 1h. - let mut scopes_string = scopes.into_iter().fold(String::new(), |mut acc, sc| { + let mut scopes_string = scopes.iter().fold(String::new(), |mut acc, sc| { acc.push_str(sc.as_ref()); acc.push_str(" "); acc @@ -271,13 +269,16 @@ where C: hyper::client::connect::Connect + 'static, { /// Send a request for a new Bearer token to the OAuth provider. - async fn request_token( + async fn request_token( client: hyper::client::Client, sub: Option, key: ServiceAccountKey, - scopes: Vec, - ) -> Result { - let mut claims = init_claims_from_key(&key, &scopes); + scopes: &[T], + ) -> Result + where + T: AsRef, + { + let mut claims = init_claims_from_key(&key, scopes); claims.sub = sub.clone(); let signed = JWT::new(claims) .sign(key.private_key.as_ref().unwrap()) @@ -335,12 +336,16 @@ where Ok(token) } - async fn get_token(&self, hash: u64, scopes: Vec) -> Result { + async fn get_token(&self, scopes: &[T]) -> Result + where + T: AsRef, + { + let hash = hash_scopes(scopes); let cache = self.cache.clone(); match cache .lock() .unwrap() - .get(hash, scopes.iter()) + .get(hash, scopes) { Ok(Some(token)) if !token.expired() => return Ok(token), _ => {} @@ -349,12 +354,12 @@ where self.client.clone(), self.sub.clone(), self.key.clone(), - scopes.iter().map(|s| s.to_string()).collect(), + scopes, ) .await?; let _ = cache.lock().unwrap().set( hash, - scopes.iter(), + scopes, Some(token.clone()), ); Ok(token) @@ -365,16 +370,14 @@ impl GetToken for ServiceAccountAccessImpl where C: hyper::client::connect::Connect + 'static, { - fn token<'a, I, T>( + fn token<'a, T>( &'a self, - scopes: I, + scopes: &'a [T], ) -> Pin> + Send + 'a>> where - T: Into, - I: IntoIterator, + T: AsRef + Sync, { - let (hash, scps0) = hash_scopes(scopes); - Box::pin(self.get_token(hash, scps0)) + Box::pin(self.get_token(scopes)) } /// Returns an empty ApplicationSecret as tokens for service accounts don't need to be @@ -449,7 +452,7 @@ mod tests { let acc = ServiceAccountAccessImpl::new(client.clone(), key.clone(), None); let fut = async { let tok = acc - .token(vec!["https://www.googleapis.com/auth/pubsub"]) + .token(&["https://www.googleapis.com/auth/pubsub"]) .await?; assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); assert_eq!(Some(3600), tok.expires_in); @@ -463,14 +466,14 @@ mod tests { .unwrap() .get( 3502164897243251857, - ["https://www.googleapis.com/auth/pubsub"].iter(), + &["https://www.googleapis.com/auth/pubsub"], ) .unwrap() .is_some()); // Test that token is in cache (otherwise mock will tell us) let fut = async { let tok = acc - .token(vec!["https://www.googleapis.com/auth/pubsub"]) + .token(&["https://www.googleapis.com/auth/pubsub"]) .await?; assert!(tok.access_token.contains("ya29.c.ElouBywiys0Ly")); assert_eq!(Some(3600), tok.expires_in); @@ -492,7 +495,7 @@ mod tests { .build(); let fut = async { let result = acc - .token(vec!["https://www.googleapis.com/auth/pubsub"]) + .token(&["https://www.googleapis.com/auth/pubsub"]) .await; assert!(result.is_err()); Ok(()) as Result<(), ()> @@ -522,7 +525,7 @@ mod tests { rt.block_on(async { println!( "{:?}", - acc.token(vec!["https://www.googleapis.com/auth/pubsub"]) + acc.token(&["https://www.googleapis.com/auth/pubsub"]) .await ); }); diff --git a/src/storage.rs b/src/storage.rs index 551c563..c1403ca 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -26,34 +26,35 @@ pub trait TokenStorage: Send + Sync { /// If `token` is None, it is invalid or revoked and should be removed from storage. /// Otherwise, it should be saved. - fn set( + fn set( &self, scope_hash: u64, - scopes: I, + scopes: &[T], token: Option, ) -> Result<(), Self::Error> where - I: IntoIterator, - I::Item: AsRef; + T: AsRef; /// A `None` result indicates that there is no token for the given scope_hash. - fn get(&self, scope_hash: u64, scopes: I) -> Result, Self::Error> + fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> where - I: IntoIterator + Clone, - I::Item: AsRef; + T: AsRef; } -/// Calculate a hash value describing the scopes, and return a sorted Vec of the scopes. -pub fn hash_scopes(scopes: I) -> (u64, Vec) +/// Calculate a hash value describing the scopes. The order of the scopes in the +/// list does not change the hash value. i.e. two lists that contains the exact +/// same scopes, but in different order will return the same hash value. +pub fn hash_scopes(scopes: &[T]) -> u64 where - T: Into, - I: IntoIterator, + T: AsRef, { - let mut sv: Vec = scopes.into_iter().map(Into::into).collect(); - sv.sort(); - let mut sh = DefaultHasher::new(); - sv.hash(&mut sh); - (sh.finish(), sv) + let mut hash_sum = DefaultHasher::new().finish(); + for scope in scopes { + let mut hasher = DefaultHasher::new(); + scope.as_ref().hash(&mut hasher); + hash_sum ^= hasher.finish(); + } + hash_sum } /// A storage that remembers nothing. @@ -62,18 +63,16 @@ pub struct NullStorage; impl TokenStorage for NullStorage { type Error = std::convert::Infallible; - fn set(&self, _: u64, _: I, _: Option) -> Result<(), Self::Error> + fn set(&self, _: u64, _: &[T], _: Option) -> Result<(), Self::Error> where - I: IntoIterator, - I::Item: AsRef, + T: AsRef { Ok(()) } - fn get(&self, _: u64, _: I) -> Result, Self::Error> + fn get(&self, _: u64, _: &[T]) -> Result, Self::Error> where - I: IntoIterator + Clone, - I::Item: AsRef, + T: AsRef { Ok(None) } @@ -94,15 +93,14 @@ impl MemoryStorage { impl TokenStorage for MemoryStorage { type Error = std::convert::Infallible; - fn set( + fn set( &self, scope_hash: u64, - scopes: I, + scopes: &[T], token: Option, ) -> Result<(), Self::Error> where - I: IntoIterator, - I::Item: AsRef, + T: AsRef { let mut tokens = self.tokens.lock().expect("poisoned mutex"); let matched = tokens.iter().find_position(|x| x.hash == scope_hash); @@ -124,10 +122,9 @@ impl TokenStorage for MemoryStorage { Ok(()) } - fn get(&self, scope_hash: u64, scopes: I) -> Result, Self::Error> + fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> where - I: IntoIterator + Clone, - I::Item: AsRef, + T: AsRef { let tokens = self.tokens.lock().expect("poisoned mutex"); Ok(token_for_scopes(&tokens, scope_hash, scopes)) @@ -225,15 +222,14 @@ fn load_from_file(filename: &Path) -> Result, io::Error> { impl TokenStorage for DiskTokenStorage { type Error = io::Error; - fn set( + fn set( &self, scope_hash: u64, - scopes: I, + scopes: &[T], token: Option, ) -> Result<(), Self::Error> where - I: IntoIterator, - I::Item: AsRef, + T: AsRef { { let mut tokens = self.tokens.lock().expect("poisoned mutex"); @@ -257,24 +253,22 @@ impl TokenStorage for DiskTokenStorage { self.dump_to_file() } - fn get(&self, scope_hash: u64, scopes: I) -> Result, Self::Error> + fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> where - I: IntoIterator + Clone, - I::Item: AsRef, + T: AsRef { let tokens = self.tokens.lock().expect("poisoned mutex"); Ok(token_for_scopes(&tokens, scope_hash, scopes)) } } -fn token_for_scopes(tokens: &[JSONToken], scope_hash: u64, scopes: I) -> Option +fn token_for_scopes(tokens: &[JSONToken], scope_hash: u64, scopes: &[T]) -> Option where - I: IntoIterator + Clone, - I::Item: AsRef, + T: AsRef, { for t in tokens.iter() { if let Some(token_scopes) = &t.scopes { - if scopes.clone().into_iter().all(|s| token_scopes.iter().any(|t| t == s.as_ref())) { + if scopes.iter().all(|s| token_scopes.iter().any(|t| t == s.as_ref())) { return Some(t.token.clone()); } } else if scope_hash == t.hash { @@ -283,3 +277,21 @@ where } None } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hash_scopes() { + // Idential list should hash equal. + assert_eq!(hash_scopes(&["foo", "bar"]), hash_scopes(&["foo", "bar"])); + // The hash should be order independent. + assert_eq!(hash_scopes(&["bar", "foo"]), hash_scopes(&["foo", "bar"])); + assert_eq!(hash_scopes(&["bar", "baz", "bat"]), hash_scopes(&["baz", "bar", "bat"])); + + // Ensure hashes differ when the contents are different by more than + // just order. + assert_ne!(hash_scopes(&["foo", "bar", "baz"]), hash_scopes(&["foo", "bar"])); + } +} diff --git a/src/types.rs b/src/types.rs index a47a0fa..5d5b54f 100644 --- a/src/types.rs +++ b/src/types.rs @@ -236,13 +236,12 @@ impl FromStr for Scheme { /// The `api_key()` method is an alternative in case there are no scopes or /// if no user is involved. pub trait GetToken: Send + Sync { - fn token<'a, I, T>( + fn token<'a, T>( &'a self, - scopes: I, + scopes: &'a [T], ) -> Pin> + Send + 'a>> where - T: Into, - I: IntoIterator; + T: AsRef + Sync; fn api_key(&self) -> Option;