diff --git a/Cargo.toml b/Cargo.toml index 680af8b..f124e0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ edition = "2018" [dependencies] base64 = "0.10" chrono = "0.4" +http = "0.1" hyper = {version = "0.12", default-features = false} hyper-tls = "0.3" itertools = "0.8" @@ -25,6 +26,7 @@ url = "1" futures = "0.1" tokio-threadpool = "0.1" tokio = "0.1" +tokio-timer = "0.2" [dev-dependencies] getopts = "0.2" @@ -32,4 +34,4 @@ open = "1.1" yup-hyper-mock = "3.14" [workspace] -members = ["examples/test-installed/", "examples/test-svc-acct/"] +members = ["examples/test-installed/", "examples/test-svc-acct/", "examples/test-device/"] diff --git a/examples/test-device/Cargo.toml b/examples/test-device/Cargo.toml new file mode 100644 index 0000000..648213b --- /dev/null +++ b/examples/test-device/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "test-device" +version = "0.1.0" +authors = ["Lewin Bormann "] +edition = "2018" + +[dependencies] +yup-oauth2 = { path = "../../" } +hyper = "0.12" +hyper-tls = "0.3" +futures = "0.1" +tokio = "0.1" diff --git a/examples/test-device/src/main.rs b/examples/test-device/src/main.rs new file mode 100644 index 0000000..0b0b565 --- /dev/null +++ b/examples/test-device/src/main.rs @@ -0,0 +1,26 @@ +use futures::prelude::*; +use yup_oauth2; + +use hyper::client::Client; +use hyper_tls::HttpsConnector; +use std::path; +use tokio; + +fn main() { + let creds = yup_oauth2::read_application_secret(path::Path::new("clientsecret.json")) + .expect("clientsecret"); + let https = HttpsConnector::new(1).expect("tls"); + let client = Client::builder().build::<_, hyper::Body>(https); + + let scopes = &["https://www.googleapis.com/auth/youtube.readonly".to_string()]; + + let ad = yup_oauth2::DefaultAuthenticatorDelegate; + let mut df = yup_oauth2::DeviceFlow::new::(client, creds, ad, None); + let mut rt = tokio::runtime::Runtime::new().unwrap(); + + let fut = df + .retrieve_device_token(scopes.to_vec()) + .and_then(|tok| Ok(println!("{:?}", tok))); + + rt.block_on(fut).unwrap() +} diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index 6d0de4f..e918fd1 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -50,10 +50,12 @@ impl fmt::Display for PollInformation { pub enum PollError { /// Connection failure - retry if you think it's worth it HttpError(hyper::Error), - /// indicates we are expired, including the expiration date + /// Indicates we are expired, including the expiration date Expired(DateTime), /// Indicates that the user declined access. String is server response AccessDenied, + /// Indicates that too many attempts failed. + TimedOut, } impl fmt::Display for PollError { @@ -62,6 +64,16 @@ impl fmt::Display for PollError { PollError::HttpError(ref err) => err.fmt(f), PollError::Expired(ref date) => writeln!(f, "Authentication expired at {}", date), PollError::AccessDenied => "Access denied by user".fmt(f), + PollError::TimedOut => "Timed out waiting for token".fmt(f), + } + } +} + +impl Error for PollError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match *self { + PollError::HttpError(ref e) => Some(e), + _ => None, } } } diff --git a/src/device.rs b/src/device.rs index a4d72bd..6dbb022 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,71 +1,107 @@ -use std::default::Default; +use std::error::Error; use std::iter::IntoIterator; use std::time::Duration; use chrono::{self, Utc}; use futures::stream::Stream; -use futures::Future; +use futures::{future, prelude::*}; +use http; use hyper; use hyper::header; use itertools::Itertools; use serde_json as json; +use tokio_timer; use url::form_urlencoded; -use crate::authenticator_delegate::{PollError, PollInformation}; +use crate::authenticator_delegate::{AuthenticatorDelegate, PollError, PollInformation}; use crate::types::{ApplicationSecret, Flow, FlowType, JsonError, RequestError, Token}; pub const GOOGLE_DEVICE_CODE_URL: &'static str = "https://accounts.google.com/o/oauth2/device/code"; -/// Encapsulates all possible states of the Device Flow -enum DeviceFlowState { - /// We failed to poll a result - Error, - /// We received poll information and will periodically poll for a token - Pending(PollInformation), - /// The flow finished successfully, providing token information - Success(Token), -} - /// Implements the [Oauth2 Device Flow](https://developers.google.com/youtube/v3/guides/authentication#devices) /// It operates in two steps: /// * obtain a code to show to the user /// * (repeatedly) poll for the user to authenticate your application -pub struct DeviceFlow { +pub struct DeviceFlow { client: hyper::Client, - device_code: String, - state: Option, - error: Option, application_secret: ApplicationSecret, + /// Usually GOOGLE_DEVICE_CODE_URL device_code_url: String, + ad: AD, } -impl Flow for DeviceFlow { +impl Flow for DeviceFlow { fn type_id() -> FlowType { FlowType::Device(String::new()) } } -impl DeviceFlow +impl DeviceFlow where C: hyper::client::connect::Connect + Sync + 'static, C::Transport: 'static, C::Future: 'static, + AD: AuthenticatorDelegate + Clone + Send + 'static, { - pub fn new>( + pub fn new>( client: hyper::Client, - secret: &ApplicationSecret, - device_code_url: S, - ) -> DeviceFlow { + secret: ApplicationSecret, + ad: AD, + device_code_url: Option, + ) -> DeviceFlow { DeviceFlow { client: client, - device_code: Default::default(), - application_secret: secret.clone(), - device_code_url: device_code_url.as_ref().to_string(), - state: None, - error: None, + application_secret: secret, + device_code_url: device_code_url + .as_ref() + .map(|s| s.as_ref().to_string()) + .unwrap_or(GOOGLE_DEVICE_CODE_URL.to_string()), + ad: ad, } } + pub fn retrieve_device_token<'a>( + &mut self, + scopes: Vec, + ) -> Box, Error = Box> + Send> { + let mut ad = self.ad.clone(); + let application_secret = self.application_secret.clone(); + let client = self.client.clone(); + let request_code = Self::request_code( + application_secret.clone(), + client.clone(), + self.device_code_url.clone(), + scopes, + ) + .and_then(move |(pollinf, device_code)| { + println!("presenting, {}", device_code); + ad.present_user_code(&pollinf); + Ok((pollinf, device_code)) + }); + Box::new(request_code.and_then(|(pollinf, device_code)| { + future::loop_fn(0, move |i| { + // Make a copy of everything every time, because the loop function needs to be + // repeatable, i.e. we can't move anything out. + // + let pt = Self::poll_token( + application_secret.clone(), + client.clone(), + device_code.clone(), + pollinf.clone(), + ); + println!("waiting {:?}", pollinf.interval); + tokio_timer::sleep(pollinf.interval) + .then(|_| pt) + .then(move |r| match r { + Ok(None) if i < 10 => Ok(future::Loop::Continue(i + 1)), + Ok(Some(tok)) => Ok(future::Loop::Break(Some(tok))), + Err(_) if i < 10 => Ok(future::Loop::Continue(i + 1)), + _ => Ok(future::Loop::Break(None)), + }) + }) + })) + } + /// The first step involves asking the server for a code that the user /// can type into a field at a specified URL. It is called only once, assuming /// there was no connection error. Otherwise, it may be called again until @@ -81,26 +117,23 @@ where /// * If called after a successful result was returned at least once. /// # Examples /// See test-cases in source code for a more complete example. - pub fn request_code<'b, T, I>(&mut self, scopes: I) -> Result - where - T: AsRef + 'b, - I: IntoIterator, + fn request_code( + application_secret: ApplicationSecret, + client: hyper::Client, + device_code_url: String, + scopes: Vec, + ) -> impl Future> { - if self.state.is_some() { - panic!("Must not be called after we have obtained a token and have no error"); - } - // note: cloned() shouldn't be needed, see issue // https://github.com/servo/rust-url/issues/81 let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ - ("client_id", &self.application_secret.client_id), + ("client_id", application_secret.client_id.clone()), ( "scope", - &scopes + scopes .into_iter() - .map(|s| s.as_ref()) - .intersperse(" ") + .intersperse(" ".to_string()) .collect::(), ), ]) @@ -108,54 +141,67 @@ where // note: works around bug in rustlang // https://github.com/rust-lang/rust/issues/22252 - let request = hyper::Request::post(&self.device_code_url) + let request = hyper::Request::post(device_code_url) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") - .body(hyper::Body::from(req))?; + .body(hyper::Body::from(req)) + .into_future(); + request + .then( + move |request: Result, http::Error>| { + let request = request.unwrap(); + println!("request: {:?}", request); + client.request(request) + }, + ) + .then( + |r: Result, hyper::error::Error>| { + match r { + Err(err) => { + return Err( + Box::new(RequestError::ClientError(err)) as Box + ); + } + Ok(res) => { + #[derive(Deserialize)] + struct JsonData { + device_code: String, + user_code: String, + verification_url: String, + expires_in: i64, + interval: i64, + } - // TODO: move the ? on request - let ret = match self.client.request(request).wait() { - Err(err) => { - return Err(RequestError::ClientError(err)); // TODO: failed here - } - Ok(res) => { - #[derive(Deserialize)] - struct JsonData { - device_code: String, - user_code: String, - verification_url: String, - expires_in: i64, - interval: i64, - } + let json_str: String = res + .into_body() + .concat2() + .wait() + .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) + .unwrap(); // TODO: error handling - let json_str: String = res - .into_body() - .concat2() - .wait() - .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) - .unwrap(); // TODO: error handling + // check for error + match json::from_str::(&json_str) { + Err(_) => {} // ignore, move on + Ok(res) => { + return Err( + Box::new(RequestError::from(res)) as Box + ) + } + } - // check for error - match json::from_str::(&json_str) { - Err(_) => {} // ignore, move on - Ok(res) => return Err(RequestError::from(res)), - } + let decoded: JsonData = json::from_str(&json_str).unwrap(); - let decoded: JsonData = json::from_str(&json_str).unwrap(); - - self.device_code = decoded.device_code; - let pi = PollInformation { - user_code: decoded.user_code, - verification_url: decoded.verification_url, - expires_at: Utc::now() + chrono::Duration::seconds(decoded.expires_in), - interval: Duration::from_secs(i64::abs(decoded.interval) as u64), - }; - self.state = Some(DeviceFlowState::Pending(pi.clone())); - - Ok(pi) - } - }; - - ret + let pi = PollInformation { + user_code: decoded.user_code, + verification_url: decoded.verification_url, + expires_at: Utc::now() + + chrono::Duration::seconds(decoded.expires_in), + interval: Duration::from_secs(i64::abs(decoded.interval) as u64), + }; + Ok((pi, decoded.device_code)) + } + } + }, + ) } /// If the first call is successful, this method may be called. @@ -175,78 +221,73 @@ where /// /// # Examples /// See test-cases in source code for a more complete example. - pub fn poll_token(&mut self) -> Result, &PollError> { - // clone, as we may re-assign our state later - let pi = match self.state { - Some(ref s) => match *s { - DeviceFlowState::Pending(ref pi) => pi.clone(), - DeviceFlowState::Error => return Err(self.error.as_ref().unwrap()), - DeviceFlowState::Success(ref t) => return Ok(Some(t.clone())), - }, - _ => panic!("You have to call request_code() beforehand"), + fn poll_token<'a>( + application_secret: ApplicationSecret, + client: hyper::Client, + device_code: String, + pi: PollInformation, + ) -> impl Future, Error = Box> { + let expired = if pi.expires_at <= Utc::now() { + Err(PollError::Expired(pi.expires_at)).into_future() + } else { + Ok(()).into_future() }; - if pi.expires_at <= Utc::now() { - self.error = Some(PollError::Expired(pi.expires_at)); - self.state = Some(DeviceFlowState::Error); - return Err(&self.error.as_ref().unwrap()); - } - // We should be ready for a new request let req = form_urlencoded::Serializer::new(String::new()) .extend_pairs(&[ - ("client_id", &self.application_secret.client_id[..]), - ("client_secret", &self.application_secret.client_secret), - ("code", &self.device_code), + ("client_id", &application_secret.client_id[..]), + ("client_secret", &application_secret.client_secret), + ("code", &device_code), ("grant_type", "http://oauth.net/grant_type/device/1.0"), ]) .finish(); - let request = hyper::Request::post(&self.application_secret.token_uri) + let request = hyper::Request::post(&application_secret.token_uri) .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") .body(hyper::Body::from(req)) .unwrap(); // TODO: Error checking - let json_str: String = match self.client.request(request).wait() { - Err(err) => { - self.error = Some(PollError::HttpError(err)); - return Err(self.error.as_ref().unwrap()); - } - Ok(res) => { + expired + .map_err(|e| Box::new(e) as Box) + .and_then(move |_| { + client + .request(request) + .map_err(|e| Box::new(e) as Box) + }) + .map(|res| { res.into_body() .concat2() .wait() .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) .unwrap() // TODO: error handling - } - }; + }) + .and_then(|json_str: String| { + #[derive(Deserialize)] + struct JsonError { + error: String, + } - #[derive(Deserialize)] - struct JsonError { - error: String, - } - - match json::from_str::(&json_str) { - Err(_) => {} // ignore, move on, it's not an error - Ok(res) => { - match res.error.as_ref() { - "access_denied" => { - self.error = Some(PollError::AccessDenied); - self.state = Some(DeviceFlowState::Error); - return Err(self.error.as_ref().unwrap()); + match json::from_str::(&json_str) { + Err(_) => {} // ignore, move on, it's not an error + Ok(res) => { + match res.error.as_ref() { + "access_denied" => { + return Err( + Box::new(PollError::AccessDenied) as Box + ); + } + "authorization_pending" => return Ok(None), + _ => panic!("server message '{}' not understood", res.error), + }; } - "authorization_pending" => return Ok(None), - _ => panic!("server message '{}' not understood", res.error), - }; - } - } + } - // yes, we expect that ! - let mut t: Token = json::from_str(&json_str).unwrap(); - t.set_expiry_absolute(); + // yes, we expect that ! + let mut t: Token = json::from_str(&json_str).unwrap(); + t.set_expiry_absolute(); - let res = Ok(Some(t.clone())); - self.state = Some(DeviceFlowState::Success(t)); - return res; + Ok(Some(t.clone())) + }) } } diff --git a/src/types.rs b/src/types.rs index 520388d..2f353f9 100644 --- a/src/types.rs +++ b/src/types.rs @@ -19,6 +19,7 @@ pub struct JsonError { } /// Encapsulates all possible results of the `request_token(...)` operation +#[derive(Debug)] pub enum RequestError { /// Indicates connection failure ClientError(hyper::Error), @@ -78,6 +79,16 @@ impl fmt::Display for RequestError { } } +impl Error for RequestError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match *self { + RequestError::ClientError(ref err) => Some(err), + RequestError::HttpError(ref err) => Some(err), + _ => None, + } + } +} + #[derive(Debug)] pub struct StringError { error: String,