From 4cfbc6e5fcb0d794bdd16d153d506c836659f6d1 Mon Sep 17 00:00:00 2001 From: Lewin Bormann Date: Thu, 13 Jun 2019 16:07:32 +0200 Subject: [PATCH] imp(Device): Honor FlowDelegate's opinion on pending authorization. --- src/authenticator_delegate.rs | 18 ++++++------- src/device.rs | 50 ++++++++++++++++++++++++++++------- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index b5723ac..b993758 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -110,15 +110,6 @@ pub trait AuthenticatorDelegate: Clone { /// The server denied the attempt to obtain a request code fn request_failure(&mut self, _: RequestError) {} - /// Called if the request code is expired. You will have to start over in this case. - /// This will be the last call the delegate receives. - /// Given `DateTime` is the expiration date - fn expired(&mut self, _: &DateTime) {} - - /// Called if the user denied access. You would have to start over. - /// This will be the last call the delegate receives. - fn denied(&mut self) {} - /// Called if we could not acquire a refresh token for a reason possibly specified /// by the server. /// This call is made for the delegate's information only. @@ -137,6 +128,15 @@ pub trait AuthenticatorDelegate: Clone { } pub trait FlowDelegate: Clone { + /// Called if the request code is expired. You will have to start over in this case. + /// This will be the last call the delegate receives. + /// Given `DateTime` is the expiration date + fn expired(&mut self, _: &DateTime) {} + + /// Called if the user denied access. You would have to start over. + /// This will be the last call the delegate receives. + fn denied(&mut self) {} + /// Called as long as we are waiting for the user to authorize us. /// Can be used to print progress information, or decide to time-out. /// diff --git a/src/device.rs b/src/device.rs index b58a68a..3182ef8 100644 --- a/src/device.rs +++ b/src/device.rs @@ -13,8 +13,10 @@ use serde_json as json; use tokio_timer; use url::form_urlencoded; -use crate::authenticator_delegate::{FlowDelegate, PollError, PollInformation}; -use crate::types::{ApplicationSecret, Flow, FlowType, GetToken, JsonError, RequestError, Token}; +use crate::authenticator_delegate::{FlowDelegate, PollError, PollInformation, Retry}; +use crate::types::{ + ApplicationSecret, Flow, FlowType, GetToken, JsonError, RequestError, StringError, Token, +}; pub const GOOGLE_DEVICE_CODE_URL: &'static str = "https://accounts.google.com/o/oauth2/device/code"; @@ -81,7 +83,7 @@ where .map(|s| s.as_ref().to_string()) .unwrap_or(GOOGLE_DEVICE_CODE_URL.to_string()), fd: fd, - wait: Duration::from_secs(120), + wait: Duration::from_secs(1200), } } @@ -96,10 +98,10 @@ where &mut self, scopes: Vec, ) -> Box> + Send> { - let mut fd = self.fd.clone(); let application_secret = self.application_secret.clone(); let client = self.client.clone(); let wait = self.wait; + let mut fd = self.fd.clone(); let request_code = Self::request_code( application_secret.clone(), client.clone(), @@ -110,6 +112,7 @@ where fd.present_user_code(&pollinf); Ok((pollinf, device_code)) }); + let fd = self.fd.clone(); Box::new(request_code.and_then(move |(pollinf, device_code)| { future::loop_fn(0, move |i| { // Make a copy of everything every time, because the loop function needs to be @@ -119,15 +122,41 @@ where client.clone(), device_code.clone(), pollinf.clone(), + fd.clone(), ); let maxn = wait.as_secs() / pollinf.interval.as_secs(); + let mut fd = fd.clone(); + let pollinf = pollinf.clone(); tokio_timer::sleep(pollinf.interval) .then(|_| pt) .then(move |r| match r { - Ok(None) if i < maxn => Ok(future::Loop::Continue(i + 1)), - Ok(Some(tok)) => Ok(future::Loop::Break(tok)), - Err(_) if i < maxn => Ok(future::Loop::Continue(i + 1)), - _ => Err(Box::new(PollError::TimedOut) as Box), + Ok(None) if i < maxn => match fd.pending(&pollinf) { + Retry::Abort | Retry::Skip => Box::new( + Err(Box::new(StringError::new( + "Pending authentication aborted".to_string(), + None, + )) as Box) + .into_future(), + ), + Retry::After(d) => Box::new( + tokio_timer::sleep(d) + .then(move |_| Ok(future::Loop::Continue(i + 1))), + ) + as Box< + dyn Future< + Item = future::Loop, + Error = Box, + > + Send, + >, + }, + Ok(Some(tok)) => Box::new(Ok(future::Loop::Break(tok)).into_future()), + Err(_) if i < maxn => { + Box::new(Ok(future::Loop::Continue(i + 1)).into_future()) + } + _ => Box::new( + Err(Box::new(PollError::TimedOut) as Box) + .into_future(), + ), }) }) })) @@ -256,8 +285,10 @@ where client: hyper::Client, device_code: String, pi: PollInformation, + mut fd: FD, ) -> impl Future, Error = Box> { let expired = if pi.expires_at <= Utc::now() { + fd.expired(&pi.expires_at); Err(PollError::Expired(pi.expires_at)).into_future() } else { Ok(()).into_future() @@ -291,7 +322,7 @@ where .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) .unwrap() // TODO: error handling }) - .and_then(|json_str: String| { + .and_then(move |json_str: String| { #[derive(Deserialize)] struct JsonError { error: String, @@ -302,6 +333,7 @@ where Ok(res) => { match res.error.as_ref() { "access_denied" => { + fd.denied(); return Err( Box::new(PollError::AccessDenied) as Box );