diff --git a/src/storage.rs b/src/storage.rs index 30eabe9..4d945ad 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -4,7 +4,6 @@ // use std::collections::hash_map::DefaultHasher; -use std::collections::HashMap; use std::error::Error; use std::fmt; use std::fs; @@ -13,6 +12,7 @@ use std::io; use std::io::{Read, Write}; use crate::types::Token; +use itertools::Itertools; /// Implements a specialized storage to set and retrieve `Token` instances. /// The `scope_hash` represents the signature of the scopes for which the given token @@ -54,6 +54,7 @@ where /// A storage that remembers nothing. #[derive(Default)] pub struct NullStorage; + #[derive(Debug)] pub struct NullError; @@ -82,7 +83,7 @@ impl TokenStorage for NullStorage { /// A storage that remembers values for one session only. #[derive(Debug, Default)] pub struct MemoryStorage { - pub tokens: HashMap, + tokens: Vec, } impl MemoryStorage { @@ -97,28 +98,53 @@ impl TokenStorage for MemoryStorage { fn set( &mut self, scope_hash: u64, - _: &Vec<&str>, + scopes: &Vec<&str>, token: Option, ) -> Result<(), NullError> { match token { - Some(t) => self.tokens.insert(scope_hash, t), - None => self.tokens.remove(&scope_hash), + Some(t) => { + self.tokens.push(JSONToken { + hash: scope_hash, + scopes: Some(scopes.iter().map(|x| x.to_string()).collect()), + token: t.clone(), + }); + () + } + None => { + let matched = self.tokens.iter().find_position(|x| x.hash == scope_hash); + if let Some((idx, _)) = matched { + self.tokens.remove(idx); + } + } }; Ok(()) } - fn get(&self, scope_hash: u64, _: &Vec<&str>) -> Result, NullError> { - match self.tokens.get(&scope_hash) { - Some(t) => Ok(Some(t.clone())), - None => Ok(None), + fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, NullError> { + let scopes: Vec<_> = scopes.iter().sorted().unique().collect(); + + for t in &self.tokens { + if let Some(token_scopes) = &t.scopes { + let matched = token_scopes + .iter() + .filter(|x| scopes.contains(&&&x[..])) + .count(); + if matched >= scopes.len() { + return Result::Ok(Some(t.token.clone())); + } + } else if scope_hash == t.hash { + return Result::Ok(Some(t.token.clone())); + } } + Result::Ok(None) } } /// A single stored token. -#[derive(Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct JSONToken { pub hash: u64, + pub scopes: Option>, pub token: Token, } @@ -132,14 +158,14 @@ struct JSONTokens { #[derive(Default)] pub struct DiskTokenStorage { location: String, - tokens: HashMap, + tokens: Vec, } impl DiskTokenStorage { pub fn new>(location: S) -> Result { let mut dts = DiskTokenStorage { location: location.as_ref().to_owned(), - tokens: HashMap::new(), + tokens: Vec::new(), }; // best-effort @@ -173,7 +199,7 @@ impl DiskTokenStorage { } for t in tokens.tokens { - self.tokens.insert(t.hash, t.token); + self.tokens.push(t); } return Result::Ok(()); } @@ -181,11 +207,8 @@ impl DiskTokenStorage { pub fn dump_to_file(&mut self) -> Result<(), io::Error> { let mut jsontokens = JSONTokens { tokens: Vec::new() }; - for (hash, token) in self.tokens.iter() { - jsontokens.tokens.push(JSONToken { - hash: *hash, - token: token.clone(), - }); + for token in self.tokens.iter() { + jsontokens.tokens.push((*token).clone()); } let serialized;; @@ -209,22 +232,45 @@ impl TokenStorage for DiskTokenStorage { fn set( &mut self, scope_hash: u64, - _: &Vec<&str>, + scopes: &Vec<&str>, token: Option, ) -> Result<(), Self::Error> { match token { None => { - self.tokens.remove(&scope_hash); + let matched = self.tokens.iter().find_position(|x| x.hash == scope_hash); + if let Some((idx, _)) = matched { + self.tokens.remove(idx); + } () } Some(t) => { - self.tokens.insert(scope_hash, t.clone()); + self.tokens.push(JSONToken { + hash: scope_hash, + scopes: Some(scopes.iter().map(|x| x.to_string()).collect()), + token: t.clone(), + }); () } } self.dump_to_file() } - fn get(&self, scope_hash: u64, _: &Vec<&str>) -> Result, Self::Error> { - Result::Ok(self.tokens.get(&scope_hash).map(|tok| tok.clone())) + fn get(&self, scope_hash: u64, scopes: &Vec<&str>) -> Result, Self::Error> { + let scopes: Vec<_> = scopes.iter().sorted().unique().collect(); + + for t in &self.tokens { + if let Some(token_scopes) = &t.scopes { + let matched = token_scopes + .iter() + .filter(|x| scopes.contains(&&&x[..])) + .count(); + // we may have some of the tokens as denormalized (many namespaces repeated) + if matched >= scopes.len() { + return Result::Ok(Some(t.token.clone())); + } + } else if scope_hash == t.hash { + return Result::Ok(Some(t.token.clone())); + } + } + Result::Ok(None) } }