mirror of
https://github.com/OMGeeky/yup-oauth2.git
synced 2026-01-06 19:29:39 +01:00
feat: Use futures-aware mutex
This commit is contained in:
@@ -24,6 +24,7 @@ serde_json = "1.0"
|
||||
tokio = { version = "0.2", features = ["fs", "macros", "io-std", "time"] }
|
||||
url = "2"
|
||||
percent-encoding = "2"
|
||||
futures = "0.3"
|
||||
|
||||
[dev-dependencies]
|
||||
httptest = "0.11.1"
|
||||
|
||||
@@ -9,11 +9,11 @@ use crate::storage::{self, Storage};
|
||||
use crate::types::{AccessToken, ApplicationSecret, TokenInfo};
|
||||
use private::AuthFlow;
|
||||
|
||||
use futures::lock::Mutex;
|
||||
use std::borrow::Cow;
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Authenticator is responsible for fetching tokens, handling refreshing tokens,
|
||||
/// and optionally persisting tokens to disk.
|
||||
@@ -80,7 +80,10 @@ where
|
||||
DisplayScopes(scopes)
|
||||
);
|
||||
let hashed_scopes = storage::ScopeSet::from(scopes);
|
||||
match (self.storage.get(hashed_scopes), self.auth_flow.app_secret()) {
|
||||
match (
|
||||
self.storage.get(hashed_scopes).await,
|
||||
self.auth_flow.app_secret(),
|
||||
) {
|
||||
(Some(t), _) if !t.is_expired() && !force_refresh => {
|
||||
// unexpired token found
|
||||
log::debug!("found valid token in cache: {:?}", t);
|
||||
|
||||
@@ -6,9 +6,10 @@ use crate::authenticator_delegate::{DefaultInstalledFlowDelegate, InstalledFlowD
|
||||
use crate::error::Error;
|
||||
use crate::types::{ApplicationSecret, TokenInfo};
|
||||
|
||||
use futures::lock::Mutex;
|
||||
use std::convert::AsRef;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::Arc;
|
||||
|
||||
use hyper::header;
|
||||
use percent_encoding::{percent_encode, AsciiSet, CONTROLS};
|
||||
@@ -286,8 +287,9 @@ impl InstalledFlowServer {
|
||||
}
|
||||
|
||||
mod installed_flow_server {
|
||||
use futures::lock::Mutex;
|
||||
use hyper::{Body, Request, Response, StatusCode, Uri};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::oneshot;
|
||||
use url::form_urlencoded;
|
||||
|
||||
@@ -312,7 +314,7 @@ mod installed_flow_server {
|
||||
.body(hyper::Body::from("Unparseable URL")),
|
||||
Ok(url) => match auth_code_from_url(url) {
|
||||
Some(auth_code) => {
|
||||
if let Some(sender) = auth_code_tx.lock().unwrap().take() {
|
||||
if let Some(sender) = auth_code_tx.lock().await.take() {
|
||||
let _ = sender.send(auth_code);
|
||||
}
|
||||
hyper::Response::builder().status(StatusCode::OK).body(
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
//
|
||||
use crate::types::TokenInfo;
|
||||
|
||||
use futures::lock::Mutex;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Mutex;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -120,18 +120,18 @@ impl Storage {
|
||||
T: AsRef<str>,
|
||||
{
|
||||
match self {
|
||||
Storage::Memory { tokens } => tokens.lock().unwrap().set(scopes, token),
|
||||
Storage::Memory { tokens } => tokens.lock().await.set(scopes, token),
|
||||
Storage::Disk(disk_storage) => disk_storage.set(scopes, token).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get<T>(&self, scopes: ScopeSet<T>) -> Option<TokenInfo>
|
||||
pub(crate) async fn get<T>(&self, scopes: ScopeSet<'_, T>) -> Option<TokenInfo>
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
match self {
|
||||
Storage::Memory { tokens } => tokens.lock().unwrap().get(scopes),
|
||||
Storage::Disk(disk_storage) => disk_storage.get(scopes),
|
||||
Storage::Memory { tokens } => tokens.lock().await.get(scopes),
|
||||
Storage::Disk(disk_storage) => disk_storage.get(scopes).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -334,7 +334,7 @@ impl DiskStorage {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
let json = {
|
||||
use std::ops::Deref;
|
||||
let mut lock = self.tokens.lock().unwrap();
|
||||
let mut lock = self.tokens.lock().await;
|
||||
lock.set(scopes, token)?;
|
||||
serde_json::to_string(lock.deref())
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
|
||||
@@ -344,11 +344,11 @@ impl DiskStorage {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn get<T>(&self, scopes: ScopeSet<T>) -> Option<TokenInfo>
|
||||
pub(crate) async fn get<T>(&self, scopes: ScopeSet<'_, T>) -> Option<TokenInfo>
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
self.tokens.lock().unwrap().get(scopes)
|
||||
self.tokens.lock().await.get(scopes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -413,19 +413,25 @@ mod tests {
|
||||
let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(storage.get(scope_set).is_none());
|
||||
assert!(storage.get(scope_set).await.is_none());
|
||||
storage
|
||||
.set(scope_set, new_token("my_access_token"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(storage.get(scope_set), Some(new_token("my_access_token")));
|
||||
assert_eq!(
|
||||
storage.get(scope_set).await,
|
||||
Some(new_token("my_access_token"))
|
||||
);
|
||||
}
|
||||
{
|
||||
// Create a new DiskStorage instance and verify the tokens were read from disk correctly.
|
||||
let storage = DiskStorage::new(tempdir.path().join("tokenstorage.json"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(storage.get(scope_set), Some(new_token("my_access_token")));
|
||||
assert_eq!(
|
||||
storage.get(scope_set).await,
|
||||
Some(new_token("my_access_token"))
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user