From 7e210a22c5459362d8ffd226b56bc38aa1de152d Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Thu, 7 Nov 2019 16:22:09 -0800 Subject: [PATCH] Have TokenStorage take scopes by iterator rather than Vec. This reduces the number of allocations needed. --- src/authenticator.rs | 8 ++- src/service_account.rs | 6 +-- src/storage.rs | 115 ++++++++++++++++++++++++----------------- 3 files changed, 73 insertions(+), 56 deletions(-) diff --git a/src/authenticator.rs b/src/authenticator.rs index 6e17d2e..f39461f 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -201,7 +201,7 @@ where loop { match store.get( scope_key.clone(), - &scopes.iter().map(|s| s.as_str()).collect(), + &scopes, ) { Ok(Some(t)) => { if !t.expired() { @@ -210,7 +210,6 @@ where // Implement refresh flow. let refresh_token = t.refresh_token.clone(); let store = store.clone(); - let scopes = scopes.clone(); let rr = RefreshFlow::refresh_token( client.clone(), appsecret.clone(), @@ -235,7 +234,7 @@ where RefreshResult::Success(t) => { let x = store.set( scope_key, - &scopes.iter().map(|s| s.as_str()).collect(), + &scopes, Some(t.clone()), ); if let Err(e) = x { @@ -252,11 +251,10 @@ where } Ok(None) => { let store = store.clone(); - let scopes = scopes.clone(); let t = gettoken.token(scopes.clone()).await?; if let Err(e) = store.set( scope_key, - &scopes.iter().map(|s| s.as_str()).collect(), + &scopes, Some(t.clone()), ) { match delegate.token_storage_failure(true, &e) { diff --git a/src/service_account.rs b/src/service_account.rs index c56a9cb..355f77e 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -340,7 +340,7 @@ where match cache .lock() .unwrap() - .get(hash, &scopes.iter().map(|s| s.as_str()).collect()) + .get(hash, scopes.iter()) { Ok(Some(token)) if !token.expired() => return Ok(token), _ => {} @@ -354,7 +354,7 @@ where .await?; let _ = cache.lock().unwrap().set( hash, - &scopes.iter().map(|s| s.as_str()).collect(), + scopes.iter(), Some(token.clone()), ); Ok(token) @@ -463,7 +463,7 @@ mod tests { .unwrap() .get( 3502164897243251857, - &vec!["https://www.googleapis.com/auth/pubsub"] + ["https://www.googleapis.com/auth/pubsub"].iter(), ) .unwrap() .is_some()); diff --git a/src/storage.rs b/src/storage.rs index a1224e7..551c563 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -26,14 +26,21 @@ pub trait TokenStorage: Send + Sync { /// If `token` is None, it is invalid or revoked and should be removed from storage. /// Otherwise, it should be saved. - fn set( + fn set( &self, scope_hash: u64, - scopes: &Vec<&str>, + scopes: I, token: Option, - ) -> Result<(), Self::Error>; + ) -> Result<(), Self::Error> + where + I: IntoIterator, + I::Item: AsRef; + /// A `None` result indicates that there is no token for the given scope_hash. - fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, Self::Error>; + fn get(&self, scope_hash: u64, scopes: I) -> Result, Self::Error> + where + I: IntoIterator + Clone, + I::Item: AsRef; } /// Calculate a hash value describing the scopes, and return a sorted Vec of the scopes. @@ -55,10 +62,19 @@ pub struct NullStorage; impl TokenStorage for NullStorage { type Error = std::convert::Infallible; - fn set(&self, _: u64, _: &Vec<&str>, _: Option) -> Result<(), Self::Error> { + fn set(&self, _: u64, _: I, _: Option) -> Result<(), Self::Error> + where + I: IntoIterator, + I::Item: AsRef, + { Ok(()) } - fn get(&self, _: u64, _: &Vec<&str>) -> Result, Self::Error> { + + fn get(&self, _: u64, _: I) -> Result, Self::Error> + where + I: IntoIterator + Clone, + I::Item: AsRef, + { Ok(None) } } @@ -78,12 +94,16 @@ impl MemoryStorage { impl TokenStorage for MemoryStorage { type Error = std::convert::Infallible; - fn set( + fn set( &self, scope_hash: u64, - scopes: &Vec<&str>, + scopes: I, token: Option, - ) -> Result<(), Self::Error> { + ) -> Result<(), Self::Error> + where + I: IntoIterator, + I::Item: AsRef, + { let mut tokens = self.tokens.lock().expect("poisoned mutex"); let matched = tokens.iter().find_position(|x| x.hash == scope_hash); if let Some((idx, _)) = matched { @@ -94,7 +114,7 @@ impl TokenStorage for MemoryStorage { Some(t) => { tokens.push(JSONToken { hash: scope_hash, - scopes: Some(scopes.iter().map(|x| x.to_string()).collect()), + scopes: Some(scopes.into_iter().map(|x| x.as_ref().to_string()).collect()), token: t.clone(), }); () @@ -104,24 +124,13 @@ impl TokenStorage for MemoryStorage { Ok(()) } - fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, Self::Error> { - let scopes: Vec<_> = scopes.iter().sorted().unique().collect(); - + fn get(&self, scope_hash: u64, scopes: I) -> Result, Self::Error> + where + I: IntoIterator + Clone, + I::Item: AsRef, + { let tokens = self.tokens.lock().expect("poisoned mutex"); - for t in tokens.iter() { - if let Some(token_scopes) = &t.scopes { - let matched = token_scopes - .iter() - .filter(|x| scopes.contains(&&&x[..])) - .count(); - if matched >= scopes.len() { - return Result::Ok(Some(t.token.clone())); - } - } else if scope_hash == t.hash { - return Result::Ok(Some(t.token.clone())); - } - } - Result::Ok(None) + Ok(token_for_scopes(&tokens, scope_hash, scopes)) } } @@ -216,12 +225,16 @@ fn load_from_file(filename: &Path) -> Result, io::Error> { impl TokenStorage for DiskTokenStorage { type Error = io::Error; - fn set( + fn set( &self, scope_hash: u64, - scopes: &Vec<&str>, + scopes: I, token: Option, - ) -> Result<(), Self::Error> { + ) -> Result<(), Self::Error> + where + I: IntoIterator, + I::Item: AsRef, + { { let mut tokens = self.tokens.lock().expect("poisoned mutex"); let matched = tokens.iter().find_position(|x| x.hash == scope_hash); @@ -234,7 +247,7 @@ impl TokenStorage for DiskTokenStorage { Some(t) => { tokens.push(JSONToken { hash: scope_hash, - scopes: Some(scopes.iter().map(|x| x.to_string()).collect()), + scopes: Some(scopes.into_iter().map(|x| x.as_ref().to_string()).collect()), token: t.clone(), }); () @@ -243,24 +256,30 @@ impl TokenStorage for DiskTokenStorage { } self.dump_to_file() } - fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, Self::Error> { - let scopes: Vec<_> = scopes.iter().sorted().unique().collect(); + fn get(&self, scope_hash: u64, scopes: I) -> Result, Self::Error> + where + I: IntoIterator + Clone, + I::Item: AsRef, + { let tokens = self.tokens.lock().expect("poisoned mutex"); - for t in tokens.iter() { - if let Some(token_scopes) = &t.scopes { - let matched = token_scopes - .iter() - .filter(|x| scopes.contains(&&&x[..])) - .count(); - // we may have some of the tokens as denormalized (many namespaces repeated) - if matched >= scopes.len() { - return Result::Ok(Some(t.token.clone())); - } - } else if scope_hash == t.hash { - return Result::Ok(Some(t.token.clone())); - } - } - Result::Ok(None) + Ok(token_for_scopes(&tokens, scope_hash, scopes)) } } + +fn token_for_scopes(tokens: &[JSONToken], scope_hash: u64, scopes: I) -> Option +where + I: IntoIterator + Clone, + I::Item: AsRef, +{ + for t in tokens.iter() { + if let Some(token_scopes) = &t.scopes { + if scopes.clone().into_iter().all(|s| token_scopes.iter().any(|t| t == s.as_ref())) { + return Some(t.token.clone()); + } + } else if scope_hash == t.hash { + return Some(t.token.clone()) + } + } + None +}