refactor(DeviceFlow): Make DeviceFlow work with Futures

This commit is contained in:
Lewin Bormann
2019-06-12 18:43:30 +02:00
parent 732e594962
commit 58383f9a03
6 changed files with 240 additions and 136 deletions

View File

@@ -13,6 +13,7 @@ edition = "2018"
[dependencies]
base64 = "0.10"
chrono = "0.4"
http = "0.1"
hyper = {version = "0.12", default-features = false}
hyper-tls = "0.3"
itertools = "0.8"
@@ -25,6 +26,7 @@ url = "1"
futures = "0.1"
tokio-threadpool = "0.1"
tokio = "0.1"
tokio-timer = "0.2"
[dev-dependencies]
getopts = "0.2"
@@ -32,4 +34,4 @@ open = "1.1"
yup-hyper-mock = "3.14"
[workspace]
members = ["examples/test-installed/", "examples/test-svc-acct/"]
members = ["examples/test-installed/", "examples/test-svc-acct/", "examples/test-device/"]

View File

@@ -0,0 +1,12 @@
[package]
name = "test-device"
version = "0.1.0"
authors = ["Lewin Bormann <lewin@lewin-bormann.info>"]
edition = "2018"
[dependencies]
yup-oauth2 = { path = "../../" }
hyper = "0.12"
hyper-tls = "0.3"
futures = "0.1"
tokio = "0.1"

View File

@@ -0,0 +1,26 @@
use futures::prelude::*;
use yup_oauth2;
use hyper::client::Client;
use hyper_tls::HttpsConnector;
use std::path;
use tokio;
fn main() {
let creds = yup_oauth2::read_application_secret(path::Path::new("clientsecret.json"))
.expect("clientsecret");
let https = HttpsConnector::new(1).expect("tls");
let client = Client::builder().build::<_, hyper::Body>(https);
let scopes = &["https://www.googleapis.com/auth/youtube.readonly".to_string()];
let ad = yup_oauth2::DefaultAuthenticatorDelegate;
let mut df = yup_oauth2::DeviceFlow::new::<String>(client, creds, ad, None);
let mut rt = tokio::runtime::Runtime::new().unwrap();
let fut = df
.retrieve_device_token(scopes.to_vec())
.and_then(|tok| Ok(println!("{:?}", tok)));
rt.block_on(fut).unwrap()
}

View File

@@ -50,10 +50,12 @@ impl fmt::Display for PollInformation {
pub enum PollError {
/// Connection failure - retry if you think it's worth it
HttpError(hyper::Error),
/// indicates we are expired, including the expiration date
/// Indicates we are expired, including the expiration date
Expired(DateTime<Utc>),
/// Indicates that the user declined access. String is server response
AccessDenied,
/// Indicates that too many attempts failed.
TimedOut,
}
impl fmt::Display for PollError {
@@ -62,6 +64,16 @@ impl fmt::Display for PollError {
PollError::HttpError(ref err) => err.fmt(f),
PollError::Expired(ref date) => writeln!(f, "Authentication expired at {}", date),
PollError::AccessDenied => "Access denied by user".fmt(f),
PollError::TimedOut => "Timed out waiting for token".fmt(f),
}
}
}
impl Error for PollError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self {
PollError::HttpError(ref e) => Some(e),
_ => None,
}
}
}

View File

@@ -1,71 +1,107 @@
use std::default::Default;
use std::error::Error;
use std::iter::IntoIterator;
use std::time::Duration;
use chrono::{self, Utc};
use futures::stream::Stream;
use futures::Future;
use futures::{future, prelude::*};
use http;
use hyper;
use hyper::header;
use itertools::Itertools;
use serde_json as json;
use tokio_timer;
use url::form_urlencoded;
use crate::authenticator_delegate::{PollError, PollInformation};
use crate::authenticator_delegate::{AuthenticatorDelegate, PollError, PollInformation};
use crate::types::{ApplicationSecret, Flow, FlowType, JsonError, RequestError, Token};
pub const GOOGLE_DEVICE_CODE_URL: &'static str = "https://accounts.google.com/o/oauth2/device/code";
/// Encapsulates all possible states of the Device Flow
enum DeviceFlowState {
/// We failed to poll a result
Error,
/// We received poll information and will periodically poll for a token
Pending(PollInformation),
/// The flow finished successfully, providing token information
Success(Token),
}
/// Implements the [Oauth2 Device Flow](https://developers.google.com/youtube/v3/guides/authentication#devices)
/// It operates in two steps:
/// * obtain a code to show to the user
/// * (repeatedly) poll for the user to authenticate your application
pub struct DeviceFlow<C> {
pub struct DeviceFlow<AD, C> {
client: hyper::Client<C, hyper::Body>,
device_code: String,
state: Option<DeviceFlowState>,
error: Option<PollError>,
application_secret: ApplicationSecret,
/// Usually GOOGLE_DEVICE_CODE_URL
device_code_url: String,
ad: AD,
}
impl<C> Flow for DeviceFlow<C> {
impl<AD, C> Flow for DeviceFlow<AD, C> {
fn type_id() -> FlowType {
FlowType::Device(String::new())
}
}
impl<C> DeviceFlow<C>
impl<AD, C> DeviceFlow<AD, C>
where
C: hyper::client::connect::Connect + Sync + 'static,
C::Transport: 'static,
C::Future: 'static,
AD: AuthenticatorDelegate + Clone + Send + 'static,
{
pub fn new<S: AsRef<str>>(
pub fn new<S: 'static + AsRef<str>>(
client: hyper::Client<C, hyper::Body>,
secret: &ApplicationSecret,
device_code_url: S,
) -> DeviceFlow<C> {
secret: ApplicationSecret,
ad: AD,
device_code_url: Option<S>,
) -> DeviceFlow<AD, C> {
DeviceFlow {
client: client,
device_code: Default::default(),
application_secret: secret.clone(),
device_code_url: device_code_url.as_ref().to_string(),
state: None,
error: None,
application_secret: secret,
device_code_url: device_code_url
.as_ref()
.map(|s| s.as_ref().to_string())
.unwrap_or(GOOGLE_DEVICE_CODE_URL.to_string()),
ad: ad,
}
}
pub fn retrieve_device_token<'a>(
&mut self,
scopes: Vec<String>,
) -> Box<dyn Future<Item = Option<Token>, Error = Box<dyn Error + Send>> + Send> {
let mut ad = self.ad.clone();
let application_secret = self.application_secret.clone();
let client = self.client.clone();
let request_code = Self::request_code(
application_secret.clone(),
client.clone(),
self.device_code_url.clone(),
scopes,
)
.and_then(move |(pollinf, device_code)| {
println!("presenting, {}", device_code);
ad.present_user_code(&pollinf);
Ok((pollinf, device_code))
});
Box::new(request_code.and_then(|(pollinf, device_code)| {
future::loop_fn(0, move |i| {
// Make a copy of everything every time, because the loop function needs to be
// repeatable, i.e. we can't move anything out.
//
let pt = Self::poll_token(
application_secret.clone(),
client.clone(),
device_code.clone(),
pollinf.clone(),
);
println!("waiting {:?}", pollinf.interval);
tokio_timer::sleep(pollinf.interval)
.then(|_| pt)
.then(move |r| match r {
Ok(None) if i < 10 => Ok(future::Loop::Continue(i + 1)),
Ok(Some(tok)) => Ok(future::Loop::Break(Some(tok))),
Err(_) if i < 10 => Ok(future::Loop::Continue(i + 1)),
_ => Ok(future::Loop::Break(None)),
})
})
}))
}
/// The first step involves asking the server for a code that the user
/// can type into a field at a specified URL. It is called only once, assuming
/// there was no connection error. Otherwise, it may be called again until
@@ -81,26 +117,23 @@ where
/// * If called after a successful result was returned at least once.
/// # Examples
/// See test-cases in source code for a more complete example.
pub fn request_code<'b, T, I>(&mut self, scopes: I) -> Result<PollInformation, RequestError>
where
T: AsRef<str> + 'b,
I: IntoIterator<Item = &'b T>,
fn request_code(
application_secret: ApplicationSecret,
client: hyper::Client<C>,
device_code_url: String,
scopes: Vec<String>,
) -> impl Future<Item = (PollInformation, String), Error = Box<dyn 'static + Error + Send>>
{
if self.state.is_some() {
panic!("Must not be called after we have obtained a token and have no error");
}
// note: cloned() shouldn't be needed, see issue
// https://github.com/servo/rust-url/issues/81
let req = form_urlencoded::Serializer::new(String::new())
.extend_pairs(&[
("client_id", &self.application_secret.client_id),
("client_id", application_secret.client_id.clone()),
(
"scope",
&scopes
scopes
.into_iter()
.map(|s| s.as_ref())
.intersperse(" ")
.intersperse(" ".to_string())
.collect::<String>(),
),
])
@@ -108,54 +141,67 @@ where
// note: works around bug in rustlang
// https://github.com/rust-lang/rust/issues/22252
let request = hyper::Request::post(&self.device_code_url)
let request = hyper::Request::post(device_code_url)
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(hyper::Body::from(req))?;
.body(hyper::Body::from(req))
.into_future();
request
.then(
move |request: Result<hyper::Request<hyper::Body>, http::Error>| {
let request = request.unwrap();
println!("request: {:?}", request);
client.request(request)
},
)
.then(
|r: Result<hyper::Response<hyper::Body>, hyper::error::Error>| {
match r {
Err(err) => {
return Err(
Box::new(RequestError::ClientError(err)) as Box<dyn Error + Send>
);
}
Ok(res) => {
#[derive(Deserialize)]
struct JsonData {
device_code: String,
user_code: String,
verification_url: String,
expires_in: i64,
interval: i64,
}
// TODO: move the ? on request
let ret = match self.client.request(request).wait() {
Err(err) => {
return Err(RequestError::ClientError(err)); // TODO: failed here
}
Ok(res) => {
#[derive(Deserialize)]
struct JsonData {
device_code: String,
user_code: String,
verification_url: String,
expires_in: i64,
interval: i64,
}
let json_str: String = res
.into_body()
.concat2()
.wait()
.map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap())
.unwrap(); // TODO: error handling
let json_str: String = res
.into_body()
.concat2()
.wait()
.map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap())
.unwrap(); // TODO: error handling
// check for error
match json::from_str::<JsonError>(&json_str) {
Err(_) => {} // ignore, move on
Ok(res) => {
return Err(
Box::new(RequestError::from(res)) as Box<dyn Error + Send>
)
}
}
// check for error
match json::from_str::<JsonError>(&json_str) {
Err(_) => {} // ignore, move on
Ok(res) => return Err(RequestError::from(res)),
}
let decoded: JsonData = json::from_str(&json_str).unwrap();
let decoded: JsonData = json::from_str(&json_str).unwrap();
self.device_code = decoded.device_code;
let pi = PollInformation {
user_code: decoded.user_code,
verification_url: decoded.verification_url,
expires_at: Utc::now() + chrono::Duration::seconds(decoded.expires_in),
interval: Duration::from_secs(i64::abs(decoded.interval) as u64),
};
self.state = Some(DeviceFlowState::Pending(pi.clone()));
Ok(pi)
}
};
ret
let pi = PollInformation {
user_code: decoded.user_code,
verification_url: decoded.verification_url,
expires_at: Utc::now()
+ chrono::Duration::seconds(decoded.expires_in),
interval: Duration::from_secs(i64::abs(decoded.interval) as u64),
};
Ok((pi, decoded.device_code))
}
}
},
)
}
/// If the first call is successful, this method may be called.
@@ -175,78 +221,73 @@ where
///
/// # Examples
/// See test-cases in source code for a more complete example.
pub fn poll_token(&mut self) -> Result<Option<Token>, &PollError> {
// clone, as we may re-assign our state later
let pi = match self.state {
Some(ref s) => match *s {
DeviceFlowState::Pending(ref pi) => pi.clone(),
DeviceFlowState::Error => return Err(self.error.as_ref().unwrap()),
DeviceFlowState::Success(ref t) => return Ok(Some(t.clone())),
},
_ => panic!("You have to call request_code() beforehand"),
fn poll_token<'a>(
application_secret: ApplicationSecret,
client: hyper::Client<C>,
device_code: String,
pi: PollInformation,
) -> impl Future<Item = Option<Token>, Error = Box<dyn 'a + Error + Send>> {
let expired = if pi.expires_at <= Utc::now() {
Err(PollError::Expired(pi.expires_at)).into_future()
} else {
Ok(()).into_future()
};
if pi.expires_at <= Utc::now() {
self.error = Some(PollError::Expired(pi.expires_at));
self.state = Some(DeviceFlowState::Error);
return Err(&self.error.as_ref().unwrap());
}
// We should be ready for a new request
let req = form_urlencoded::Serializer::new(String::new())
.extend_pairs(&[
("client_id", &self.application_secret.client_id[..]),
("client_secret", &self.application_secret.client_secret),
("code", &self.device_code),
("client_id", &application_secret.client_id[..]),
("client_secret", &application_secret.client_secret),
("code", &device_code),
("grant_type", "http://oauth.net/grant_type/device/1.0"),
])
.finish();
let request = hyper::Request::post(&self.application_secret.token_uri)
let request = hyper::Request::post(&application_secret.token_uri)
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(hyper::Body::from(req))
.unwrap(); // TODO: Error checking
let json_str: String = match self.client.request(request).wait() {
Err(err) => {
self.error = Some(PollError::HttpError(err));
return Err(self.error.as_ref().unwrap());
}
Ok(res) => {
expired
.map_err(|e| Box::new(e) as Box<dyn Error + Send>)
.and_then(move |_| {
client
.request(request)
.map_err(|e| Box::new(e) as Box<dyn Error + Send>)
})
.map(|res| {
res.into_body()
.concat2()
.wait()
.map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap())
.unwrap() // TODO: error handling
}
};
})
.and_then(|json_str: String| {
#[derive(Deserialize)]
struct JsonError {
error: String,
}
#[derive(Deserialize)]
struct JsonError {
error: String,
}
match json::from_str::<JsonError>(&json_str) {
Err(_) => {} // ignore, move on, it's not an error
Ok(res) => {
match res.error.as_ref() {
"access_denied" => {
self.error = Some(PollError::AccessDenied);
self.state = Some(DeviceFlowState::Error);
return Err(self.error.as_ref().unwrap());
match json::from_str::<JsonError>(&json_str) {
Err(_) => {} // ignore, move on, it's not an error
Ok(res) => {
match res.error.as_ref() {
"access_denied" => {
return Err(
Box::new(PollError::AccessDenied) as Box<dyn Error + Send>
);
}
"authorization_pending" => return Ok(None),
_ => panic!("server message '{}' not understood", res.error),
};
}
"authorization_pending" => return Ok(None),
_ => panic!("server message '{}' not understood", res.error),
};
}
}
}
// yes, we expect that !
let mut t: Token = json::from_str(&json_str).unwrap();
t.set_expiry_absolute();
// yes, we expect that !
let mut t: Token = json::from_str(&json_str).unwrap();
t.set_expiry_absolute();
let res = Ok(Some(t.clone()));
self.state = Some(DeviceFlowState::Success(t));
return res;
Ok(Some(t.clone()))
})
}
}

View File

@@ -19,6 +19,7 @@ pub struct JsonError {
}
/// Encapsulates all possible results of the `request_token(...)` operation
#[derive(Debug)]
pub enum RequestError {
/// Indicates connection failure
ClientError(hyper::Error),
@@ -78,6 +79,16 @@ impl fmt::Display for RequestError {
}
}
impl Error for RequestError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self {
RequestError::ClientError(ref err) => Some(err),
RequestError::HttpError(ref err) => Some(err),
_ => None,
}
}
}
#[derive(Debug)]
pub struct StringError {
error: String,