diff --git a/Cargo.toml b/Cargo.toml index e1ca60f..88618be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,8 +16,7 @@ chrono = "0.4" http = "0.1" hyper = {version = "0.13.0-alpha.4", features = ["unstable-stream"]} hyper-rustls = "=0.18.0-alpha.2" -itertools = "0.8" -log = "0.3" +log = "0.4" rustls = "0.16" serde = "1.0" serde_json = "1.0" diff --git a/examples/test-device/src/main.rs b/examples/test-device/src/main.rs index 64e413c..62bdf10 100644 --- a/examples/test-device/src/main.rs +++ b/examples/test-device/src/main.rs @@ -10,6 +10,7 @@ async fn main() { let auth = Authenticator::new(DeviceFlow::new(creds)) .persist_tokens_to_disk("tokenstorage.json") .build() + .await .expect("authenticator"); let scopes = &["https://www.googleapis.com/auth/youtube.readonly"]; diff --git a/examples/test-installed/src/main.rs b/examples/test-installed/src/main.rs index 54be93f..1bdfee0 100644 --- a/examples/test-installed/src/main.rs +++ b/examples/test-installed/src/main.rs @@ -14,6 +14,7 @@ async fn main() { )) .persist_tokens_to_disk("tokencache.json") .build() + .await .unwrap(); let scopes = &["https://www.googleapis.com/auth/drive.file"]; diff --git a/src/authenticator.rs b/src/authenticator.rs index dd2292b..b332e5a 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -1,14 +1,15 @@ -use crate::authenticator_delegate::{AuthenticatorDelegate, DefaultAuthenticatorDelegate, Retry}; +use crate::authenticator_delegate::{AuthenticatorDelegate, DefaultAuthenticatorDelegate}; use crate::refresh::RefreshFlow; -use crate::storage::{hash_scopes, DiskTokenStorage, MemoryStorage, TokenStorage}; +use crate::storage::{self, Storage}; use crate::types::{ApplicationSecret, GetToken, RefreshResult, RequestError, Token}; use futures::prelude::*; use std::error::Error; use std::io; -use std::path::Path; +use std::path::PathBuf; use std::pin::Pin; +use std::sync::Mutex; /// Authenticator abstracts different `GetToken` implementations behind one type and handles /// caching received tokens. It's important to use it (instead of the flows directly) because @@ -20,15 +21,11 @@ use std::pin::Pin; /// NOTE: It is recommended to use a client constructed like this in order to prevent functions /// like `hyper::run()` from hanging: `let client = hyper::Client::builder().keep_alive(false);`. /// Due to token requests being rare, this should not result in a too bad performance problem. -struct AuthenticatorImpl< - T: GetToken, - S: TokenStorage, - AD: AuthenticatorDelegate, - C: hyper::client::connect::Connect, -> { +struct AuthenticatorImpl +{ client: hyper::Client, inner: T, - store: S, + store: Storage, delegate: AD, } @@ -69,17 +66,22 @@ pub trait AuthFlow { fn build_token_getter(self, client: hyper::Client) -> Self::TokenGetter; } +enum StorageType { + Memory, + Disk(PathBuf), +} + /// An authenticator can be used with `InstalledFlow`'s or `DeviceFlow`'s and /// will refresh tokens as they expire as well as optionally persist tokens to /// disk. -pub struct Authenticator { +pub struct Authenticator { client: C, token_getter: T, - store: io::Result, + storage_type: StorageType, delegate: AD, } -impl Authenticator +impl Authenticator where T: AuthFlow<::Connector>, { @@ -90,27 +92,27 @@ where /// /// Examples /// ``` + /// # #[tokio::main] + /// # async fn main() { /// use std::path::Path; /// use yup_oauth2::{ApplicationSecret, Authenticator, DeviceFlow}; /// let creds = ApplicationSecret::default(); - /// let auth = Authenticator::new(DeviceFlow::new(creds)).build().unwrap(); + /// let auth = Authenticator::new(DeviceFlow::new(creds)).build().await.unwrap(); + /// # } /// ``` - pub fn new( - flow: T, - ) -> Authenticator { + pub fn new(flow: T) -> Authenticator { Authenticator { client: DefaultHyperClient, token_getter: flow, - store: Ok(MemoryStorage::new()), + storage_type: StorageType::Memory, delegate: DefaultAuthenticatorDelegate, } } } -impl Authenticator +impl Authenticator where T: AuthFlow, - S: TokenStorage, AD: AuthenticatorDelegate, C: HyperClientBuilder, { @@ -118,7 +120,7 @@ where pub fn hyper_client( self, hyper_client: hyper::Client, - ) -> Authenticator> + ) -> Authenticator> where NewC: hyper::client::connect::Connect + 'static, T: AuthFlow, @@ -126,21 +128,17 @@ where Authenticator { client: hyper_client, token_getter: self.token_getter, - store: self.store, + storage_type: self.storage_type, delegate: self.delegate, } } /// Persist tokens to disk in the provided filename. - pub fn persist_tokens_to_disk>( - self, - path: P, - ) -> Authenticator { - let disk_storage = DiskTokenStorage::new(path.as_ref().to_str().unwrap()); + pub fn persist_tokens_to_disk>(self, path: P) -> Authenticator { Authenticator { client: self.client, token_getter: self.token_getter, - store: disk_storage, + storage_type: StorageType::Disk(path.into()), delegate: self.delegate, } } @@ -149,24 +147,29 @@ where pub fn delegate( self, delegate: NewAD, - ) -> Authenticator { + ) -> Authenticator { Authenticator { client: self.client, token_getter: self.token_getter, - store: self.store, + storage_type: self.storage_type, delegate, } } /// Create the authenticator. - pub fn build(self) -> io::Result + pub async fn build(self) -> io::Result where T::TokenGetter: GetToken, C::Connector: hyper::client::connect::Connect + 'static, { let client = self.client.build_hyper_client(); - let store = self.store?; let inner = self.token_getter.build_token_getter(client.clone()); + let store = match self.storage_type { + StorageType::Memory => Storage::Memory { + tokens: Mutex::new(storage::JSONTokens::new()), + }, + StorageType::Disk(path) => Storage::Disk(storage::DiskStorage::new(path).await?), + }; Ok(AuthenticatorImpl { client, @@ -177,10 +180,9 @@ where } } -impl AuthenticatorImpl +impl AuthenticatorImpl where GT: GetToken, - S: TokenStorage, AD: AuthenticatorDelegate, C: hyper::client::connect::Connect + 'static, { @@ -188,83 +190,61 @@ where where T: AsRef + Sync, { - let scope_key = hash_scopes(scopes); + let scope_key = storage::ScopeHash::new(scopes); let store = &self.store; let delegate = &self.delegate; let client = &self.client; let gettoken = &self.inner; let appsecret = gettoken.application_secret(); - loop { - match store.get(scope_key, scopes) { - Ok(Some(t)) if !t.expired() => { - // unexpired token found - return Ok(t); - } - Ok(Some(Token { - refresh_token: Some(refresh_token), - .. - })) => { - // token is expired but has a refresh token. - let rr = RefreshFlow::refresh_token(client, appsecret, &refresh_token).await?; - match rr { - RefreshResult::Error(ref e) => { - delegate.token_refresh_failed( - e.description(), - Some("the request has likely timed out"), + match store.get(scope_key, scopes) { + Some(t) if !t.expired() => { + // unexpired token found + Ok(t) + } + Some(Token { + refresh_token: Some(refresh_token), + .. + }) => { + // token is expired but has a refresh token. + let rr = RefreshFlow::refresh_token(client, appsecret, &refresh_token).await?; + match rr { + RefreshResult::Error(ref e) => { + delegate.token_refresh_failed( + e.description(), + Some("the request has likely timed out"), + ); + Err(RequestError::Refresh(rr)) + } + RefreshResult::RefreshError(ref s, ref ss) => { + delegate.token_refresh_failed( + &format!("{}{}", s, ss.as_ref().map(|s| format!(" ({})", s)).unwrap_or_else(String::new)), + Some("the refresh token is likely invalid and your authorization has been revoked"), ); - return Err(RequestError::Refresh(rr)); - } - RefreshResult::RefreshError(ref s, ref ss) => { - delegate.token_refresh_failed( - &format!("{}{}", s, ss.as_ref().map(|s| format!(" ({})", s)).unwrap_or_else(String::new)), - Some("the refresh token is likely invalid and your authorization has been revoked"), - ); - return Err(RequestError::Refresh(rr)); - } - RefreshResult::Success(t) => { - let x = store.set(scope_key, scopes, Some(t.clone())); - if let Err(e) = x { - match delegate.token_storage_failure(true, &e) { - Retry::Skip => return Ok(t), - Retry::Abort => return Err(RequestError::Cache(Box::new(e))), - Retry::After(d) => tokio::timer::delay_for(d).await, - } - } else { - return Ok(t); - } - } + Err(RequestError::Refresh(rr)) + } + RefreshResult::Success(t) => { + store.set(scope_key, scopes, Some(t.clone())).await; + Ok(t) } } - Ok(None) - | Ok(Some(Token { - refresh_token: None, - .. - })) => { - // no token in the cache or the token returned does not contain a refresh token. - let t = gettoken.token(scopes).await?; - if let Err(e) = store.set(scope_key, scopes, Some(t.clone())) { - match delegate.token_storage_failure(true, &e) { - Retry::Skip => return Ok(t), - Retry::Abort => return Err(RequestError::Cache(Box::new(e))), - Retry::After(d) => tokio::timer::delay_for(d).await, - } - } else { - return Ok(t); - } - } - Err(err) => match delegate.token_storage_failure(false, &err) { - Retry::Abort | Retry::Skip => return Err(RequestError::Cache(Box::new(err))), - Retry::After(d) => tokio::timer::delay_for(d).await, - }, + } + None + | Some(Token { + refresh_token: None, + .. + }) => { + // no token in the cache or the token returned does not contain a refresh token. + let t = gettoken.token(scopes).await?; + store.set(scope_key, scopes, Some(t.clone())).await; + Ok(t) } } } } -impl GetToken for AuthenticatorImpl +impl GetToken for AuthenticatorImpl where GT: GetToken, - S: TokenStorage, AD: AuthenticatorDelegate, C: hyper::client::connect::Connect + 'static, { diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index eee7ec9..b039e3b 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -79,16 +79,6 @@ pub trait AuthenticatorDelegate: Send + Sync { Retry::Abort } - /// Called whenever we failed to retrieve a token or set a token due to a storage error. - /// You may use it to either ignore the incident or retry. - /// This can be useful if the underlying `TokenStorage` may fail occasionally. - /// if `is_set` is true, the failure resulted from `TokenStorage.set(...)`. Otherwise, - /// it was `TokenStorage.get(...)` - fn token_storage_failure(&self, is_set: bool, _: &(dyn Error + Send + Sync)) -> Retry { - let _ = is_set; - Retry::Abort - } - /// The server denied the attempt to obtain a request code fn request_failure(&self, _: RequestError) {} diff --git a/src/device.rs b/src/device.rs index 51d97b3..a72fddd 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,7 +1,7 @@ use std::pin::Pin; use std::time::Duration; -use ::log::{error, log}; +use ::log::error; use chrono::{DateTime, Utc}; use futures::prelude::*; use hyper; diff --git a/src/lib.rs b/src/lib.rs index f53d54c..549ea28 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -61,6 +61,7 @@ //! ) //! .persist_tokens_to_disk("tokencache.json") //! .build() +//! .await //! .unwrap(); //! //! let scopes = &["https://www.googleapis.com/auth/drive.file"]; @@ -96,7 +97,6 @@ pub use crate::device::{DeviceFlow, GOOGLE_DEVICE_CODE_URL}; pub use crate::helper::*; pub use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; pub use crate::service_account::*; -pub use crate::storage::{DiskTokenStorage, MemoryStorage, NullStorage, TokenStorage}; pub use crate::types::{ ApplicationSecret, ConsoleApplicationSecret, GetToken, PollError, RefreshResult, RequestError, Scheme, Token, TokenType, diff --git a/src/service_account.rs b/src/service_account.rs index 03525a5..1b10de8 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -11,12 +11,11 @@ //! Copyright (c) 2016 Google Inc (lewinb@google.com). //! -use std::default::Default; use std::pin::Pin; -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; use crate::authenticator::{DefaultHyperClient, HyperClientBuilder}; -use crate::storage::{hash_scopes, MemoryStorage, TokenStorage}; +use crate::storage::{self, Storage}; use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, RequestError, Token}; use futures::prelude::*; @@ -210,7 +209,7 @@ where struct ServiceAccountAccessImpl { client: hyper::Client, key: ServiceAccountKey, - cache: Arc>, + cache: Storage, subject: Option, signer: JWTSigner, } @@ -228,7 +227,9 @@ where Ok(ServiceAccountAccessImpl { client, key, - cache: Arc::new(Mutex::new(MemoryStorage::default())), + cache: Storage::Memory { + tokens: Mutex::new(storage::JSONTokens::new()), + }, subject, signer, }) @@ -309,10 +310,10 @@ where where T: AsRef, { - let hash = hash_scopes(scopes); + let hash = storage::ScopeHash::new(scopes); let cache = &self.cache; - match cache.lock().unwrap().get(hash, scopes) { - Ok(Some(token)) if !token.expired() => return Ok(token), + match cache.get(hash, scopes) { + Some(token) if !token.expired() => return Ok(token), _ => {} } let token = Self::request_token( @@ -323,7 +324,7 @@ where scopes, ) .await?; - let _ = cache.lock().unwrap().set(hash, scopes, Some(token.clone())); + cache.set(hash, scopes, Some(token.clone())).await; Ok(token) } } @@ -425,13 +426,12 @@ mod tests { assert!(acc .cache - .lock() - .unwrap() .get( - 3502164897243251857, + dbg!(storage::ScopeHash::new(&[ + "https://www.googleapis.com/auth/pubsub" + ])), &["https://www.googleapis.com/auth/pubsub"], ) - .unwrap() .is_some()); // Test that token is in cache (otherwise mock will tell us) let fut = async { diff --git a/src/storage.rs b/src/storage.rs index fef3c0a..17cf2db 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -5,127 +5,65 @@ use std::cmp::Ordering; use std::collections::hash_map::DefaultHasher; -use std::error::Error; -use std::fs; use std::hash::{Hash, Hasher}; use std::io; -use std::io::Write; use std::path::{Path, PathBuf}; use std::sync::Mutex; use crate::types::Token; -use itertools::Itertools; -/// Implements a specialized storage to set and retrieve `Token` instances. -/// The `scope_hash` represents the signature of the scopes for which the given token -/// should be stored or retrieved. -/// For completeness, the underlying, sorted scopes are provided as well. They might be -/// useful for presentation to the user. -pub trait TokenStorage: Send + Sync { - type Error: 'static + Error + Send + Sync; +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct ScopeHash(u64); - /// If `token` is None, it is invalid or revoked and should be removed from storage. - /// Otherwise, it should be saved. - fn set( - &self, - scope_hash: u64, - scopes: &[T], - token: Option, - ) -> Result<(), Self::Error> - where - T: AsRef; - - /// A `None` result indicates that there is no token for the given scope_hash. - fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> - where - T: AsRef; -} - -/// Calculate a hash value describing the scopes. The order of the scopes in the -/// list does not change the hash value. i.e. two lists that contains the exact -/// same scopes, but in different order will return the same hash value. -pub fn hash_scopes(scopes: &[T]) -> u64 -where - T: AsRef, -{ - let mut hash_sum = DefaultHasher::new().finish(); - for scope in scopes { - let mut hasher = DefaultHasher::new(); - scope.as_ref().hash(&mut hasher); - hash_sum ^= hasher.finish(); - } - hash_sum -} - -/// A storage that remembers nothing. -#[derive(Default)] -pub struct NullStorage; - -impl TokenStorage for NullStorage { - type Error = std::convert::Infallible; - fn set(&self, _: u64, _: &[T], _: Option) -> Result<(), Self::Error> +impl ScopeHash { + /// Calculate a hash value describing the scopes. The order of the scopes in the + /// list does not change the hash value. i.e. two lists that contains the exact + /// same scopes, but in different order will return the same hash value. + pub fn new(scopes: &[T]) -> Self where T: AsRef, { - Ok(()) - } - - fn get(&self, _: u64, _: &[T]) -> Result, Self::Error> - where - T: AsRef, - { - Ok(None) - } -} - -/// A storage that remembers values for one session only. -#[derive(Debug, Default)] -pub struct MemoryStorage { - tokens: Mutex>, -} - -impl MemoryStorage { - pub fn new() -> MemoryStorage { - Default::default() - } -} - -impl TokenStorage for MemoryStorage { - type Error = std::convert::Infallible; - - fn set(&self, scope_hash: u64, scopes: &[T], token: Option) -> Result<(), Self::Error> - where - T: AsRef, - { - let mut tokens = self.tokens.lock().expect("poisoned mutex"); - let matched = tokens.iter().find_position(|x| x.hash == scope_hash); - if let Some((idx, _)) = matched { - self.tokens.retain(|x| x.hash != scope_hash); + let mut hash_sum = DefaultHasher::new().finish(); + for scope in scopes { + let mut hasher = DefaultHasher::new(); + scope.as_ref().hash(&mut hasher); + hash_sum ^= hasher.finish(); } - - if let Some(t) = token { - tokens.push(JSONToken { - hash: scope_hash, - scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()), - token: t, - }); - } - Ok(()) + ScopeHash(hash_sum) } +} - fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> +pub(crate) enum Storage { + Memory { tokens: Mutex }, + Disk(DiskStorage), +} + +impl Storage { + pub(crate) async fn set(&self, h: ScopeHash, scopes: &[T], token: Option) where T: AsRef, { - let tokens = self.tokens.lock().expect("poisoned mutex"); - Ok(token_for_scopes(&tokens, scope_hash, scopes)) + match self { + Storage::Memory { tokens } => tokens.lock().unwrap().set(h, scopes, token), + Storage::Disk(disk_storage) => disk_storage.set(h, scopes, token).await, + } + } + + pub(crate) fn get(&self, h: ScopeHash, scopes: &[T]) -> Option + where + T: AsRef, + { + match self { + Storage::Memory { tokens } => tokens.lock().unwrap().get(h, scopes), + Storage::Disk(disk_storage) => disk_storage.get(h, scopes), + } } } /// A single stored token. #[derive(Debug, Clone, Serialize, Deserialize)] struct JSONToken { - pub hash: u64, + pub hash: ScopeHash, pub scopes: Option>, pub token: Token, } @@ -151,121 +89,123 @@ impl Ord for JSONToken { } /// List of tokens in a JSON object -#[derive(Serialize, Deserialize)] -struct JSONTokens { - pub tokens: Vec, +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct JSONTokens { + tokens: Vec, } -/// Serializes tokens to a JSON file on disk. -#[derive(Default)] -pub struct DiskTokenStorage { - location: PathBuf, - tokens: Mutex>, +impl JSONTokens { + pub(crate) fn new() -> Self { + JSONTokens { tokens: Vec::new() } + } + + pub(crate) async fn load_from_file(filename: &Path) -> Result { + let contents = tokio::fs::read(filename).await?; + let container: JSONTokens = serde_json::from_slice(&contents) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(container) + } + + fn get(&self, h: ScopeHash, scopes: &[T]) -> Option + where + T: AsRef, + { + for t in self.tokens.iter() { + if let Some(token_scopes) = &t.scopes { + if scopes + .iter() + .all(|s| token_scopes.iter().any(|t| t == s.as_ref())) + { + return Some(t.token.clone()); + } + } else if h == t.hash { + return Some(t.token.clone()); + } + } + None + } + + fn set(&mut self, h: ScopeHash, scopes: &[T], token: Option) + where + T: AsRef, + { + eprintln!("setting: {:?}, {:?}", h, token); + self.tokens.retain(|x| x.hash != h); + + match token { + None => (), + Some(t) => { + self.tokens.push(JSONToken { + hash: h, + scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()), + token: t, + }); + } + } + } + + // TODO: ideally this function would accept &Path, but tokio requires the + // path be 'static. Revisit this and ask why tokio::fs::write has that + // limitation. + async fn dump_to_file(&self, path: PathBuf) -> Result<(), io::Error> { + let serialized = serde_json::to_string(self) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + tokio::fs::write(path, &serialized).await + } } -impl DiskTokenStorage { - pub fn new>(location: S) -> Result { - let filename = location.into(); - let tokens = match load_from_file(&filename) { - Ok(tokens) => tokens, - Err(e) if e.kind() == io::ErrorKind::NotFound => Vec::new(), - Err(e) => return Err(e), - }; - Ok(DiskTokenStorage { - location: filename, +pub(crate) struct DiskStorage { + tokens: Mutex, + write_tx: tokio::sync::mpsc::Sender, +} + +impl DiskStorage { + pub(crate) async fn new(path: PathBuf) -> Result { + let tokens = JSONTokens::load_from_file(&path).await?; + // Writing to disk will happen in a separate task. This means in the + // common case returning a token to the user will not be required to + // wait for disk i/o. We communicate with a dedicated writer task via a + // buffered channel. This ensures that the writes happen in the order + // received, and if writes fall too far behind we will block GetToken + // requests until disk i/o completes. + let (write_tx, mut write_rx) = tokio::sync::mpsc::channel::(2); + tokio::spawn(async move { + while let Some(tokens) = write_rx.recv().await { + if let Err(e) = tokens.dump_to_file(path.to_path_buf()).await { + log::error!("Failed to write token storage to disk: {}", e); + } + } + }); + Ok(DiskStorage { tokens: Mutex::new(tokens), + write_tx, }) } - pub fn dump_to_file(&self) -> Result<(), io::Error> { - let mut jsontokens = JSONTokens { tokens: Vec::new() }; - - { - let tokens = self.tokens.lock().expect("mutex poisoned"); - for token in tokens.iter() { - jsontokens.tokens.push((*token).clone()); - } - } - - let serialized; - - match serde_json::to_string(&jsontokens) { - Result::Err(e) => return Result::Err(io::Error::new(io::ErrorKind::InvalidData, e)), - Result::Ok(s) => serialized = s, - } - - // TODO: Write to disk asynchronously so that we don't stall the eventloop if invoked in async context. - let mut f = fs::OpenOptions::new() - .create(true) - .write(true) - .truncate(true) - .open(&self.location)?; - f.write(serialized.as_ref()).map(|_| ()) - } -} - -fn load_from_file(filename: &Path) -> Result, io::Error> { - let contents = std::fs::read_to_string(filename)?; - let container: JSONTokens = serde_json::from_str(&contents) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - Ok(container.tokens) -} - -impl TokenStorage for DiskTokenStorage { - type Error = io::Error; - fn set(&self, scope_hash: u64, scopes: &[T], token: Option) -> Result<(), Self::Error> + async fn set(&self, h: ScopeHash, scopes: &[T], token: Option) where T: AsRef, { - { - let mut tokens = self.tokens.lock().expect("poisoned mutex"); - let matched = tokens.iter().find_position(|x| x.hash == scope_hash); - if let Some((idx, _)) = matched { - self.tokens.retain(|x| x.hash != scope_hash); - } - - match token { - None => (), - Some(t) => { - tokens.push(JSONToken { - hash: scope_hash, - scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()), - token: t, - }); - } - } - } - self.dump_to_file() + let cloned_tokens = { + let mut tokens = self.tokens.lock().unwrap(); + tokens.set(h, scopes, token); + tokens.clone() + }; + self.write_tx + .clone() + .send(cloned_tokens) + .await + .expect("disk storage task not running"); } - fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> + pub(crate) fn get(&self, h: ScopeHash, scopes: &[T]) -> Option where T: AsRef, { - let tokens = self.tokens.lock().expect("poisoned mutex"); - Ok(token_for_scopes(&tokens, scope_hash, scopes)) + self.tokens.lock().unwrap().get(h, scopes) } } -fn token_for_scopes(tokens: &[JSONToken], scope_hash: u64, scopes: &[T]) -> Option -where - T: AsRef, -{ - for t in tokens.iter() { - if let Some(token_scopes) = &t.scopes { - if scopes - .iter() - .all(|s| token_scopes.iter().any(|t| t == s.as_ref())) - { - return Some(t.token.clone()); - } - } else if scope_hash == t.hash { - return Some(t.token.clone()); - } - } - None -} - #[cfg(test)] mod tests { use super::*; @@ -273,19 +213,25 @@ mod tests { #[test] fn test_hash_scopes() { // Idential list should hash equal. - assert_eq!(hash_scopes(&["foo", "bar"]), hash_scopes(&["foo", "bar"])); - // The hash should be order independent. - assert_eq!(hash_scopes(&["bar", "foo"]), hash_scopes(&["foo", "bar"])); assert_eq!( - hash_scopes(&["bar", "baz", "bat"]), - hash_scopes(&["baz", "bar", "bat"]) + ScopeHash::new(&["foo", "bar"]), + ScopeHash::new(&["foo", "bar"]) + ); + // The hash should be order independent. + assert_eq!( + ScopeHash::new(&["bar", "foo"]), + ScopeHash::new(&["foo", "bar"]) + ); + assert_eq!( + ScopeHash::new(&["bar", "baz", "bat"]), + ScopeHash::new(&["baz", "bar", "bat"]) ); // Ensure hashes differ when the contents are different by more than // just order. assert_ne!( - hash_scopes(&["foo", "bar", "baz"]), - hash_scopes(&["foo", "bar"]) + ScopeHash::new(&["foo", "bar", "baz"]), + ScopeHash::new(&["foo", "bar"]) ); } }