From 88a8f74406327b2f6ade50d3985426cb7fe2e1cc Mon Sep 17 00:00:00 2001 From: Glenn Griffin Date: Tue, 12 Nov 2019 13:03:01 -0800 Subject: [PATCH] Refactor token storage. The current code uses standard blocking i/o operations (std::fs::*) this is problematic as it would block the entire futures executor waiting for i/o. This change is a major refactoring to make the token storage mechansim async i/o friendly. The first major decision was to abandon the GetToken trait. The trait is only implemented internally and there was no mechanism for users to provide their own, but async fn's are not currently supported in trait impls so keeping the trait would have required Boxing futures. This probably would have been fine, but seemed unnecessary. Instead of a trait the storage mechanism is just an enum with a choice between Memory and Disk storage. The DiskStorage works primarily as it did before, rewriting the entire contents of the file on every set() invocation. The only difference is that we now defer the actual writing to a separate task so that it does not block the return of the Token to the user. If disk i/o is too slow to keep up with the rate of incoming writes it will push back and will eventually block the return of tokens, this is to prevent a buildup of in-flight requests. One major drawback to this approach is that any errors that happen on write are simply logged and no delegate function is invoked on error because the delegate no longer has the ability to say to sleep, retry, etc. --- Cargo.toml | 3 +- examples/test-device/src/main.rs | 1 + examples/test-installed/src/main.rs | 1 + src/authenticator.rs | 172 ++++++-------- src/authenticator_delegate.rs | 10 - src/device.rs | 2 +- src/lib.rs | 2 +- src/service_account.rs | 26 +-- src/storage.rs | 348 ++++++++++++---------------- 9 files changed, 241 insertions(+), 324 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e1ca60f..88618be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,8 +16,7 @@ chrono = "0.4" http = "0.1" hyper = {version = "0.13.0-alpha.4", features = ["unstable-stream"]} hyper-rustls = "=0.18.0-alpha.2" -itertools = "0.8" -log = "0.3" +log = "0.4" rustls = "0.16" serde = "1.0" serde_json = "1.0" diff --git a/examples/test-device/src/main.rs b/examples/test-device/src/main.rs index 64e413c..62bdf10 100644 --- a/examples/test-device/src/main.rs +++ b/examples/test-device/src/main.rs @@ -10,6 +10,7 @@ async fn main() { let auth = Authenticator::new(DeviceFlow::new(creds)) .persist_tokens_to_disk("tokenstorage.json") .build() + .await .expect("authenticator"); let scopes = &["https://www.googleapis.com/auth/youtube.readonly"]; diff --git a/examples/test-installed/src/main.rs b/examples/test-installed/src/main.rs index 54be93f..1bdfee0 100644 --- a/examples/test-installed/src/main.rs +++ b/examples/test-installed/src/main.rs @@ -14,6 +14,7 @@ async fn main() { )) .persist_tokens_to_disk("tokencache.json") .build() + .await .unwrap(); let scopes = &["https://www.googleapis.com/auth/drive.file"]; diff --git a/src/authenticator.rs b/src/authenticator.rs index dd2292b..b332e5a 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -1,14 +1,15 @@ -use crate::authenticator_delegate::{AuthenticatorDelegate, DefaultAuthenticatorDelegate, Retry}; +use crate::authenticator_delegate::{AuthenticatorDelegate, DefaultAuthenticatorDelegate}; use crate::refresh::RefreshFlow; -use crate::storage::{hash_scopes, DiskTokenStorage, MemoryStorage, TokenStorage}; +use crate::storage::{self, Storage}; use crate::types::{ApplicationSecret, GetToken, RefreshResult, RequestError, Token}; use futures::prelude::*; use std::error::Error; use std::io; -use std::path::Path; +use std::path::PathBuf; use std::pin::Pin; +use std::sync::Mutex; /// Authenticator abstracts different `GetToken` implementations behind one type and handles /// caching received tokens. It's important to use it (instead of the flows directly) because @@ -20,15 +21,11 @@ use std::pin::Pin; /// NOTE: It is recommended to use a client constructed like this in order to prevent functions /// like `hyper::run()` from hanging: `let client = hyper::Client::builder().keep_alive(false);`. /// Due to token requests being rare, this should not result in a too bad performance problem. -struct AuthenticatorImpl< - T: GetToken, - S: TokenStorage, - AD: AuthenticatorDelegate, - C: hyper::client::connect::Connect, -> { +struct AuthenticatorImpl +{ client: hyper::Client, inner: T, - store: S, + store: Storage, delegate: AD, } @@ -69,17 +66,22 @@ pub trait AuthFlow { fn build_token_getter(self, client: hyper::Client) -> Self::TokenGetter; } +enum StorageType { + Memory, + Disk(PathBuf), +} + /// An authenticator can be used with `InstalledFlow`'s or `DeviceFlow`'s and /// will refresh tokens as they expire as well as optionally persist tokens to /// disk. -pub struct Authenticator { +pub struct Authenticator { client: C, token_getter: T, - store: io::Result, + storage_type: StorageType, delegate: AD, } -impl Authenticator +impl Authenticator where T: AuthFlow<::Connector>, { @@ -90,27 +92,27 @@ where /// /// Examples /// ``` + /// # #[tokio::main] + /// # async fn main() { /// use std::path::Path; /// use yup_oauth2::{ApplicationSecret, Authenticator, DeviceFlow}; /// let creds = ApplicationSecret::default(); - /// let auth = Authenticator::new(DeviceFlow::new(creds)).build().unwrap(); + /// let auth = Authenticator::new(DeviceFlow::new(creds)).build().await.unwrap(); + /// # } /// ``` - pub fn new( - flow: T, - ) -> Authenticator { + pub fn new(flow: T) -> Authenticator { Authenticator { client: DefaultHyperClient, token_getter: flow, - store: Ok(MemoryStorage::new()), + storage_type: StorageType::Memory, delegate: DefaultAuthenticatorDelegate, } } } -impl Authenticator +impl Authenticator where T: AuthFlow, - S: TokenStorage, AD: AuthenticatorDelegate, C: HyperClientBuilder, { @@ -118,7 +120,7 @@ where pub fn hyper_client( self, hyper_client: hyper::Client, - ) -> Authenticator> + ) -> Authenticator> where NewC: hyper::client::connect::Connect + 'static, T: AuthFlow, @@ -126,21 +128,17 @@ where Authenticator { client: hyper_client, token_getter: self.token_getter, - store: self.store, + storage_type: self.storage_type, delegate: self.delegate, } } /// Persist tokens to disk in the provided filename. - pub fn persist_tokens_to_disk>( - self, - path: P, - ) -> Authenticator { - let disk_storage = DiskTokenStorage::new(path.as_ref().to_str().unwrap()); + pub fn persist_tokens_to_disk>(self, path: P) -> Authenticator { Authenticator { client: self.client, token_getter: self.token_getter, - store: disk_storage, + storage_type: StorageType::Disk(path.into()), delegate: self.delegate, } } @@ -149,24 +147,29 @@ where pub fn delegate( self, delegate: NewAD, - ) -> Authenticator { + ) -> Authenticator { Authenticator { client: self.client, token_getter: self.token_getter, - store: self.store, + storage_type: self.storage_type, delegate, } } /// Create the authenticator. - pub fn build(self) -> io::Result + pub async fn build(self) -> io::Result where T::TokenGetter: GetToken, C::Connector: hyper::client::connect::Connect + 'static, { let client = self.client.build_hyper_client(); - let store = self.store?; let inner = self.token_getter.build_token_getter(client.clone()); + let store = match self.storage_type { + StorageType::Memory => Storage::Memory { + tokens: Mutex::new(storage::JSONTokens::new()), + }, + StorageType::Disk(path) => Storage::Disk(storage::DiskStorage::new(path).await?), + }; Ok(AuthenticatorImpl { client, @@ -177,10 +180,9 @@ where } } -impl AuthenticatorImpl +impl AuthenticatorImpl where GT: GetToken, - S: TokenStorage, AD: AuthenticatorDelegate, C: hyper::client::connect::Connect + 'static, { @@ -188,83 +190,61 @@ where where T: AsRef + Sync, { - let scope_key = hash_scopes(scopes); + let scope_key = storage::ScopeHash::new(scopes); let store = &self.store; let delegate = &self.delegate; let client = &self.client; let gettoken = &self.inner; let appsecret = gettoken.application_secret(); - loop { - match store.get(scope_key, scopes) { - Ok(Some(t)) if !t.expired() => { - // unexpired token found - return Ok(t); - } - Ok(Some(Token { - refresh_token: Some(refresh_token), - .. - })) => { - // token is expired but has a refresh token. - let rr = RefreshFlow::refresh_token(client, appsecret, &refresh_token).await?; - match rr { - RefreshResult::Error(ref e) => { - delegate.token_refresh_failed( - e.description(), - Some("the request has likely timed out"), + match store.get(scope_key, scopes) { + Some(t) if !t.expired() => { + // unexpired token found + Ok(t) + } + Some(Token { + refresh_token: Some(refresh_token), + .. + }) => { + // token is expired but has a refresh token. + let rr = RefreshFlow::refresh_token(client, appsecret, &refresh_token).await?; + match rr { + RefreshResult::Error(ref e) => { + delegate.token_refresh_failed( + e.description(), + Some("the request has likely timed out"), + ); + Err(RequestError::Refresh(rr)) + } + RefreshResult::RefreshError(ref s, ref ss) => { + delegate.token_refresh_failed( + &format!("{}{}", s, ss.as_ref().map(|s| format!(" ({})", s)).unwrap_or_else(String::new)), + Some("the refresh token is likely invalid and your authorization has been revoked"), ); - return Err(RequestError::Refresh(rr)); - } - RefreshResult::RefreshError(ref s, ref ss) => { - delegate.token_refresh_failed( - &format!("{}{}", s, ss.as_ref().map(|s| format!(" ({})", s)).unwrap_or_else(String::new)), - Some("the refresh token is likely invalid and your authorization has been revoked"), - ); - return Err(RequestError::Refresh(rr)); - } - RefreshResult::Success(t) => { - let x = store.set(scope_key, scopes, Some(t.clone())); - if let Err(e) = x { - match delegate.token_storage_failure(true, &e) { - Retry::Skip => return Ok(t), - Retry::Abort => return Err(RequestError::Cache(Box::new(e))), - Retry::After(d) => tokio::timer::delay_for(d).await, - } - } else { - return Ok(t); - } - } + Err(RequestError::Refresh(rr)) + } + RefreshResult::Success(t) => { + store.set(scope_key, scopes, Some(t.clone())).await; + Ok(t) } } - Ok(None) - | Ok(Some(Token { - refresh_token: None, - .. - })) => { - // no token in the cache or the token returned does not contain a refresh token. - let t = gettoken.token(scopes).await?; - if let Err(e) = store.set(scope_key, scopes, Some(t.clone())) { - match delegate.token_storage_failure(true, &e) { - Retry::Skip => return Ok(t), - Retry::Abort => return Err(RequestError::Cache(Box::new(e))), - Retry::After(d) => tokio::timer::delay_for(d).await, - } - } else { - return Ok(t); - } - } - Err(err) => match delegate.token_storage_failure(false, &err) { - Retry::Abort | Retry::Skip => return Err(RequestError::Cache(Box::new(err))), - Retry::After(d) => tokio::timer::delay_for(d).await, - }, + } + None + | Some(Token { + refresh_token: None, + .. + }) => { + // no token in the cache or the token returned does not contain a refresh token. + let t = gettoken.token(scopes).await?; + store.set(scope_key, scopes, Some(t.clone())).await; + Ok(t) } } } } -impl GetToken for AuthenticatorImpl +impl GetToken for AuthenticatorImpl where GT: GetToken, - S: TokenStorage, AD: AuthenticatorDelegate, C: hyper::client::connect::Connect + 'static, { diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index eee7ec9..b039e3b 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -79,16 +79,6 @@ pub trait AuthenticatorDelegate: Send + Sync { Retry::Abort } - /// Called whenever we failed to retrieve a token or set a token due to a storage error. - /// You may use it to either ignore the incident or retry. - /// This can be useful if the underlying `TokenStorage` may fail occasionally. - /// if `is_set` is true, the failure resulted from `TokenStorage.set(...)`. Otherwise, - /// it was `TokenStorage.get(...)` - fn token_storage_failure(&self, is_set: bool, _: &(dyn Error + Send + Sync)) -> Retry { - let _ = is_set; - Retry::Abort - } - /// The server denied the attempt to obtain a request code fn request_failure(&self, _: RequestError) {} diff --git a/src/device.rs b/src/device.rs index 51d97b3..a72fddd 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,7 +1,7 @@ use std::pin::Pin; use std::time::Duration; -use ::log::{error, log}; +use ::log::error; use chrono::{DateTime, Utc}; use futures::prelude::*; use hyper; diff --git a/src/lib.rs b/src/lib.rs index f53d54c..549ea28 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -61,6 +61,7 @@ //! ) //! .persist_tokens_to_disk("tokencache.json") //! .build() +//! .await //! .unwrap(); //! //! let scopes = &["https://www.googleapis.com/auth/drive.file"]; @@ -96,7 +97,6 @@ pub use crate::device::{DeviceFlow, GOOGLE_DEVICE_CODE_URL}; pub use crate::helper::*; pub use crate::installed::{InstalledFlow, InstalledFlowReturnMethod}; pub use crate::service_account::*; -pub use crate::storage::{DiskTokenStorage, MemoryStorage, NullStorage, TokenStorage}; pub use crate::types::{ ApplicationSecret, ConsoleApplicationSecret, GetToken, PollError, RefreshResult, RequestError, Scheme, Token, TokenType, diff --git a/src/service_account.rs b/src/service_account.rs index 03525a5..1b10de8 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -11,12 +11,11 @@ //! Copyright (c) 2016 Google Inc (lewinb@google.com). //! -use std::default::Default; use std::pin::Pin; -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; use crate::authenticator::{DefaultHyperClient, HyperClientBuilder}; -use crate::storage::{hash_scopes, MemoryStorage, TokenStorage}; +use crate::storage::{self, Storage}; use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, RequestError, Token}; use futures::prelude::*; @@ -210,7 +209,7 @@ where struct ServiceAccountAccessImpl { client: hyper::Client, key: ServiceAccountKey, - cache: Arc>, + cache: Storage, subject: Option, signer: JWTSigner, } @@ -228,7 +227,9 @@ where Ok(ServiceAccountAccessImpl { client, key, - cache: Arc::new(Mutex::new(MemoryStorage::default())), + cache: Storage::Memory { + tokens: Mutex::new(storage::JSONTokens::new()), + }, subject, signer, }) @@ -309,10 +310,10 @@ where where T: AsRef, { - let hash = hash_scopes(scopes); + let hash = storage::ScopeHash::new(scopes); let cache = &self.cache; - match cache.lock().unwrap().get(hash, scopes) { - Ok(Some(token)) if !token.expired() => return Ok(token), + match cache.get(hash, scopes) { + Some(token) if !token.expired() => return Ok(token), _ => {} } let token = Self::request_token( @@ -323,7 +324,7 @@ where scopes, ) .await?; - let _ = cache.lock().unwrap().set(hash, scopes, Some(token.clone())); + cache.set(hash, scopes, Some(token.clone())).await; Ok(token) } } @@ -425,13 +426,12 @@ mod tests { assert!(acc .cache - .lock() - .unwrap() .get( - 3502164897243251857, + dbg!(storage::ScopeHash::new(&[ + "https://www.googleapis.com/auth/pubsub" + ])), &["https://www.googleapis.com/auth/pubsub"], ) - .unwrap() .is_some()); // Test that token is in cache (otherwise mock will tell us) let fut = async { diff --git a/src/storage.rs b/src/storage.rs index fef3c0a..17cf2db 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -5,127 +5,65 @@ use std::cmp::Ordering; use std::collections::hash_map::DefaultHasher; -use std::error::Error; -use std::fs; use std::hash::{Hash, Hasher}; use std::io; -use std::io::Write; use std::path::{Path, PathBuf}; use std::sync::Mutex; use crate::types::Token; -use itertools::Itertools; -/// Implements a specialized storage to set and retrieve `Token` instances. -/// The `scope_hash` represents the signature of the scopes for which the given token -/// should be stored or retrieved. -/// For completeness, the underlying, sorted scopes are provided as well. They might be -/// useful for presentation to the user. -pub trait TokenStorage: Send + Sync { - type Error: 'static + Error + Send + Sync; +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct ScopeHash(u64); - /// If `token` is None, it is invalid or revoked and should be removed from storage. - /// Otherwise, it should be saved. - fn set( - &self, - scope_hash: u64, - scopes: &[T], - token: Option, - ) -> Result<(), Self::Error> - where - T: AsRef; - - /// A `None` result indicates that there is no token for the given scope_hash. - fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> - where - T: AsRef; -} - -/// Calculate a hash value describing the scopes. The order of the scopes in the -/// list does not change the hash value. i.e. two lists that contains the exact -/// same scopes, but in different order will return the same hash value. -pub fn hash_scopes(scopes: &[T]) -> u64 -where - T: AsRef, -{ - let mut hash_sum = DefaultHasher::new().finish(); - for scope in scopes { - let mut hasher = DefaultHasher::new(); - scope.as_ref().hash(&mut hasher); - hash_sum ^= hasher.finish(); - } - hash_sum -} - -/// A storage that remembers nothing. -#[derive(Default)] -pub struct NullStorage; - -impl TokenStorage for NullStorage { - type Error = std::convert::Infallible; - fn set(&self, _: u64, _: &[T], _: Option) -> Result<(), Self::Error> +impl ScopeHash { + /// Calculate a hash value describing the scopes. The order of the scopes in the + /// list does not change the hash value. i.e. two lists that contains the exact + /// same scopes, but in different order will return the same hash value. + pub fn new(scopes: &[T]) -> Self where T: AsRef, { - Ok(()) - } - - fn get(&self, _: u64, _: &[T]) -> Result, Self::Error> - where - T: AsRef, - { - Ok(None) - } -} - -/// A storage that remembers values for one session only. -#[derive(Debug, Default)] -pub struct MemoryStorage { - tokens: Mutex>, -} - -impl MemoryStorage { - pub fn new() -> MemoryStorage { - Default::default() - } -} - -impl TokenStorage for MemoryStorage { - type Error = std::convert::Infallible; - - fn set(&self, scope_hash: u64, scopes: &[T], token: Option) -> Result<(), Self::Error> - where - T: AsRef, - { - let mut tokens = self.tokens.lock().expect("poisoned mutex"); - let matched = tokens.iter().find_position(|x| x.hash == scope_hash); - if let Some((idx, _)) = matched { - self.tokens.retain(|x| x.hash != scope_hash); + let mut hash_sum = DefaultHasher::new().finish(); + for scope in scopes { + let mut hasher = DefaultHasher::new(); + scope.as_ref().hash(&mut hasher); + hash_sum ^= hasher.finish(); } - - if let Some(t) = token { - tokens.push(JSONToken { - hash: scope_hash, - scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()), - token: t, - }); - } - Ok(()) + ScopeHash(hash_sum) } +} - fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> +pub(crate) enum Storage { + Memory { tokens: Mutex }, + Disk(DiskStorage), +} + +impl Storage { + pub(crate) async fn set(&self, h: ScopeHash, scopes: &[T], token: Option) where T: AsRef, { - let tokens = self.tokens.lock().expect("poisoned mutex"); - Ok(token_for_scopes(&tokens, scope_hash, scopes)) + match self { + Storage::Memory { tokens } => tokens.lock().unwrap().set(h, scopes, token), + Storage::Disk(disk_storage) => disk_storage.set(h, scopes, token).await, + } + } + + pub(crate) fn get(&self, h: ScopeHash, scopes: &[T]) -> Option + where + T: AsRef, + { + match self { + Storage::Memory { tokens } => tokens.lock().unwrap().get(h, scopes), + Storage::Disk(disk_storage) => disk_storage.get(h, scopes), + } } } /// A single stored token. #[derive(Debug, Clone, Serialize, Deserialize)] struct JSONToken { - pub hash: u64, + pub hash: ScopeHash, pub scopes: Option>, pub token: Token, } @@ -151,121 +89,123 @@ impl Ord for JSONToken { } /// List of tokens in a JSON object -#[derive(Serialize, Deserialize)] -struct JSONTokens { - pub tokens: Vec, +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct JSONTokens { + tokens: Vec, } -/// Serializes tokens to a JSON file on disk. -#[derive(Default)] -pub struct DiskTokenStorage { - location: PathBuf, - tokens: Mutex>, +impl JSONTokens { + pub(crate) fn new() -> Self { + JSONTokens { tokens: Vec::new() } + } + + pub(crate) async fn load_from_file(filename: &Path) -> Result { + let contents = tokio::fs::read(filename).await?; + let container: JSONTokens = serde_json::from_slice(&contents) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(container) + } + + fn get(&self, h: ScopeHash, scopes: &[T]) -> Option + where + T: AsRef, + { + for t in self.tokens.iter() { + if let Some(token_scopes) = &t.scopes { + if scopes + .iter() + .all(|s| token_scopes.iter().any(|t| t == s.as_ref())) + { + return Some(t.token.clone()); + } + } else if h == t.hash { + return Some(t.token.clone()); + } + } + None + } + + fn set(&mut self, h: ScopeHash, scopes: &[T], token: Option) + where + T: AsRef, + { + eprintln!("setting: {:?}, {:?}", h, token); + self.tokens.retain(|x| x.hash != h); + + match token { + None => (), + Some(t) => { + self.tokens.push(JSONToken { + hash: h, + scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()), + token: t, + }); + } + } + } + + // TODO: ideally this function would accept &Path, but tokio requires the + // path be 'static. Revisit this and ask why tokio::fs::write has that + // limitation. + async fn dump_to_file(&self, path: PathBuf) -> Result<(), io::Error> { + let serialized = serde_json::to_string(self) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + tokio::fs::write(path, &serialized).await + } } -impl DiskTokenStorage { - pub fn new>(location: S) -> Result { - let filename = location.into(); - let tokens = match load_from_file(&filename) { - Ok(tokens) => tokens, - Err(e) if e.kind() == io::ErrorKind::NotFound => Vec::new(), - Err(e) => return Err(e), - }; - Ok(DiskTokenStorage { - location: filename, +pub(crate) struct DiskStorage { + tokens: Mutex, + write_tx: tokio::sync::mpsc::Sender, +} + +impl DiskStorage { + pub(crate) async fn new(path: PathBuf) -> Result { + let tokens = JSONTokens::load_from_file(&path).await?; + // Writing to disk will happen in a separate task. This means in the + // common case returning a token to the user will not be required to + // wait for disk i/o. We communicate with a dedicated writer task via a + // buffered channel. This ensures that the writes happen in the order + // received, and if writes fall too far behind we will block GetToken + // requests until disk i/o completes. + let (write_tx, mut write_rx) = tokio::sync::mpsc::channel::(2); + tokio::spawn(async move { + while let Some(tokens) = write_rx.recv().await { + if let Err(e) = tokens.dump_to_file(path.to_path_buf()).await { + log::error!("Failed to write token storage to disk: {}", e); + } + } + }); + Ok(DiskStorage { tokens: Mutex::new(tokens), + write_tx, }) } - pub fn dump_to_file(&self) -> Result<(), io::Error> { - let mut jsontokens = JSONTokens { tokens: Vec::new() }; - - { - let tokens = self.tokens.lock().expect("mutex poisoned"); - for token in tokens.iter() { - jsontokens.tokens.push((*token).clone()); - } - } - - let serialized; - - match serde_json::to_string(&jsontokens) { - Result::Err(e) => return Result::Err(io::Error::new(io::ErrorKind::InvalidData, e)), - Result::Ok(s) => serialized = s, - } - - // TODO: Write to disk asynchronously so that we don't stall the eventloop if invoked in async context. - let mut f = fs::OpenOptions::new() - .create(true) - .write(true) - .truncate(true) - .open(&self.location)?; - f.write(serialized.as_ref()).map(|_| ()) - } -} - -fn load_from_file(filename: &Path) -> Result, io::Error> { - let contents = std::fs::read_to_string(filename)?; - let container: JSONTokens = serde_json::from_str(&contents) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - Ok(container.tokens) -} - -impl TokenStorage for DiskTokenStorage { - type Error = io::Error; - fn set(&self, scope_hash: u64, scopes: &[T], token: Option) -> Result<(), Self::Error> + async fn set(&self, h: ScopeHash, scopes: &[T], token: Option) where T: AsRef, { - { - let mut tokens = self.tokens.lock().expect("poisoned mutex"); - let matched = tokens.iter().find_position(|x| x.hash == scope_hash); - if let Some((idx, _)) = matched { - self.tokens.retain(|x| x.hash != scope_hash); - } - - match token { - None => (), - Some(t) => { - tokens.push(JSONToken { - hash: scope_hash, - scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()), - token: t, - }); - } - } - } - self.dump_to_file() + let cloned_tokens = { + let mut tokens = self.tokens.lock().unwrap(); + tokens.set(h, scopes, token); + tokens.clone() + }; + self.write_tx + .clone() + .send(cloned_tokens) + .await + .expect("disk storage task not running"); } - fn get(&self, scope_hash: u64, scopes: &[T]) -> Result, Self::Error> + pub(crate) fn get(&self, h: ScopeHash, scopes: &[T]) -> Option where T: AsRef, { - let tokens = self.tokens.lock().expect("poisoned mutex"); - Ok(token_for_scopes(&tokens, scope_hash, scopes)) + self.tokens.lock().unwrap().get(h, scopes) } } -fn token_for_scopes(tokens: &[JSONToken], scope_hash: u64, scopes: &[T]) -> Option -where - T: AsRef, -{ - for t in tokens.iter() { - if let Some(token_scopes) = &t.scopes { - if scopes - .iter() - .all(|s| token_scopes.iter().any(|t| t == s.as_ref())) - { - return Some(t.token.clone()); - } - } else if scope_hash == t.hash { - return Some(t.token.clone()); - } - } - None -} - #[cfg(test)] mod tests { use super::*; @@ -273,19 +213,25 @@ mod tests { #[test] fn test_hash_scopes() { // Idential list should hash equal. - assert_eq!(hash_scopes(&["foo", "bar"]), hash_scopes(&["foo", "bar"])); - // The hash should be order independent. - assert_eq!(hash_scopes(&["bar", "foo"]), hash_scopes(&["foo", "bar"])); assert_eq!( - hash_scopes(&["bar", "baz", "bat"]), - hash_scopes(&["baz", "bar", "bat"]) + ScopeHash::new(&["foo", "bar"]), + ScopeHash::new(&["foo", "bar"]) + ); + // The hash should be order independent. + assert_eq!( + ScopeHash::new(&["bar", "foo"]), + ScopeHash::new(&["foo", "bar"]) + ); + assert_eq!( + ScopeHash::new(&["bar", "baz", "bat"]), + ScopeHash::new(&["baz", "bar", "bat"]) ); // Ensure hashes differ when the contents are different by more than // just order. assert_ne!( - hash_scopes(&["foo", "bar", "baz"]), - hash_scopes(&["foo", "bar"]) + ScopeHash::new(&["foo", "bar", "baz"]), + ScopeHash::new(&["foo", "bar"]) ); } }