feat(DeviceFlow): Proper timeout handling for the DeviceFlow.

This commit is contained in:
Lewin Bormann
2019-06-12 19:28:37 +02:00
parent e7a89fae07
commit 46e1f1b880
2 changed files with 42 additions and 14 deletions

View File

@@ -1,9 +1,10 @@
use futures::prelude::*;
use yup_oauth2;
use yup_oauth2::{self, GetToken};
use hyper::client::Client;
use hyper_tls::HttpsConnector;
use std::path;
use std::time::Duration;
use tokio;
fn main() {
@@ -16,11 +17,12 @@ fn main() {
let ad = yup_oauth2::DefaultAuthenticatorDelegate;
let mut df = yup_oauth2::DeviceFlow::new::<String>(client, creds, ad, None);
df.set_wait_duration(Duration::from_secs(120));
let mut rt = tokio::runtime::Runtime::new().unwrap();
let fut = df
.retrieve_device_token(scopes.to_vec())
.token(scopes.iter())
.and_then(|tok| Ok(println!("{:?}", tok)));
rt.block_on(fut).unwrap()
println!("{:?}", rt.block_on(fut));
}

View File

@@ -1,5 +1,5 @@
use std::error::Error;
use std::iter::IntoIterator;
use std::iter::{FromIterator, IntoIterator};
use std::time::Duration;
use chrono::{self, Utc};
@@ -14,7 +14,7 @@ use tokio_timer;
use url::form_urlencoded;
use crate::authenticator_delegate::{AuthenticatorDelegate, PollError, PollInformation};
use crate::types::{ApplicationSecret, Flow, FlowType, JsonError, RequestError, Token};
use crate::types::{ApplicationSecret, Flow, FlowType, GetToken, JsonError, RequestError, Token};
pub const GOOGLE_DEVICE_CODE_URL: &'static str = "https://accounts.google.com/o/oauth2/device/code";
@@ -28,6 +28,7 @@ pub struct DeviceFlow<AD, C> {
/// Usually GOOGLE_DEVICE_CODE_URL
device_code_url: String,
ad: AD,
wait: Duration,
}
impl<AD, C> Flow for DeviceFlow<AD, C> {
@@ -36,6 +37,26 @@ impl<AD, C> Flow for DeviceFlow<AD, C> {
}
}
impl<
AD: AuthenticatorDelegate + Clone + Send + 'static,
C: hyper::client::connect::Connect + Sync + 'static,
> GetToken for DeviceFlow<AD, C>
{
fn token<'b, I, T>(
&mut self,
scopes: I,
) -> Box<dyn Future<Item = Token, Error = Box<dyn Error + Send>> + Send>
where
T: AsRef<str> + Ord + 'b,
I: Iterator<Item = &'b T>,
{
self.retrieve_device_token(Vec::from_iter(scopes.map(|s| s.as_ref().to_string())))
}
fn api_key(&mut self) -> Option<String> {
None
}
}
impl<AD, C> DeviceFlow<AD, C>
where
C: hyper::client::connect::Connect + Sync + 'static,
@@ -57,16 +78,23 @@ where
.map(|s| s.as_ref().to_string())
.unwrap_or(GOOGLE_DEVICE_CODE_URL.to_string()),
ad: ad,
wait: Duration::from_secs(120),
}
}
/// Set the time to wait for the user to authorize us. The default is 120 seconds.
pub fn set_wait_duration(&mut self, wait: Duration) {
self.wait = wait;
}
pub fn retrieve_device_token<'a>(
&mut self,
scopes: Vec<String>,
) -> Box<dyn Future<Item = Option<Token>, Error = Box<dyn Error + Send>> + Send> {
) -> Box<dyn Future<Item = Token, Error = Box<dyn Error + Send>> + Send> {
let mut ad = self.ad.clone();
let application_secret = self.application_secret.clone();
let client = self.client.clone();
let wait = self.wait;
let request_code = Self::request_code(
application_secret.clone(),
client.clone(),
@@ -74,11 +102,10 @@ where
scopes,
)
.and_then(move |(pollinf, device_code)| {
println!("presenting, {}", device_code);
ad.present_user_code(&pollinf);
Ok((pollinf, device_code))
});
Box::new(request_code.and_then(|(pollinf, device_code)| {
Box::new(request_code.and_then(move |(pollinf, device_code)| {
future::loop_fn(0, move |i| {
// Make a copy of everything every time, because the loop function needs to be
// repeatable, i.e. we can't move anything out.
@@ -89,14 +116,14 @@ where
device_code.clone(),
pollinf.clone(),
);
println!("waiting {:?}", pollinf.interval);
let maxn = wait.as_secs() / pollinf.interval.as_secs();
tokio_timer::sleep(pollinf.interval)
.then(|_| pt)
.then(move |r| match r {
Ok(None) if i < 10 => Ok(future::Loop::Continue(i + 1)),
Ok(Some(tok)) => Ok(future::Loop::Break(Some(tok))),
Err(_) if i < 10 => Ok(future::Loop::Continue(i + 1)),
_ => Ok(future::Loop::Break(None)),
Ok(None) if i < maxn => Ok(future::Loop::Continue(i + 1)),
Ok(Some(tok)) => Ok(future::Loop::Break(tok)),
Err(_) if i < maxn => Ok(future::Loop::Continue(i + 1)),
_ => Err(Box::new(PollError::TimedOut) as Box<dyn Error + Send>),
})
})
}))
@@ -149,7 +176,6 @@ where
.then(
move |request: Result<hyper::Request<hyper::Body>, http::Error>| {
let request = request.unwrap();
println!("request: {:?}", request);
client.request(request)
},
)