mirror of
https://github.com/OMGeeky/yup-oauth2.git
synced 2026-02-23 15:50:00 +01:00
Refactor token storage.
The current code uses standard blocking i/o operations (std::fs::*) this is problematic as it would block the entire futures executor waiting for i/o. This change is a major refactoring to make the token storage mechansim async i/o friendly. The first major decision was to abandon the GetToken trait. The trait is only implemented internally and there was no mechanism for users to provide their own, but async fn's are not currently supported in trait impls so keeping the trait would have required Boxing futures. This probably would have been fine, but seemed unnecessary. Instead of a trait the storage mechanism is just an enum with a choice between Memory and Disk storage. The DiskStorage works primarily as it did before, rewriting the entire contents of the file on every set() invocation. The only difference is that we now defer the actual writing to a separate task so that it does not block the return of the Token to the user. If disk i/o is too slow to keep up with the rate of incoming writes it will push back and will eventually block the return of tokens, this is to prevent a buildup of in-flight requests. One major drawback to this approach is that any errors that happen on write are simply logged and no delegate function is invoked on error because the delegate no longer has the ability to say to sleep, retry, etc.
This commit is contained in:
@@ -16,8 +16,7 @@ chrono = "0.4"
|
||||
http = "0.1"
|
||||
hyper = {version = "0.13.0-alpha.4", features = ["unstable-stream"]}
|
||||
hyper-rustls = "=0.18.0-alpha.2"
|
||||
itertools = "0.8"
|
||||
log = "0.3"
|
||||
log = "0.4"
|
||||
rustls = "0.16"
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
|
||||
@@ -10,6 +10,7 @@ async fn main() {
|
||||
let auth = Authenticator::new(DeviceFlow::new(creds))
|
||||
.persist_tokens_to_disk("tokenstorage.json")
|
||||
.build()
|
||||
.await
|
||||
.expect("authenticator");
|
||||
|
||||
let scopes = &["https://www.googleapis.com/auth/youtube.readonly"];
|
||||
|
||||
@@ -14,6 +14,7 @@ async fn main() {
|
||||
))
|
||||
.persist_tokens_to_disk("tokencache.json")
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
let scopes = &["https://www.googleapis.com/auth/drive.file"];
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
use crate::authenticator_delegate::{AuthenticatorDelegate, DefaultAuthenticatorDelegate, Retry};
|
||||
use crate::authenticator_delegate::{AuthenticatorDelegate, DefaultAuthenticatorDelegate};
|
||||
use crate::refresh::RefreshFlow;
|
||||
use crate::storage::{hash_scopes, DiskTokenStorage, MemoryStorage, TokenStorage};
|
||||
use crate::storage::{self, Storage};
|
||||
use crate::types::{ApplicationSecret, GetToken, RefreshResult, RequestError, Token};
|
||||
|
||||
use futures::prelude::*;
|
||||
|
||||
use std::error::Error;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Authenticator abstracts different `GetToken` implementations behind one type and handles
|
||||
/// caching received tokens. It's important to use it (instead of the flows directly) because
|
||||
@@ -20,15 +21,11 @@ use std::pin::Pin;
|
||||
/// NOTE: It is recommended to use a client constructed like this in order to prevent functions
|
||||
/// like `hyper::run()` from hanging: `let client = hyper::Client::builder().keep_alive(false);`.
|
||||
/// Due to token requests being rare, this should not result in a too bad performance problem.
|
||||
struct AuthenticatorImpl<
|
||||
T: GetToken,
|
||||
S: TokenStorage,
|
||||
AD: AuthenticatorDelegate,
|
||||
C: hyper::client::connect::Connect,
|
||||
> {
|
||||
struct AuthenticatorImpl<T: GetToken, AD: AuthenticatorDelegate, C: hyper::client::connect::Connect>
|
||||
{
|
||||
client: hyper::Client<C>,
|
||||
inner: T,
|
||||
store: S,
|
||||
store: Storage,
|
||||
delegate: AD,
|
||||
}
|
||||
|
||||
@@ -69,17 +66,22 @@ pub trait AuthFlow<C> {
|
||||
fn build_token_getter(self, client: hyper::Client<C>) -> Self::TokenGetter;
|
||||
}
|
||||
|
||||
enum StorageType {
|
||||
Memory,
|
||||
Disk(PathBuf),
|
||||
}
|
||||
|
||||
/// An authenticator can be used with `InstalledFlow`'s or `DeviceFlow`'s and
|
||||
/// will refresh tokens as they expire as well as optionally persist tokens to
|
||||
/// disk.
|
||||
pub struct Authenticator<T, S, AD, C> {
|
||||
pub struct Authenticator<T, AD, C> {
|
||||
client: C,
|
||||
token_getter: T,
|
||||
store: io::Result<S>,
|
||||
storage_type: StorageType,
|
||||
delegate: AD,
|
||||
}
|
||||
|
||||
impl<T> Authenticator<T, MemoryStorage, DefaultAuthenticatorDelegate, DefaultHyperClient>
|
||||
impl<T> Authenticator<T, DefaultAuthenticatorDelegate, DefaultHyperClient>
|
||||
where
|
||||
T: AuthFlow<<DefaultHyperClient as HyperClientBuilder>::Connector>,
|
||||
{
|
||||
@@ -90,27 +92,27 @@ where
|
||||
///
|
||||
/// Examples
|
||||
/// ```
|
||||
/// # #[tokio::main]
|
||||
/// # async fn main() {
|
||||
/// use std::path::Path;
|
||||
/// use yup_oauth2::{ApplicationSecret, Authenticator, DeviceFlow};
|
||||
/// let creds = ApplicationSecret::default();
|
||||
/// let auth = Authenticator::new(DeviceFlow::new(creds)).build().unwrap();
|
||||
/// let auth = Authenticator::new(DeviceFlow::new(creds)).build().await.unwrap();
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(
|
||||
flow: T,
|
||||
) -> Authenticator<T, MemoryStorage, DefaultAuthenticatorDelegate, DefaultHyperClient> {
|
||||
pub fn new(flow: T) -> Authenticator<T, DefaultAuthenticatorDelegate, DefaultHyperClient> {
|
||||
Authenticator {
|
||||
client: DefaultHyperClient,
|
||||
token_getter: flow,
|
||||
store: Ok(MemoryStorage::new()),
|
||||
storage_type: StorageType::Memory,
|
||||
delegate: DefaultAuthenticatorDelegate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S, AD, C> Authenticator<T, S, AD, C>
|
||||
impl<T, AD, C> Authenticator<T, AD, C>
|
||||
where
|
||||
T: AuthFlow<C::Connector>,
|
||||
S: TokenStorage,
|
||||
AD: AuthenticatorDelegate,
|
||||
C: HyperClientBuilder,
|
||||
{
|
||||
@@ -118,7 +120,7 @@ where
|
||||
pub fn hyper_client<NewC>(
|
||||
self,
|
||||
hyper_client: hyper::Client<NewC>,
|
||||
) -> Authenticator<T, S, AD, hyper::Client<NewC>>
|
||||
) -> Authenticator<T, AD, hyper::Client<NewC>>
|
||||
where
|
||||
NewC: hyper::client::connect::Connect + 'static,
|
||||
T: AuthFlow<NewC>,
|
||||
@@ -126,21 +128,17 @@ where
|
||||
Authenticator {
|
||||
client: hyper_client,
|
||||
token_getter: self.token_getter,
|
||||
store: self.store,
|
||||
storage_type: self.storage_type,
|
||||
delegate: self.delegate,
|
||||
}
|
||||
}
|
||||
|
||||
/// Persist tokens to disk in the provided filename.
|
||||
pub fn persist_tokens_to_disk<P: AsRef<Path>>(
|
||||
self,
|
||||
path: P,
|
||||
) -> Authenticator<T, DiskTokenStorage, AD, C> {
|
||||
let disk_storage = DiskTokenStorage::new(path.as_ref().to_str().unwrap());
|
||||
pub fn persist_tokens_to_disk<P: Into<PathBuf>>(self, path: P) -> Authenticator<T, AD, C> {
|
||||
Authenticator {
|
||||
client: self.client,
|
||||
token_getter: self.token_getter,
|
||||
store: disk_storage,
|
||||
storage_type: StorageType::Disk(path.into()),
|
||||
delegate: self.delegate,
|
||||
}
|
||||
}
|
||||
@@ -149,24 +147,29 @@ where
|
||||
pub fn delegate<NewAD: AuthenticatorDelegate>(
|
||||
self,
|
||||
delegate: NewAD,
|
||||
) -> Authenticator<T, S, NewAD, C> {
|
||||
) -> Authenticator<T, NewAD, C> {
|
||||
Authenticator {
|
||||
client: self.client,
|
||||
token_getter: self.token_getter,
|
||||
store: self.store,
|
||||
storage_type: self.storage_type,
|
||||
delegate,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create the authenticator.
|
||||
pub fn build(self) -> io::Result<impl GetToken>
|
||||
pub async fn build(self) -> io::Result<impl GetToken>
|
||||
where
|
||||
T::TokenGetter: GetToken,
|
||||
C::Connector: hyper::client::connect::Connect + 'static,
|
||||
{
|
||||
let client = self.client.build_hyper_client();
|
||||
let store = self.store?;
|
||||
let inner = self.token_getter.build_token_getter(client.clone());
|
||||
let store = match self.storage_type {
|
||||
StorageType::Memory => Storage::Memory {
|
||||
tokens: Mutex::new(storage::JSONTokens::new()),
|
||||
},
|
||||
StorageType::Disk(path) => Storage::Disk(storage::DiskStorage::new(path).await?),
|
||||
};
|
||||
|
||||
Ok(AuthenticatorImpl {
|
||||
client,
|
||||
@@ -177,10 +180,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<GT, S, AD, C> AuthenticatorImpl<GT, S, AD, C>
|
||||
impl<GT, AD, C> AuthenticatorImpl<GT, AD, C>
|
||||
where
|
||||
GT: GetToken,
|
||||
S: TokenStorage,
|
||||
AD: AuthenticatorDelegate,
|
||||
C: hyper::client::connect::Connect + 'static,
|
||||
{
|
||||
@@ -188,83 +190,61 @@ where
|
||||
where
|
||||
T: AsRef<str> + Sync,
|
||||
{
|
||||
let scope_key = hash_scopes(scopes);
|
||||
let scope_key = storage::ScopeHash::new(scopes);
|
||||
let store = &self.store;
|
||||
let delegate = &self.delegate;
|
||||
let client = &self.client;
|
||||
let gettoken = &self.inner;
|
||||
let appsecret = gettoken.application_secret();
|
||||
loop {
|
||||
match store.get(scope_key, scopes) {
|
||||
Ok(Some(t)) if !t.expired() => {
|
||||
// unexpired token found
|
||||
return Ok(t);
|
||||
}
|
||||
Ok(Some(Token {
|
||||
refresh_token: Some(refresh_token),
|
||||
..
|
||||
})) => {
|
||||
// token is expired but has a refresh token.
|
||||
let rr = RefreshFlow::refresh_token(client, appsecret, &refresh_token).await?;
|
||||
match rr {
|
||||
RefreshResult::Error(ref e) => {
|
||||
delegate.token_refresh_failed(
|
||||
e.description(),
|
||||
Some("the request has likely timed out"),
|
||||
match store.get(scope_key, scopes) {
|
||||
Some(t) if !t.expired() => {
|
||||
// unexpired token found
|
||||
Ok(t)
|
||||
}
|
||||
Some(Token {
|
||||
refresh_token: Some(refresh_token),
|
||||
..
|
||||
}) => {
|
||||
// token is expired but has a refresh token.
|
||||
let rr = RefreshFlow::refresh_token(client, appsecret, &refresh_token).await?;
|
||||
match rr {
|
||||
RefreshResult::Error(ref e) => {
|
||||
delegate.token_refresh_failed(
|
||||
e.description(),
|
||||
Some("the request has likely timed out"),
|
||||
);
|
||||
Err(RequestError::Refresh(rr))
|
||||
}
|
||||
RefreshResult::RefreshError(ref s, ref ss) => {
|
||||
delegate.token_refresh_failed(
|
||||
&format!("{}{}", s, ss.as_ref().map(|s| format!(" ({})", s)).unwrap_or_else(String::new)),
|
||||
Some("the refresh token is likely invalid and your authorization has been revoked"),
|
||||
);
|
||||
return Err(RequestError::Refresh(rr));
|
||||
}
|
||||
RefreshResult::RefreshError(ref s, ref ss) => {
|
||||
delegate.token_refresh_failed(
|
||||
&format!("{}{}", s, ss.as_ref().map(|s| format!(" ({})", s)).unwrap_or_else(String::new)),
|
||||
Some("the refresh token is likely invalid and your authorization has been revoked"),
|
||||
);
|
||||
return Err(RequestError::Refresh(rr));
|
||||
}
|
||||
RefreshResult::Success(t) => {
|
||||
let x = store.set(scope_key, scopes, Some(t.clone()));
|
||||
if let Err(e) = x {
|
||||
match delegate.token_storage_failure(true, &e) {
|
||||
Retry::Skip => return Ok(t),
|
||||
Retry::Abort => return Err(RequestError::Cache(Box::new(e))),
|
||||
Retry::After(d) => tokio::timer::delay_for(d).await,
|
||||
}
|
||||
} else {
|
||||
return Ok(t);
|
||||
}
|
||||
}
|
||||
Err(RequestError::Refresh(rr))
|
||||
}
|
||||
RefreshResult::Success(t) => {
|
||||
store.set(scope_key, scopes, Some(t.clone())).await;
|
||||
Ok(t)
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
| Ok(Some(Token {
|
||||
refresh_token: None,
|
||||
..
|
||||
})) => {
|
||||
// no token in the cache or the token returned does not contain a refresh token.
|
||||
let t = gettoken.token(scopes).await?;
|
||||
if let Err(e) = store.set(scope_key, scopes, Some(t.clone())) {
|
||||
match delegate.token_storage_failure(true, &e) {
|
||||
Retry::Skip => return Ok(t),
|
||||
Retry::Abort => return Err(RequestError::Cache(Box::new(e))),
|
||||
Retry::After(d) => tokio::timer::delay_for(d).await,
|
||||
}
|
||||
} else {
|
||||
return Ok(t);
|
||||
}
|
||||
}
|
||||
Err(err) => match delegate.token_storage_failure(false, &err) {
|
||||
Retry::Abort | Retry::Skip => return Err(RequestError::Cache(Box::new(err))),
|
||||
Retry::After(d) => tokio::timer::delay_for(d).await,
|
||||
},
|
||||
}
|
||||
None
|
||||
| Some(Token {
|
||||
refresh_token: None,
|
||||
..
|
||||
}) => {
|
||||
// no token in the cache or the token returned does not contain a refresh token.
|
||||
let t = gettoken.token(scopes).await?;
|
||||
store.set(scope_key, scopes, Some(t.clone())).await;
|
||||
Ok(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<GT, S, AD, C> GetToken for AuthenticatorImpl<GT, S, AD, C>
|
||||
impl<GT, AD, C> GetToken for AuthenticatorImpl<GT, AD, C>
|
||||
where
|
||||
GT: GetToken,
|
||||
S: TokenStorage,
|
||||
AD: AuthenticatorDelegate,
|
||||
C: hyper::client::connect::Connect + 'static,
|
||||
{
|
||||
|
||||
@@ -79,16 +79,6 @@ pub trait AuthenticatorDelegate: Send + Sync {
|
||||
Retry::Abort
|
||||
}
|
||||
|
||||
/// Called whenever we failed to retrieve a token or set a token due to a storage error.
|
||||
/// You may use it to either ignore the incident or retry.
|
||||
/// This can be useful if the underlying `TokenStorage` may fail occasionally.
|
||||
/// if `is_set` is true, the failure resulted from `TokenStorage.set(...)`. Otherwise,
|
||||
/// it was `TokenStorage.get(...)`
|
||||
fn token_storage_failure(&self, is_set: bool, _: &(dyn Error + Send + Sync)) -> Retry {
|
||||
let _ = is_set;
|
||||
Retry::Abort
|
||||
}
|
||||
|
||||
/// The server denied the attempt to obtain a request code
|
||||
fn request_failure(&self, _: RequestError) {}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::pin::Pin;
|
||||
use std::time::Duration;
|
||||
|
||||
use ::log::{error, log};
|
||||
use ::log::error;
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures::prelude::*;
|
||||
use hyper;
|
||||
|
||||
@@ -61,6 +61,7 @@
|
||||
//! )
|
||||
//! .persist_tokens_to_disk("tokencache.json")
|
||||
//! .build()
|
||||
//! .await
|
||||
//! .unwrap();
|
||||
//!
|
||||
//! let scopes = &["https://www.googleapis.com/auth/drive.file"];
|
||||
@@ -96,7 +97,6 @@ pub use crate::device::{DeviceFlow, GOOGLE_DEVICE_CODE_URL};
|
||||
pub use crate::helper::*;
|
||||
pub use crate::installed::{InstalledFlow, InstalledFlowReturnMethod};
|
||||
pub use crate::service_account::*;
|
||||
pub use crate::storage::{DiskTokenStorage, MemoryStorage, NullStorage, TokenStorage};
|
||||
pub use crate::types::{
|
||||
ApplicationSecret, ConsoleApplicationSecret, GetToken, PollError, RefreshResult, RequestError,
|
||||
Scheme, Token, TokenType,
|
||||
|
||||
@@ -11,12 +11,11 @@
|
||||
//! Copyright (c) 2016 Google Inc (lewinb@google.com).
|
||||
//!
|
||||
|
||||
use std::default::Default;
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::Mutex;
|
||||
|
||||
use crate::authenticator::{DefaultHyperClient, HyperClientBuilder};
|
||||
use crate::storage::{hash_scopes, MemoryStorage, TokenStorage};
|
||||
use crate::storage::{self, Storage};
|
||||
use crate::types::{ApplicationSecret, GetToken, JsonErrorOr, RequestError, Token};
|
||||
|
||||
use futures::prelude::*;
|
||||
@@ -210,7 +209,7 @@ where
|
||||
struct ServiceAccountAccessImpl<C> {
|
||||
client: hyper::Client<C, hyper::Body>,
|
||||
key: ServiceAccountKey,
|
||||
cache: Arc<Mutex<MemoryStorage>>,
|
||||
cache: Storage,
|
||||
subject: Option<String>,
|
||||
signer: JWTSigner,
|
||||
}
|
||||
@@ -228,7 +227,9 @@ where
|
||||
Ok(ServiceAccountAccessImpl {
|
||||
client,
|
||||
key,
|
||||
cache: Arc::new(Mutex::new(MemoryStorage::default())),
|
||||
cache: Storage::Memory {
|
||||
tokens: Mutex::new(storage::JSONTokens::new()),
|
||||
},
|
||||
subject,
|
||||
signer,
|
||||
})
|
||||
@@ -309,10 +310,10 @@ where
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
let hash = hash_scopes(scopes);
|
||||
let hash = storage::ScopeHash::new(scopes);
|
||||
let cache = &self.cache;
|
||||
match cache.lock().unwrap().get(hash, scopes) {
|
||||
Ok(Some(token)) if !token.expired() => return Ok(token),
|
||||
match cache.get(hash, scopes) {
|
||||
Some(token) if !token.expired() => return Ok(token),
|
||||
_ => {}
|
||||
}
|
||||
let token = Self::request_token(
|
||||
@@ -323,7 +324,7 @@ where
|
||||
scopes,
|
||||
)
|
||||
.await?;
|
||||
let _ = cache.lock().unwrap().set(hash, scopes, Some(token.clone()));
|
||||
cache.set(hash, scopes, Some(token.clone())).await;
|
||||
Ok(token)
|
||||
}
|
||||
}
|
||||
@@ -425,13 +426,12 @@ mod tests {
|
||||
|
||||
assert!(acc
|
||||
.cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(
|
||||
3502164897243251857,
|
||||
dbg!(storage::ScopeHash::new(&[
|
||||
"https://www.googleapis.com/auth/pubsub"
|
||||
])),
|
||||
&["https://www.googleapis.com/auth/pubsub"],
|
||||
)
|
||||
.unwrap()
|
||||
.is_some());
|
||||
// Test that token is in cache (otherwise mock will tell us)
|
||||
let fut = async {
|
||||
|
||||
348
src/storage.rs
348
src/storage.rs
@@ -5,127 +5,65 @@
|
||||
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::error::Error;
|
||||
use std::fs;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Mutex;
|
||||
|
||||
use crate::types::Token;
|
||||
use itertools::Itertools;
|
||||
|
||||
/// Implements a specialized storage to set and retrieve `Token` instances.
|
||||
/// The `scope_hash` represents the signature of the scopes for which the given token
|
||||
/// should be stored or retrieved.
|
||||
/// For completeness, the underlying, sorted scopes are provided as well. They might be
|
||||
/// useful for presentation to the user.
|
||||
pub trait TokenStorage: Send + Sync {
|
||||
type Error: 'static + Error + Send + Sync;
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct ScopeHash(u64);
|
||||
|
||||
/// If `token` is None, it is invalid or revoked and should be removed from storage.
|
||||
/// Otherwise, it should be saved.
|
||||
fn set<T>(
|
||||
&self,
|
||||
scope_hash: u64,
|
||||
scopes: &[T],
|
||||
token: Option<Token>,
|
||||
) -> Result<(), Self::Error>
|
||||
where
|
||||
T: AsRef<str>;
|
||||
|
||||
/// A `None` result indicates that there is no token for the given scope_hash.
|
||||
fn get<T>(&self, scope_hash: u64, scopes: &[T]) -> Result<Option<Token>, Self::Error>
|
||||
where
|
||||
T: AsRef<str>;
|
||||
}
|
||||
|
||||
/// Calculate a hash value describing the scopes. The order of the scopes in the
|
||||
/// list does not change the hash value. i.e. two lists that contains the exact
|
||||
/// same scopes, but in different order will return the same hash value.
|
||||
pub fn hash_scopes<T>(scopes: &[T]) -> u64
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
let mut hash_sum = DefaultHasher::new().finish();
|
||||
for scope in scopes {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
scope.as_ref().hash(&mut hasher);
|
||||
hash_sum ^= hasher.finish();
|
||||
}
|
||||
hash_sum
|
||||
}
|
||||
|
||||
/// A storage that remembers nothing.
|
||||
#[derive(Default)]
|
||||
pub struct NullStorage;
|
||||
|
||||
impl TokenStorage for NullStorage {
|
||||
type Error = std::convert::Infallible;
|
||||
fn set<T>(&self, _: u64, _: &[T], _: Option<Token>) -> Result<(), Self::Error>
|
||||
impl ScopeHash {
|
||||
/// Calculate a hash value describing the scopes. The order of the scopes in the
|
||||
/// list does not change the hash value. i.e. two lists that contains the exact
|
||||
/// same scopes, but in different order will return the same hash value.
|
||||
pub fn new<T>(scopes: &[T]) -> Self
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get<T>(&self, _: u64, _: &[T]) -> Result<Option<Token>, Self::Error>
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// A storage that remembers values for one session only.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MemoryStorage {
|
||||
tokens: Mutex<Vec<JSONToken>>,
|
||||
}
|
||||
|
||||
impl MemoryStorage {
|
||||
pub fn new() -> MemoryStorage {
|
||||
Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenStorage for MemoryStorage {
|
||||
type Error = std::convert::Infallible;
|
||||
|
||||
fn set<T>(&self, scope_hash: u64, scopes: &[T], token: Option<Token>) -> Result<(), Self::Error>
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
let mut tokens = self.tokens.lock().expect("poisoned mutex");
|
||||
let matched = tokens.iter().find_position(|x| x.hash == scope_hash);
|
||||
if let Some((idx, _)) = matched {
|
||||
self.tokens.retain(|x| x.hash != scope_hash);
|
||||
let mut hash_sum = DefaultHasher::new().finish();
|
||||
for scope in scopes {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
scope.as_ref().hash(&mut hasher);
|
||||
hash_sum ^= hasher.finish();
|
||||
}
|
||||
|
||||
if let Some(t) = token {
|
||||
tokens.push(JSONToken {
|
||||
hash: scope_hash,
|
||||
scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()),
|
||||
token: t,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
ScopeHash(hash_sum)
|
||||
}
|
||||
}
|
||||
|
||||
fn get<T>(&self, scope_hash: u64, scopes: &[T]) -> Result<Option<Token>, Self::Error>
|
||||
pub(crate) enum Storage {
|
||||
Memory { tokens: Mutex<JSONTokens> },
|
||||
Disk(DiskStorage),
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
pub(crate) async fn set<T>(&self, h: ScopeHash, scopes: &[T], token: Option<Token>)
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
let tokens = self.tokens.lock().expect("poisoned mutex");
|
||||
Ok(token_for_scopes(&tokens, scope_hash, scopes))
|
||||
match self {
|
||||
Storage::Memory { tokens } => tokens.lock().unwrap().set(h, scopes, token),
|
||||
Storage::Disk(disk_storage) => disk_storage.set(h, scopes, token).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get<T>(&self, h: ScopeHash, scopes: &[T]) -> Option<Token>
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
match self {
|
||||
Storage::Memory { tokens } => tokens.lock().unwrap().get(h, scopes),
|
||||
Storage::Disk(disk_storage) => disk_storage.get(h, scopes),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single stored token.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct JSONToken {
|
||||
pub hash: u64,
|
||||
pub hash: ScopeHash,
|
||||
pub scopes: Option<Vec<String>>,
|
||||
pub token: Token,
|
||||
}
|
||||
@@ -151,121 +89,123 @@ impl Ord for JSONToken {
|
||||
}
|
||||
|
||||
/// List of tokens in a JSON object
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct JSONTokens {
|
||||
pub tokens: Vec<JSONToken>,
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct JSONTokens {
|
||||
tokens: Vec<JSONToken>,
|
||||
}
|
||||
|
||||
/// Serializes tokens to a JSON file on disk.
|
||||
#[derive(Default)]
|
||||
pub struct DiskTokenStorage {
|
||||
location: PathBuf,
|
||||
tokens: Mutex<Vec<JSONToken>>,
|
||||
impl JSONTokens {
|
||||
pub(crate) fn new() -> Self {
|
||||
JSONTokens { tokens: Vec::new() }
|
||||
}
|
||||
|
||||
pub(crate) async fn load_from_file(filename: &Path) -> Result<Self, io::Error> {
|
||||
let contents = tokio::fs::read(filename).await?;
|
||||
let container: JSONTokens = serde_json::from_slice(&contents)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
|
||||
Ok(container)
|
||||
}
|
||||
|
||||
fn get<T>(&self, h: ScopeHash, scopes: &[T]) -> Option<Token>
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
for t in self.tokens.iter() {
|
||||
if let Some(token_scopes) = &t.scopes {
|
||||
if scopes
|
||||
.iter()
|
||||
.all(|s| token_scopes.iter().any(|t| t == s.as_ref()))
|
||||
{
|
||||
return Some(t.token.clone());
|
||||
}
|
||||
} else if h == t.hash {
|
||||
return Some(t.token.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn set<T>(&mut self, h: ScopeHash, scopes: &[T], token: Option<Token>)
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
eprintln!("setting: {:?}, {:?}", h, token);
|
||||
self.tokens.retain(|x| x.hash != h);
|
||||
|
||||
match token {
|
||||
None => (),
|
||||
Some(t) => {
|
||||
self.tokens.push(JSONToken {
|
||||
hash: h,
|
||||
scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()),
|
||||
token: t,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: ideally this function would accept &Path, but tokio requires the
|
||||
// path be 'static. Revisit this and ask why tokio::fs::write has that
|
||||
// limitation.
|
||||
async fn dump_to_file(&self, path: PathBuf) -> Result<(), io::Error> {
|
||||
let serialized = serde_json::to_string(self)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
|
||||
tokio::fs::write(path, &serialized).await
|
||||
}
|
||||
}
|
||||
|
||||
impl DiskTokenStorage {
|
||||
pub fn new<S: Into<PathBuf>>(location: S) -> Result<DiskTokenStorage, io::Error> {
|
||||
let filename = location.into();
|
||||
let tokens = match load_from_file(&filename) {
|
||||
Ok(tokens) => tokens,
|
||||
Err(e) if e.kind() == io::ErrorKind::NotFound => Vec::new(),
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
Ok(DiskTokenStorage {
|
||||
location: filename,
|
||||
pub(crate) struct DiskStorage {
|
||||
tokens: Mutex<JSONTokens>,
|
||||
write_tx: tokio::sync::mpsc::Sender<JSONTokens>,
|
||||
}
|
||||
|
||||
impl DiskStorage {
|
||||
pub(crate) async fn new(path: PathBuf) -> Result<Self, io::Error> {
|
||||
let tokens = JSONTokens::load_from_file(&path).await?;
|
||||
// Writing to disk will happen in a separate task. This means in the
|
||||
// common case returning a token to the user will not be required to
|
||||
// wait for disk i/o. We communicate with a dedicated writer task via a
|
||||
// buffered channel. This ensures that the writes happen in the order
|
||||
// received, and if writes fall too far behind we will block GetToken
|
||||
// requests until disk i/o completes.
|
||||
let (write_tx, mut write_rx) = tokio::sync::mpsc::channel::<JSONTokens>(2);
|
||||
tokio::spawn(async move {
|
||||
while let Some(tokens) = write_rx.recv().await {
|
||||
if let Err(e) = tokens.dump_to_file(path.to_path_buf()).await {
|
||||
log::error!("Failed to write token storage to disk: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
Ok(DiskStorage {
|
||||
tokens: Mutex::new(tokens),
|
||||
write_tx,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dump_to_file(&self) -> Result<(), io::Error> {
|
||||
let mut jsontokens = JSONTokens { tokens: Vec::new() };
|
||||
|
||||
{
|
||||
let tokens = self.tokens.lock().expect("mutex poisoned");
|
||||
for token in tokens.iter() {
|
||||
jsontokens.tokens.push((*token).clone());
|
||||
}
|
||||
}
|
||||
|
||||
let serialized;
|
||||
|
||||
match serde_json::to_string(&jsontokens) {
|
||||
Result::Err(e) => return Result::Err(io::Error::new(io::ErrorKind::InvalidData, e)),
|
||||
Result::Ok(s) => serialized = s,
|
||||
}
|
||||
|
||||
// TODO: Write to disk asynchronously so that we don't stall the eventloop if invoked in async context.
|
||||
let mut f = fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.truncate(true)
|
||||
.open(&self.location)?;
|
||||
f.write(serialized.as_ref()).map(|_| ())
|
||||
}
|
||||
}
|
||||
|
||||
fn load_from_file(filename: &Path) -> Result<Vec<JSONToken>, io::Error> {
|
||||
let contents = std::fs::read_to_string(filename)?;
|
||||
let container: JSONTokens = serde_json::from_str(&contents)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
|
||||
Ok(container.tokens)
|
||||
}
|
||||
|
||||
impl TokenStorage for DiskTokenStorage {
|
||||
type Error = io::Error;
|
||||
fn set<T>(&self, scope_hash: u64, scopes: &[T], token: Option<Token>) -> Result<(), Self::Error>
|
||||
async fn set<T>(&self, h: ScopeHash, scopes: &[T], token: Option<Token>)
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
{
|
||||
let mut tokens = self.tokens.lock().expect("poisoned mutex");
|
||||
let matched = tokens.iter().find_position(|x| x.hash == scope_hash);
|
||||
if let Some((idx, _)) = matched {
|
||||
self.tokens.retain(|x| x.hash != scope_hash);
|
||||
}
|
||||
|
||||
match token {
|
||||
None => (),
|
||||
Some(t) => {
|
||||
tokens.push(JSONToken {
|
||||
hash: scope_hash,
|
||||
scopes: Some(scopes.iter().map(|x| x.as_ref().to_string()).collect()),
|
||||
token: t,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
self.dump_to_file()
|
||||
let cloned_tokens = {
|
||||
let mut tokens = self.tokens.lock().unwrap();
|
||||
tokens.set(h, scopes, token);
|
||||
tokens.clone()
|
||||
};
|
||||
self.write_tx
|
||||
.clone()
|
||||
.send(cloned_tokens)
|
||||
.await
|
||||
.expect("disk storage task not running");
|
||||
}
|
||||
|
||||
fn get<T>(&self, scope_hash: u64, scopes: &[T]) -> Result<Option<Token>, Self::Error>
|
||||
pub(crate) fn get<T>(&self, h: ScopeHash, scopes: &[T]) -> Option<Token>
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
let tokens = self.tokens.lock().expect("poisoned mutex");
|
||||
Ok(token_for_scopes(&tokens, scope_hash, scopes))
|
||||
self.tokens.lock().unwrap().get(h, scopes)
|
||||
}
|
||||
}
|
||||
|
||||
fn token_for_scopes<T>(tokens: &[JSONToken], scope_hash: u64, scopes: &[T]) -> Option<Token>
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
for t in tokens.iter() {
|
||||
if let Some(token_scopes) = &t.scopes {
|
||||
if scopes
|
||||
.iter()
|
||||
.all(|s| token_scopes.iter().any(|t| t == s.as_ref()))
|
||||
{
|
||||
return Some(t.token.clone());
|
||||
}
|
||||
} else if scope_hash == t.hash {
|
||||
return Some(t.token.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -273,19 +213,25 @@ mod tests {
|
||||
#[test]
|
||||
fn test_hash_scopes() {
|
||||
// Idential list should hash equal.
|
||||
assert_eq!(hash_scopes(&["foo", "bar"]), hash_scopes(&["foo", "bar"]));
|
||||
// The hash should be order independent.
|
||||
assert_eq!(hash_scopes(&["bar", "foo"]), hash_scopes(&["foo", "bar"]));
|
||||
assert_eq!(
|
||||
hash_scopes(&["bar", "baz", "bat"]),
|
||||
hash_scopes(&["baz", "bar", "bat"])
|
||||
ScopeHash::new(&["foo", "bar"]),
|
||||
ScopeHash::new(&["foo", "bar"])
|
||||
);
|
||||
// The hash should be order independent.
|
||||
assert_eq!(
|
||||
ScopeHash::new(&["bar", "foo"]),
|
||||
ScopeHash::new(&["foo", "bar"])
|
||||
);
|
||||
assert_eq!(
|
||||
ScopeHash::new(&["bar", "baz", "bat"]),
|
||||
ScopeHash::new(&["baz", "bar", "bat"])
|
||||
);
|
||||
|
||||
// Ensure hashes differ when the contents are different by more than
|
||||
// just order.
|
||||
assert_ne!(
|
||||
hash_scopes(&["foo", "bar", "baz"]),
|
||||
hash_scopes(&["foo", "bar"])
|
||||
ScopeHash::new(&["foo", "bar", "baz"]),
|
||||
ScopeHash::new(&["foo", "bar"])
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user