From 09d1f05a003ce3f25363741a5510d6d0cf67ffed Mon Sep 17 00:00:00 2001 From: Abdul Rehman <10097155+abdul-rehman0@users.noreply.github.com> Date: Mon, 6 Apr 2020 15:31:49 +0500 Subject: [PATCH] feat: Use futures-aware mutex --- Cargo.toml | 1 + src/authenticator.rs | 7 +++++-- src/installed.rs | 8 +++++--- src/storage.rs | 28 +++++++++++++++++----------- 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8663691..6250c42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ serde_json = "1.0" tokio = { version = "0.2", features = ["fs", "macros", "io-std", "time"] } url = "2" percent-encoding = "2" +futures = "0.3" [dev-dependencies] httptest = "0.11.1" diff --git a/src/authenticator.rs b/src/authenticator.rs index 009e6e7..69cbe1d 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -9,11 +9,11 @@ use crate::storage::{self, Storage}; use crate::types::{AccessToken, ApplicationSecret, TokenInfo}; use private::AuthFlow; +use futures::lock::Mutex; use std::borrow::Cow; use std::fmt; use std::io; use std::path::PathBuf; -use std::sync::Mutex; /// Authenticator is responsible for fetching tokens, handling refreshing tokens, /// and optionally persisting tokens to disk. @@ -80,7 +80,10 @@ where DisplayScopes(scopes) ); let hashed_scopes = storage::ScopeSet::from(scopes); - match (self.storage.get(hashed_scopes), self.auth_flow.app_secret()) { + match ( + self.storage.get(hashed_scopes).await, + self.auth_flow.app_secret(), + ) { (Some(t), _) if !t.is_expired() && !force_refresh => { // unexpired token found log::debug!("found valid token in cache: {:?}", t); diff --git a/src/installed.rs b/src/installed.rs index 35fe91f..41507bb 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -6,9 +6,10 @@ use crate::authenticator_delegate::{DefaultInstalledFlowDelegate, InstalledFlowD use crate::error::Error; use crate::types::{ApplicationSecret, TokenInfo}; +use futures::lock::Mutex; use std::convert::AsRef; use std::net::SocketAddr; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use hyper::header; use percent_encoding::{percent_encode, AsciiSet, CONTROLS}; @@ -286,8 +287,9 @@ impl InstalledFlowServer { } mod installed_flow_server { + use futures::lock::Mutex; use hyper::{Body, Request, Response, StatusCode, Uri}; - use std::sync::{Arc, Mutex}; + use std::sync::Arc; use tokio::sync::oneshot; use url::form_urlencoded; @@ -312,7 +314,7 @@ mod installed_flow_server { .body(hyper::Body::from("Unparseable URL")), Ok(url) => match auth_code_from_url(url) { Some(auth_code) => { - if let Some(sender) = auth_code_tx.lock().unwrap().take() { + if let Some(sender) = auth_code_tx.lock().await.take() { let _ = sender.send(auth_code); } hyper::Response::builder().status(StatusCode::OK).body( diff --git a/src/storage.rs b/src/storage.rs index 4acc30a..308eeea 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -4,10 +4,10 @@ // use crate::types::TokenInfo; +use futures::lock::Mutex; use std::collections::HashMap; use std::io; use std::path::{Path, PathBuf}; -use std::sync::Mutex; use serde::{Deserialize, Serialize}; @@ -120,18 +120,18 @@ impl Storage { T: AsRef, { match self { - Storage::Memory { tokens } => tokens.lock().unwrap().set(scopes, token), + Storage::Memory { tokens } => tokens.lock().await.set(scopes, token), Storage::Disk(disk_storage) => disk_storage.set(scopes, token).await, } } - pub(crate) fn get(&self, scopes: ScopeSet) -> Option + pub(crate) async fn get(&self, scopes: ScopeSet<'_, T>) -> Option where T: AsRef, { match self { - Storage::Memory { tokens } => tokens.lock().unwrap().get(scopes), - Storage::Disk(disk_storage) => disk_storage.get(scopes), + Storage::Memory { tokens } => tokens.lock().await.get(scopes), + Storage::Disk(disk_storage) => disk_storage.get(scopes).await, } } } @@ -334,7 +334,7 @@ impl DiskStorage { use tokio::io::AsyncWriteExt; let json = { use std::ops::Deref; - let mut lock = self.tokens.lock().unwrap(); + let mut lock = self.tokens.lock().await; lock.set(scopes, token)?; serde_json::to_string(lock.deref()) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? @@ -344,11 +344,11 @@ impl DiskStorage { Ok(()) } - pub(crate) fn get(&self, scopes: ScopeSet) -> Option + pub(crate) async fn get(&self, scopes: ScopeSet<'_, T>) -> Option where T: AsRef, { - self.tokens.lock().unwrap().get(scopes) + self.tokens.lock().await.get(scopes) } } @@ -413,19 +413,25 @@ mod tests { let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json")) .await .unwrap(); - assert!(storage.get(scope_set).is_none()); + assert!(storage.get(scope_set).await.is_none()); storage .set(scope_set, new_token("my_access_token")) .await .unwrap(); - assert_eq!(storage.get(scope_set), Some(new_token("my_access_token"))); + assert_eq!( + storage.get(scope_set).await, + Some(new_token("my_access_token")) + ); } { // Create a new DiskStorage instance and verify the tokens were read from disk correctly. let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json")) .await .unwrap(); - assert_eq!(storage.get(scope_set), Some(new_token("my_access_token"))); + assert_eq!( + storage.get(scope_set).await, + Some(new_token("my_access_token")) + ); } } }