feat: Use futures-aware mutex

This commit is contained in:
Abdul Rehman
2020-04-06 15:31:49 +05:00
parent c5bad4c209
commit 09d1f05a00
4 changed files with 28 additions and 16 deletions

View File

@@ -24,6 +24,7 @@ serde_json = "1.0"
tokio = { version = "0.2", features = ["fs", "macros", "io-std", "time"] }
url = "2"
percent-encoding = "2"
futures = "0.3"
[dev-dependencies]
httptest = "0.11.1"

View File

@@ -9,11 +9,11 @@ use crate::storage::{self, Storage};
use crate::types::{AccessToken, ApplicationSecret, TokenInfo};
use private::AuthFlow;
use futures::lock::Mutex;
use std::borrow::Cow;
use std::fmt;
use std::io;
use std::path::PathBuf;
use std::sync::Mutex;
/// Authenticator is responsible for fetching tokens, handling refreshing tokens,
/// and optionally persisting tokens to disk.
@@ -80,7 +80,10 @@ where
DisplayScopes(scopes)
);
let hashed_scopes = storage::ScopeSet::from(scopes);
match (self.storage.get(hashed_scopes), self.auth_flow.app_secret()) {
match (
self.storage.get(hashed_scopes).await,
self.auth_flow.app_secret(),
) {
(Some(t), _) if !t.is_expired() && !force_refresh => {
// unexpired token found
log::debug!("found valid token in cache: {:?}", t);

View File

@@ -6,9 +6,10 @@ use crate::authenticator_delegate::{DefaultInstalledFlowDelegate, InstalledFlowD
use crate::error::Error;
use crate::types::{ApplicationSecret, TokenInfo};
use futures::lock::Mutex;
use std::convert::AsRef;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use hyper::header;
use percent_encoding::{percent_encode, AsciiSet, CONTROLS};
@@ -286,8 +287,9 @@ impl InstalledFlowServer {
}
mod installed_flow_server {
use futures::lock::Mutex;
use hyper::{Body, Request, Response, StatusCode, Uri};
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use tokio::sync::oneshot;
use url::form_urlencoded;
@@ -312,7 +314,7 @@ mod installed_flow_server {
.body(hyper::Body::from("Unparseable URL")),
Ok(url) => match auth_code_from_url(url) {
Some(auth_code) => {
if let Some(sender) = auth_code_tx.lock().unwrap().take() {
if let Some(sender) = auth_code_tx.lock().await.take() {
let _ = sender.send(auth_code);
}
hyper::Response::builder().status(StatusCode::OK).body(

View File

@@ -4,10 +4,10 @@
//
use crate::types::TokenInfo;
use futures::lock::Mutex;
use std::collections::HashMap;
use std::io;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use serde::{Deserialize, Serialize};
@@ -120,18 +120,18 @@ impl Storage {
T: AsRef<str>,
{
match self {
Storage::Memory { tokens } => tokens.lock().unwrap().set(scopes, token),
Storage::Memory { tokens } => tokens.lock().await.set(scopes, token),
Storage::Disk(disk_storage) => disk_storage.set(scopes, token).await,
}
}
pub(crate) fn get<T>(&self, scopes: ScopeSet<T>) -> Option<TokenInfo>
pub(crate) async fn get<T>(&self, scopes: ScopeSet<'_, T>) -> Option<TokenInfo>
where
T: AsRef<str>,
{
match self {
Storage::Memory { tokens } => tokens.lock().unwrap().get(scopes),
Storage::Disk(disk_storage) => disk_storage.get(scopes),
Storage::Memory { tokens } => tokens.lock().await.get(scopes),
Storage::Disk(disk_storage) => disk_storage.get(scopes).await,
}
}
}
@@ -334,7 +334,7 @@ impl DiskStorage {
use tokio::io::AsyncWriteExt;
let json = {
use std::ops::Deref;
let mut lock = self.tokens.lock().unwrap();
let mut lock = self.tokens.lock().await;
lock.set(scopes, token)?;
serde_json::to_string(lock.deref())
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
@@ -344,11 +344,11 @@ impl DiskStorage {
Ok(())
}
pub(crate) fn get<T>(&self, scopes: ScopeSet<T>) -> Option<TokenInfo>
pub(crate) async fn get<T>(&self, scopes: ScopeSet<'_, T>) -> Option<TokenInfo>
where
T: AsRef<str>,
{
self.tokens.lock().unwrap().get(scopes)
self.tokens.lock().await.get(scopes)
}
}
@@ -413,19 +413,25 @@ mod tests {
let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json"))
.await
.unwrap();
assert!(storage.get(scope_set).is_none());
assert!(storage.get(scope_set).await.is_none());
storage
.set(scope_set, new_token("my_access_token"))
.await
.unwrap();
assert_eq!(storage.get(scope_set), Some(new_token("my_access_token")));
assert_eq!(
storage.get(scope_set).await,
Some(new_token("my_access_token"))
);
}
{
// Create a new DiskStorage instance and verify the tokens were read from disk correctly.
let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json"))
.await
.unwrap();
assert_eq!(storage.get(scope_set), Some(new_token("my_access_token")));
assert_eq!(
storage.get(scope_set).await,
Some(new_token("my_access_token"))
);
}
}
}