mirror of
https://github.com/OMGeeky/yup-oauth2.git
synced 2026-01-25 03:00:34 +01:00
refactor(DeviceFlow): Make DeviceFlow work with Futures
This commit is contained in:
@@ -13,6 +13,7 @@ edition = "2018"
|
||||
[dependencies]
|
||||
base64 = "0.10"
|
||||
chrono = "0.4"
|
||||
http = "0.1"
|
||||
hyper = {version = "0.12", default-features = false}
|
||||
hyper-tls = "0.3"
|
||||
itertools = "0.8"
|
||||
@@ -25,6 +26,7 @@ url = "1"
|
||||
futures = "0.1"
|
||||
tokio-threadpool = "0.1"
|
||||
tokio = "0.1"
|
||||
tokio-timer = "0.2"
|
||||
|
||||
[dev-dependencies]
|
||||
getopts = "0.2"
|
||||
@@ -32,4 +34,4 @@ open = "1.1"
|
||||
yup-hyper-mock = "3.14"
|
||||
|
||||
[workspace]
|
||||
members = ["examples/test-installed/", "examples/test-svc-acct/"]
|
||||
members = ["examples/test-installed/", "examples/test-svc-acct/", "examples/test-device/"]
|
||||
|
||||
12
examples/test-device/Cargo.toml
Normal file
12
examples/test-device/Cargo.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "test-device"
|
||||
version = "0.1.0"
|
||||
authors = ["Lewin Bormann <lewin@lewin-bormann.info>"]
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
yup-oauth2 = { path = "../../" }
|
||||
hyper = "0.12"
|
||||
hyper-tls = "0.3"
|
||||
futures = "0.1"
|
||||
tokio = "0.1"
|
||||
26
examples/test-device/src/main.rs
Normal file
26
examples/test-device/src/main.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use futures::prelude::*;
|
||||
use yup_oauth2;
|
||||
|
||||
use hyper::client::Client;
|
||||
use hyper_tls::HttpsConnector;
|
||||
use std::path;
|
||||
use tokio;
|
||||
|
||||
fn main() {
|
||||
let creds = yup_oauth2::read_application_secret(path::Path::new("clientsecret.json"))
|
||||
.expect("clientsecret");
|
||||
let https = HttpsConnector::new(1).expect("tls");
|
||||
let client = Client::builder().build::<_, hyper::Body>(https);
|
||||
|
||||
let scopes = &["https://www.googleapis.com/auth/youtube.readonly".to_string()];
|
||||
|
||||
let ad = yup_oauth2::DefaultAuthenticatorDelegate;
|
||||
let mut df = yup_oauth2::DeviceFlow::new::<String>(client, creds, ad, None);
|
||||
let mut rt = tokio::runtime::Runtime::new().unwrap();
|
||||
|
||||
let fut = df
|
||||
.retrieve_device_token(scopes.to_vec())
|
||||
.and_then(|tok| Ok(println!("{:?}", tok)));
|
||||
|
||||
rt.block_on(fut).unwrap()
|
||||
}
|
||||
@@ -50,10 +50,12 @@ impl fmt::Display for PollInformation {
|
||||
pub enum PollError {
|
||||
/// Connection failure - retry if you think it's worth it
|
||||
HttpError(hyper::Error),
|
||||
/// indicates we are expired, including the expiration date
|
||||
/// Indicates we are expired, including the expiration date
|
||||
Expired(DateTime<Utc>),
|
||||
/// Indicates that the user declined access. String is server response
|
||||
AccessDenied,
|
||||
/// Indicates that too many attempts failed.
|
||||
TimedOut,
|
||||
}
|
||||
|
||||
impl fmt::Display for PollError {
|
||||
@@ -62,6 +64,16 @@ impl fmt::Display for PollError {
|
||||
PollError::HttpError(ref err) => err.fmt(f),
|
||||
PollError::Expired(ref date) => writeln!(f, "Authentication expired at {}", date),
|
||||
PollError::AccessDenied => "Access denied by user".fmt(f),
|
||||
PollError::TimedOut => "Timed out waiting for token".fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for PollError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
match *self {
|
||||
PollError::HttpError(ref e) => Some(e),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
309
src/device.rs
309
src/device.rs
@@ -1,71 +1,107 @@
|
||||
use std::default::Default;
|
||||
use std::error::Error;
|
||||
use std::iter::IntoIterator;
|
||||
use std::time::Duration;
|
||||
|
||||
use chrono::{self, Utc};
|
||||
use futures::stream::Stream;
|
||||
use futures::Future;
|
||||
use futures::{future, prelude::*};
|
||||
use http;
|
||||
use hyper;
|
||||
use hyper::header;
|
||||
use itertools::Itertools;
|
||||
use serde_json as json;
|
||||
use tokio_timer;
|
||||
use url::form_urlencoded;
|
||||
|
||||
use crate::authenticator_delegate::{PollError, PollInformation};
|
||||
use crate::authenticator_delegate::{AuthenticatorDelegate, PollError, PollInformation};
|
||||
use crate::types::{ApplicationSecret, Flow, FlowType, JsonError, RequestError, Token};
|
||||
|
||||
pub const GOOGLE_DEVICE_CODE_URL: &'static str = "https://accounts.google.com/o/oauth2/device/code";
|
||||
|
||||
/// Encapsulates all possible states of the Device Flow
|
||||
enum DeviceFlowState {
|
||||
/// We failed to poll a result
|
||||
Error,
|
||||
/// We received poll information and will periodically poll for a token
|
||||
Pending(PollInformation),
|
||||
/// The flow finished successfully, providing token information
|
||||
Success(Token),
|
||||
}
|
||||
|
||||
/// Implements the [Oauth2 Device Flow](https://developers.google.com/youtube/v3/guides/authentication#devices)
|
||||
/// It operates in two steps:
|
||||
/// * obtain a code to show to the user
|
||||
/// * (repeatedly) poll for the user to authenticate your application
|
||||
pub struct DeviceFlow<C> {
|
||||
pub struct DeviceFlow<AD, C> {
|
||||
client: hyper::Client<C, hyper::Body>,
|
||||
device_code: String,
|
||||
state: Option<DeviceFlowState>,
|
||||
error: Option<PollError>,
|
||||
application_secret: ApplicationSecret,
|
||||
/// Usually GOOGLE_DEVICE_CODE_URL
|
||||
device_code_url: String,
|
||||
ad: AD,
|
||||
}
|
||||
|
||||
impl<C> Flow for DeviceFlow<C> {
|
||||
impl<AD, C> Flow for DeviceFlow<AD, C> {
|
||||
fn type_id() -> FlowType {
|
||||
FlowType::Device(String::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> DeviceFlow<C>
|
||||
impl<AD, C> DeviceFlow<AD, C>
|
||||
where
|
||||
C: hyper::client::connect::Connect + Sync + 'static,
|
||||
C::Transport: 'static,
|
||||
C::Future: 'static,
|
||||
AD: AuthenticatorDelegate + Clone + Send + 'static,
|
||||
{
|
||||
pub fn new<S: AsRef<str>>(
|
||||
pub fn new<S: 'static + AsRef<str>>(
|
||||
client: hyper::Client<C, hyper::Body>,
|
||||
secret: &ApplicationSecret,
|
||||
device_code_url: S,
|
||||
) -> DeviceFlow<C> {
|
||||
secret: ApplicationSecret,
|
||||
ad: AD,
|
||||
device_code_url: Option<S>,
|
||||
) -> DeviceFlow<AD, C> {
|
||||
DeviceFlow {
|
||||
client: client,
|
||||
device_code: Default::default(),
|
||||
application_secret: secret.clone(),
|
||||
device_code_url: device_code_url.as_ref().to_string(),
|
||||
state: None,
|
||||
error: None,
|
||||
application_secret: secret,
|
||||
device_code_url: device_code_url
|
||||
.as_ref()
|
||||
.map(|s| s.as_ref().to_string())
|
||||
.unwrap_or(GOOGLE_DEVICE_CODE_URL.to_string()),
|
||||
ad: ad,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn retrieve_device_token<'a>(
|
||||
&mut self,
|
||||
scopes: Vec<String>,
|
||||
) -> Box<dyn Future<Item = Option<Token>, Error = Box<dyn Error + Send>> + Send> {
|
||||
let mut ad = self.ad.clone();
|
||||
let application_secret = self.application_secret.clone();
|
||||
let client = self.client.clone();
|
||||
let request_code = Self::request_code(
|
||||
application_secret.clone(),
|
||||
client.clone(),
|
||||
self.device_code_url.clone(),
|
||||
scopes,
|
||||
)
|
||||
.and_then(move |(pollinf, device_code)| {
|
||||
println!("presenting, {}", device_code);
|
||||
ad.present_user_code(&pollinf);
|
||||
Ok((pollinf, device_code))
|
||||
});
|
||||
Box::new(request_code.and_then(|(pollinf, device_code)| {
|
||||
future::loop_fn(0, move |i| {
|
||||
// Make a copy of everything every time, because the loop function needs to be
|
||||
// repeatable, i.e. we can't move anything out.
|
||||
//
|
||||
let pt = Self::poll_token(
|
||||
application_secret.clone(),
|
||||
client.clone(),
|
||||
device_code.clone(),
|
||||
pollinf.clone(),
|
||||
);
|
||||
println!("waiting {:?}", pollinf.interval);
|
||||
tokio_timer::sleep(pollinf.interval)
|
||||
.then(|_| pt)
|
||||
.then(move |r| match r {
|
||||
Ok(None) if i < 10 => Ok(future::Loop::Continue(i + 1)),
|
||||
Ok(Some(tok)) => Ok(future::Loop::Break(Some(tok))),
|
||||
Err(_) if i < 10 => Ok(future::Loop::Continue(i + 1)),
|
||||
_ => Ok(future::Loop::Break(None)),
|
||||
})
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
/// The first step involves asking the server for a code that the user
|
||||
/// can type into a field at a specified URL. It is called only once, assuming
|
||||
/// there was no connection error. Otherwise, it may be called again until
|
||||
@@ -81,26 +117,23 @@ where
|
||||
/// * If called after a successful result was returned at least once.
|
||||
/// # Examples
|
||||
/// See test-cases in source code for a more complete example.
|
||||
pub fn request_code<'b, T, I>(&mut self, scopes: I) -> Result<PollInformation, RequestError>
|
||||
where
|
||||
T: AsRef<str> + 'b,
|
||||
I: IntoIterator<Item = &'b T>,
|
||||
fn request_code(
|
||||
application_secret: ApplicationSecret,
|
||||
client: hyper::Client<C>,
|
||||
device_code_url: String,
|
||||
scopes: Vec<String>,
|
||||
) -> impl Future<Item = (PollInformation, String), Error = Box<dyn 'static + Error + Send>>
|
||||
{
|
||||
if self.state.is_some() {
|
||||
panic!("Must not be called after we have obtained a token and have no error");
|
||||
}
|
||||
|
||||
// note: cloned() shouldn't be needed, see issue
|
||||
// https://github.com/servo/rust-url/issues/81
|
||||
let req = form_urlencoded::Serializer::new(String::new())
|
||||
.extend_pairs(&[
|
||||
("client_id", &self.application_secret.client_id),
|
||||
("client_id", application_secret.client_id.clone()),
|
||||
(
|
||||
"scope",
|
||||
&scopes
|
||||
scopes
|
||||
.into_iter()
|
||||
.map(|s| s.as_ref())
|
||||
.intersperse(" ")
|
||||
.intersperse(" ".to_string())
|
||||
.collect::<String>(),
|
||||
),
|
||||
])
|
||||
@@ -108,54 +141,67 @@ where
|
||||
|
||||
// note: works around bug in rustlang
|
||||
// https://github.com/rust-lang/rust/issues/22252
|
||||
let request = hyper::Request::post(&self.device_code_url)
|
||||
let request = hyper::Request::post(device_code_url)
|
||||
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
|
||||
.body(hyper::Body::from(req))?;
|
||||
.body(hyper::Body::from(req))
|
||||
.into_future();
|
||||
request
|
||||
.then(
|
||||
move |request: Result<hyper::Request<hyper::Body>, http::Error>| {
|
||||
let request = request.unwrap();
|
||||
println!("request: {:?}", request);
|
||||
client.request(request)
|
||||
},
|
||||
)
|
||||
.then(
|
||||
|r: Result<hyper::Response<hyper::Body>, hyper::error::Error>| {
|
||||
match r {
|
||||
Err(err) => {
|
||||
return Err(
|
||||
Box::new(RequestError::ClientError(err)) as Box<dyn Error + Send>
|
||||
);
|
||||
}
|
||||
Ok(res) => {
|
||||
#[derive(Deserialize)]
|
||||
struct JsonData {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_url: String,
|
||||
expires_in: i64,
|
||||
interval: i64,
|
||||
}
|
||||
|
||||
// TODO: move the ? on request
|
||||
let ret = match self.client.request(request).wait() {
|
||||
Err(err) => {
|
||||
return Err(RequestError::ClientError(err)); // TODO: failed here
|
||||
}
|
||||
Ok(res) => {
|
||||
#[derive(Deserialize)]
|
||||
struct JsonData {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_url: String,
|
||||
expires_in: i64,
|
||||
interval: i64,
|
||||
}
|
||||
let json_str: String = res
|
||||
.into_body()
|
||||
.concat2()
|
||||
.wait()
|
||||
.map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap())
|
||||
.unwrap(); // TODO: error handling
|
||||
|
||||
let json_str: String = res
|
||||
.into_body()
|
||||
.concat2()
|
||||
.wait()
|
||||
.map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap())
|
||||
.unwrap(); // TODO: error handling
|
||||
// check for error
|
||||
match json::from_str::<JsonError>(&json_str) {
|
||||
Err(_) => {} // ignore, move on
|
||||
Ok(res) => {
|
||||
return Err(
|
||||
Box::new(RequestError::from(res)) as Box<dyn Error + Send>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// check for error
|
||||
match json::from_str::<JsonError>(&json_str) {
|
||||
Err(_) => {} // ignore, move on
|
||||
Ok(res) => return Err(RequestError::from(res)),
|
||||
}
|
||||
let decoded: JsonData = json::from_str(&json_str).unwrap();
|
||||
|
||||
let decoded: JsonData = json::from_str(&json_str).unwrap();
|
||||
|
||||
self.device_code = decoded.device_code;
|
||||
let pi = PollInformation {
|
||||
user_code: decoded.user_code,
|
||||
verification_url: decoded.verification_url,
|
||||
expires_at: Utc::now() + chrono::Duration::seconds(decoded.expires_in),
|
||||
interval: Duration::from_secs(i64::abs(decoded.interval) as u64),
|
||||
};
|
||||
self.state = Some(DeviceFlowState::Pending(pi.clone()));
|
||||
|
||||
Ok(pi)
|
||||
}
|
||||
};
|
||||
|
||||
ret
|
||||
let pi = PollInformation {
|
||||
user_code: decoded.user_code,
|
||||
verification_url: decoded.verification_url,
|
||||
expires_at: Utc::now()
|
||||
+ chrono::Duration::seconds(decoded.expires_in),
|
||||
interval: Duration::from_secs(i64::abs(decoded.interval) as u64),
|
||||
};
|
||||
Ok((pi, decoded.device_code))
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// If the first call is successful, this method may be called.
|
||||
@@ -175,78 +221,73 @@ where
|
||||
///
|
||||
/// # Examples
|
||||
/// See test-cases in source code for a more complete example.
|
||||
pub fn poll_token(&mut self) -> Result<Option<Token>, &PollError> {
|
||||
// clone, as we may re-assign our state later
|
||||
let pi = match self.state {
|
||||
Some(ref s) => match *s {
|
||||
DeviceFlowState::Pending(ref pi) => pi.clone(),
|
||||
DeviceFlowState::Error => return Err(self.error.as_ref().unwrap()),
|
||||
DeviceFlowState::Success(ref t) => return Ok(Some(t.clone())),
|
||||
},
|
||||
_ => panic!("You have to call request_code() beforehand"),
|
||||
fn poll_token<'a>(
|
||||
application_secret: ApplicationSecret,
|
||||
client: hyper::Client<C>,
|
||||
device_code: String,
|
||||
pi: PollInformation,
|
||||
) -> impl Future<Item = Option<Token>, Error = Box<dyn 'a + Error + Send>> {
|
||||
let expired = if pi.expires_at <= Utc::now() {
|
||||
Err(PollError::Expired(pi.expires_at)).into_future()
|
||||
} else {
|
||||
Ok(()).into_future()
|
||||
};
|
||||
|
||||
if pi.expires_at <= Utc::now() {
|
||||
self.error = Some(PollError::Expired(pi.expires_at));
|
||||
self.state = Some(DeviceFlowState::Error);
|
||||
return Err(&self.error.as_ref().unwrap());
|
||||
}
|
||||
|
||||
// We should be ready for a new request
|
||||
let req = form_urlencoded::Serializer::new(String::new())
|
||||
.extend_pairs(&[
|
||||
("client_id", &self.application_secret.client_id[..]),
|
||||
("client_secret", &self.application_secret.client_secret),
|
||||
("code", &self.device_code),
|
||||
("client_id", &application_secret.client_id[..]),
|
||||
("client_secret", &application_secret.client_secret),
|
||||
("code", &device_code),
|
||||
("grant_type", "http://oauth.net/grant_type/device/1.0"),
|
||||
])
|
||||
.finish();
|
||||
|
||||
let request = hyper::Request::post(&self.application_secret.token_uri)
|
||||
let request = hyper::Request::post(&application_secret.token_uri)
|
||||
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
|
||||
.body(hyper::Body::from(req))
|
||||
.unwrap(); // TODO: Error checking
|
||||
let json_str: String = match self.client.request(request).wait() {
|
||||
Err(err) => {
|
||||
self.error = Some(PollError::HttpError(err));
|
||||
return Err(self.error.as_ref().unwrap());
|
||||
}
|
||||
Ok(res) => {
|
||||
expired
|
||||
.map_err(|e| Box::new(e) as Box<dyn Error + Send>)
|
||||
.and_then(move |_| {
|
||||
client
|
||||
.request(request)
|
||||
.map_err(|e| Box::new(e) as Box<dyn Error + Send>)
|
||||
})
|
||||
.map(|res| {
|
||||
res.into_body()
|
||||
.concat2()
|
||||
.wait()
|
||||
.map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap())
|
||||
.unwrap() // TODO: error handling
|
||||
}
|
||||
};
|
||||
})
|
||||
.and_then(|json_str: String| {
|
||||
#[derive(Deserialize)]
|
||||
struct JsonError {
|
||||
error: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct JsonError {
|
||||
error: String,
|
||||
}
|
||||
|
||||
match json::from_str::<JsonError>(&json_str) {
|
||||
Err(_) => {} // ignore, move on, it's not an error
|
||||
Ok(res) => {
|
||||
match res.error.as_ref() {
|
||||
"access_denied" => {
|
||||
self.error = Some(PollError::AccessDenied);
|
||||
self.state = Some(DeviceFlowState::Error);
|
||||
return Err(self.error.as_ref().unwrap());
|
||||
match json::from_str::<JsonError>(&json_str) {
|
||||
Err(_) => {} // ignore, move on, it's not an error
|
||||
Ok(res) => {
|
||||
match res.error.as_ref() {
|
||||
"access_denied" => {
|
||||
return Err(
|
||||
Box::new(PollError::AccessDenied) as Box<dyn Error + Send>
|
||||
);
|
||||
}
|
||||
"authorization_pending" => return Ok(None),
|
||||
_ => panic!("server message '{}' not understood", res.error),
|
||||
};
|
||||
}
|
||||
"authorization_pending" => return Ok(None),
|
||||
_ => panic!("server message '{}' not understood", res.error),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// yes, we expect that !
|
||||
let mut t: Token = json::from_str(&json_str).unwrap();
|
||||
t.set_expiry_absolute();
|
||||
// yes, we expect that !
|
||||
let mut t: Token = json::from_str(&json_str).unwrap();
|
||||
t.set_expiry_absolute();
|
||||
|
||||
let res = Ok(Some(t.clone()));
|
||||
self.state = Some(DeviceFlowState::Success(t));
|
||||
return res;
|
||||
Ok(Some(t.clone()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
11
src/types.rs
11
src/types.rs
@@ -19,6 +19,7 @@ pub struct JsonError {
|
||||
}
|
||||
|
||||
/// Encapsulates all possible results of the `request_token(...)` operation
|
||||
#[derive(Debug)]
|
||||
pub enum RequestError {
|
||||
/// Indicates connection failure
|
||||
ClientError(hyper::Error),
|
||||
@@ -78,6 +79,16 @@ impl fmt::Display for RequestError {
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for RequestError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
match *self {
|
||||
RequestError::ClientError(ref err) => Some(err),
|
||||
RequestError::HttpError(ref err) => Some(err),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StringError {
|
||||
error: String,
|
||||
|
||||
Reference in New Issue
Block a user