diff --git a/Cargo.toml b/Cargo.toml index 2efdf79..dd41ebd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,8 +13,8 @@ edition = "2018" [dependencies] base64 = "0.10" chrono = "0.4" -hyper = "0.10.2" -hyper-native-tls = "0.3" +hyper = {version = "0.12", default-features = false} +hyper-tls = "0.3" itertools = "0.8" log = "0.3" openssl = {version = "0.10", optional = true} @@ -23,6 +23,8 @@ serde = "1.0" serde_json = "1.0" serde_derive = "1.0" url = "1" +futures = "0.1" +tokio-threadpool = "0.1" [features] default = ["openssl"] @@ -31,4 +33,5 @@ no-openssl = ["rustls"] [dev-dependencies] getopts = "0.2" open = "1.1" -yup-hyper-mock = "2.0" +yup-hyper-mock = "3.14" +tokio = "0.1" diff --git a/examples/auth.rs b/examples/auth.rs index 6303104..15fd3af 100644 --- a/examples/auth.rs +++ b/examples/auth.rs @@ -1,10 +1,10 @@ use chrono::Local; use getopts::{Fail, HasArg, Occur, Options}; +use hyper_tls::HttpsConnector; use std::default::Default; use std::env; use std::thread::sleep; use std::time::Duration; -use yup_hyper_mock as mock; use yup_oauth2::{self as oauth2, GetToken}; fn usage(program: &str, opts: &Options, err: Option) -> ! { @@ -92,9 +92,8 @@ fn main() { } } - let client = hyper::Client::with_connector(mock::TeeConnector { - connector: hyper::net::HttpConnector, - }); + let https = HttpsConnector::new(4).unwrap(); + let client = hyper::Client::builder().build(https); match oauth2::Authenticator::new(&secret, StdoutHandler, client, oauth2::NullStorage, None) .token(&m.free) @@ -109,5 +108,5 @@ fn main() { println!("Access token wasn't obtained: {}", err); std::process::exit(10); } - } + }; } diff --git a/src/authenticator.rs b/src/authenticator.rs index 94cf514..5863d0c 100644 --- a/src/authenticator.rs +++ b/src/authenticator.rs @@ -1,4 +1,3 @@ -use std::borrow::BorrowMut; use std::cmp::min; use std::collections::hash_map::DefaultHasher; use std::convert::From; @@ -37,7 +36,7 @@ pub struct Authenticator { flow_type: FlowType, delegate: D, storage: S, - client: C, + client: hyper::Client, secret: ApplicationSecret, } @@ -53,11 +52,11 @@ pub trait GetToken { fn api_key(&mut self) -> Option; } -impl Authenticator +impl<'a, D, S, C: 'static> Authenticator where D: AuthenticatorDelegate, S: TokenStorage, - C: BorrowMut, + C: hyper::client::connect::Connect, { /// Returns a new `Authenticator` instance /// @@ -75,7 +74,7 @@ where pub fn new( secret: &ApplicationSecret, delegate: D, - client: C, + client: hyper::Client, storage: S, flow_type: Option, ) -> Authenticator { @@ -101,7 +100,7 @@ where _ => installed_type = None, } - let mut flow = InstalledFlow::new(self.client.borrow_mut(), installed_type); + let mut flow = InstalledFlow::new(self.client.clone(), installed_type); flow.obtain_token(&mut self.delegate, &self.secret, scopes.iter()) } @@ -110,7 +109,7 @@ where scopes: &Vec<&str>, code_url: String, ) -> Result> { - let mut flow = DeviceFlow::new(self.client.borrow_mut(), &self.secret, &code_url); + let mut flow = DeviceFlow::new(self.client.clone(), &self.secret, &code_url); // PHASE 1: REQUEST CODE let pi: PollInformation; @@ -120,6 +119,12 @@ where pi = match res { Err(res_err) => { match res_err { + RequestError::ClientError(err) => match self.delegate.client_error(&err) { + Retry::Abort | Retry::Skip => { + return Err(Box::new(StringError::from(&err as &Error))); + } + Retry::After(d) => sleep(d), + }, RequestError::HttpError(err) => { match self.delegate.connection_error(&err) { Retry::Abort | Retry::Skip => { @@ -152,14 +157,12 @@ where Err(ref poll_err) => { let pts = poll_err.to_string(); match poll_err { - &&PollError::HttpError(ref err) => { - match self.delegate.connection_error(err) { - Retry::Abort | Retry::Skip => { - return Err(Box::new(StringError::from(err as &Error))); - } - Retry::After(d) => sleep(d), + &&PollError::HttpError(ref err) => match self.delegate.client_error(err) { + Retry::Abort | Retry::Skip => { + return Err(Box::new(StringError::from(err as &Error))); } - } + Retry::After(d) => sleep(d), + }, &&PollError::Expired(ref t) => { self.delegate.expired(t); return Err(Box::new(StringError::from(pts))); @@ -185,11 +188,11 @@ where } } -impl GetToken for Authenticator +impl GetToken for Authenticator where D: AuthenticatorDelegate, S: TokenStorage, - C: BorrowMut, + C: hyper::client::connect::Connect, { /// Blocks until a token was retrieved from storage, from the server, or until the delegate /// decided to abort the attempt, or the user decided not to authorize the application. @@ -219,15 +222,18 @@ where Ok(Some(mut t)) => { // t needs refresh ? if t.expired() { - let mut rf = RefreshFlow::new(self.client.borrow_mut()); + let mut rf = RefreshFlow::new(self.client.clone()); loop { match *rf.refresh_token( self.flow_type.clone(), &self.secret, &t.refresh_token, ) { + RefreshResult::Uninitialized => { + panic!("Token flow should never get here"); + } RefreshResult::Error(ref err) => { - match self.delegate.connection_error(err) { + match self.delegate.client_error(err) { Retry::Abort | Retry::Skip => { return Err(Box::new(StringError::new( err.description().to_string(), @@ -344,17 +350,21 @@ mod tests { use std::default::Default; #[test] - fn flow() { + fn test_flow() { use serde_json as json; + let runtime = tokio::runtime::Runtime::new().unwrap(); let secret = json::from_str::(SECRET) .unwrap() .installed .unwrap(); + let client = hyper::Client::builder() + .executor(runtime.executor()) + .build(MockGoogleAuth::default()); let res = Authenticator::new( &secret, DefaultAuthenticatorDelegate, - hyper::Client::with_connector(::default()), + client, ::default(), None, ) @@ -362,7 +372,7 @@ mod tests { match res { Ok(t) => assert_eq!(t.access_token, "1/fFAGRNJru1FTz70BzhT3Zg"), - _ => panic!("Expected to retrieve token in one go"), + Err(err) => panic!("Expected to retrieve token in one go: {}", err), } } } diff --git a/src/authenticator_delegate.rs b/src/authenticator_delegate.rs index e7dd846..0d70d5b 100644 --- a/src/authenticator_delegate.rs +++ b/src/authenticator_delegate.rs @@ -58,10 +58,17 @@ impl fmt::Display for PollError { /// The only method that needs to be implemented manually is `present_user_code(...)`, /// as no assumptions are made on how this presentation should happen. pub trait AuthenticatorDelegate { + /// Called whenever there is an client, usually if there are network problems. + /// + /// Return retry information. + fn client_error(&mut self, _: &hyper::Error) -> Retry { + Retry::Abort + } + /// Called whenever there is an HttpError, usually if there are network problems. /// /// Return retry information. - fn connection_error(&mut self, _: &hyper::Error) -> Retry { + fn connection_error(&mut self, _: &hyper::http::Error) -> Retry { Retry::Abort } diff --git a/src/device.rs b/src/device.rs index db3bafa..a4d72bd 100644 --- a/src/device.rs +++ b/src/device.rs @@ -3,13 +3,12 @@ use std::iter::IntoIterator; use std::time::Duration; use chrono::{self, Utc}; +use futures::stream::Stream; +use futures::Future; use hyper; -use hyper::header::ContentType; +use hyper::header; use itertools::Itertools; use serde_json as json; -use std::borrow::BorrowMut; -use std::i64; -use std::io::Read; use url::form_urlencoded; use crate::authenticator_delegate::{PollError, PollInformation}; @@ -32,7 +31,7 @@ enum DeviceFlowState { /// * obtain a code to show to the user /// * (repeatedly) poll for the user to authenticate your application pub struct DeviceFlow { - client: C, + client: hyper::Client, device_code: String, state: Option, error: Option, @@ -45,12 +44,15 @@ impl Flow for DeviceFlow { FlowType::Device(String::new()) } } + impl DeviceFlow where - C: BorrowMut, + C: hyper::client::connect::Connect + Sync + 'static, + C::Transport: 'static, + C::Future: 'static, { pub fn new>( - client: C, + client: hyper::Client, secret: &ApplicationSecret, device_code_url: S, ) -> DeviceFlow { @@ -106,20 +108,16 @@ where // note: works around bug in rustlang // https://github.com/rust-lang/rust/issues/22252 - let ret = match self - .client - .borrow_mut() - .post(&self.device_code_url) - .header(ContentType( - "application/x-www-form-urlencoded".parse().unwrap(), - )) - .body(&*req) - .send() - { + let request = hyper::Request::post(&self.device_code_url) + .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") + .body(hyper::Body::from(req))?; + + // TODO: move the ? on request + let ret = match self.client.request(request).wait() { Err(err) => { - return Err(RequestError::HttpError(err)); + return Err(RequestError::ClientError(err)); // TODO: failed here } - Ok(mut res) => { + Ok(res) => { #[derive(Deserialize)] struct JsonData { device_code: String, @@ -129,8 +127,12 @@ where interval: i64, } - let mut json_str = String::new(); - res.read_to_string(&mut json_str).unwrap(); + let json_str: String = res + .into_body() + .concat2() + .wait() + .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) + .unwrap(); // TODO: error handling // check for error match json::from_str::(&json_str) { @@ -200,24 +202,21 @@ where ]) .finish(); - let json_str: String = match self - .client - .borrow_mut() - .post(&self.application_secret.token_uri) - .header(ContentType( - "application/x-www-form-urlencoded".parse().unwrap(), - )) - .body(&*req) - .send() - { + let request = hyper::Request::post(&self.application_secret.token_uri) + .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") + .body(hyper::Body::from(req)) + .unwrap(); // TODO: Error checking + let json_str: String = match self.client.request(request).wait() { Err(err) => { self.error = Some(PollError::HttpError(err)); return Err(self.error.as_ref().unwrap()); } - Ok(mut res) => { - let mut json_str = String::new(); - res.read_to_string(&mut json_str).unwrap(); - json_str + Ok(res) => { + res.into_body() + .concat2() + .wait() + .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) + .unwrap() // TODO: error handling } }; @@ -254,59 +253,30 @@ where #[cfg(test)] pub mod tests { use super::*; - use hyper; - use std::default::Default; - use std::time::Duration; - use yup_hyper_mock::{MockStream, SequentialConnector}; - pub struct MockGoogleAuth(SequentialConnector); - - impl Default for MockGoogleAuth { - fn default() -> MockGoogleAuth { - let mut c = MockGoogleAuth(Default::default()); - c.0.content.push( - "HTTP/1.1 200 OK\r\n\ - Server: BOGUS\r\n\ - \r\n\ - {\r\n\ + mock_connector_in_order!(MockGoogleAuth { + "HTTP/1.1 200 OK\r\n\ + Server: BOGUS\r\n\ + \r\n\ + {\r\n\ \"device_code\" : \"4/L9fTtLrhY96442SEuf1Rl3KLFg3y\",\r\n\ \"user_code\" : \"a9xfwk9c\",\r\n\ \"verification_url\" : \"http://www.google.com/device\",\r\n\ \"expires_in\" : 1800,\r\n\ \"interval\" : 0\r\n\ - }" - .to_string(), - ); - - c.0.content.push( - "HTTP/1.1 200 OK\r\n\ - Server: BOGUS\r\n\ - \r\n\ - {\r\n\ + }" + "HTTP/1.1 200 OK\r\n\ + Server: BOGUS\r\n\ + \r\n\ + {\r\n\ \"error\" : \"authorization_pending\"\r\n\ - }" - .to_string(), - ); - - c.0.content.push( - "HTTP/1.1 200 OK\r\nServer: \ - BOGUS\r\n\r\n{\r\n\"access_token\":\"1/fFAGRNJru1FTz70BzhT3Zg\",\ - \r\n\"expires_in\":3920,\r\n\"token_type\":\"Bearer\",\ - \r\n\"refresh_token\":\ - \"1/6BMfW9j53gdGImsixUH6kU5RsR4zwI9lUVX-tqf8JXQ\"\r\n}" - .to_string(), - ); - c - } - } - - impl hyper::net::NetworkConnector for MockGoogleAuth { - type Stream = MockStream; - - fn connect(&self, host: &str, port: u16, scheme: &str) -> ::hyper::Result { - self.0.connect(host, port, scheme) - } - } + }" + "HTTP/1.1 200 OK\r\nServer: \ + BOGUS\r\n\r\n{\r\n\"access_token\":\"1/fFAGRNJru1FTz70BzhT3Zg\",\ + \r\n\"expires_in\":3920,\r\n\"token_type\":\"Bearer\",\ + \r\n\"refresh_token\":\ + \"1/6BMfW9j53gdGImsixUH6kU5RsR4zwI9lUVX-tqf8JXQ\"\r\n}" + }); const TEST_APP_SECRET: &'static str = r#"{"installed":{"client_id":"384278056379-tr5pbot1mil66749n639jo54i4840u77.apps.googleusercontent.com","project_id":"sanguine-rhythm-105020","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://accounts.google.com/o/oauth2/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_secret":"QeQUnhzsiO4t--ZGmj9muUAu","redirect_uris":["urn:ietf:wg:oauth:2.0:oob","http://localhost"]}}"#; @@ -314,16 +284,17 @@ pub mod tests { fn working_flow() { use crate::helper::parse_application_secret; - let appsecret = parse_application_secret(TEST_APP_SECRET).unwrap(); - let mut flow = DeviceFlow::new( - hyper::Client::with_connector(::default()), - &appsecret, - GOOGLE_DEVICE_CODE_URL, - ); + let runtime = tokio::runtime::Runtime::new().unwrap(); + let appsecret = parse_application_secret(&TEST_APP_SECRET.to_string()).unwrap(); + let client = hyper::Client::builder() + .executor(runtime.executor()) + .build(MockGoogleAuth::default()); + + let mut flow = DeviceFlow::new(client, &appsecret, GOOGLE_DEVICE_CODE_URL); match flow.request_code(&["https://www.googleapis.com/auth/youtube.upload"]) { Ok(pi) => assert_eq!(pi.interval, Duration::from_secs(0)), - _ => unreachable!(), + Err(err) => assert!(false, "request_code failed: {}", err), } match flow.poll_token() { diff --git a/src/installed.rs b/src/installed.rs index 43037c4..7109114 100644 --- a/src/installed.rs +++ b/src/installed.rs @@ -5,16 +5,17 @@ extern crate serde_json; extern crate url; -use std::borrow::BorrowMut; use std::convert::AsRef; use std::error::Error; use std::io; -use std::io::Read; -use std::sync::mpsc::{channel, Receiver, Sender}; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; +use futures; +use futures::stream::Stream; +use futures::sync::oneshot; +use futures::Future; use hyper; -use hyper::{client, header, server, status, uri}; +use hyper::{header, StatusCode, Uri}; use serde_json::error; use url::form_urlencoded; use url::percent_encoding::{percent_encode, QUERY_ENCODE_SET}; @@ -64,11 +65,8 @@ where } pub struct InstalledFlow { - client: C, - server: Option, - port: Option, - - auth_code_rcv: Option>, + client: hyper::Client, + server: Option, } /// cf. https://developers.google.com/identity/protocols/OAuth2InstalledApp#choosingredirecturi @@ -79,48 +77,35 @@ pub enum InstalledFlowReturnMethod { /// Involves spinning up a local HTTP server and Google redirecting the browser to /// the server with a URL containing the code (preferred, but not as reliable). The /// parameter is the port to listen on. - HTTPRedirect(u32), + HTTPRedirect(u16), } -impl InstalledFlow +impl InstalledFlow where - C: BorrowMut, + C: hyper::client::connect::Connect, { /// Starts a new Installed App auth flow. /// If HTTPRedirect is chosen as method and the server can't be started, the flow falls /// back to Interactive. - pub fn new(client: C, method: Option) -> InstalledFlow { + pub fn new( + client: hyper::Client, + method: Option, + ) -> InstalledFlow { let default = InstalledFlow { client: client, server: None, - port: None, - auth_code_rcv: None, }; match method { None => default, Some(InstalledFlowReturnMethod::Interactive) => default, // Start server on localhost to accept auth code. Some(InstalledFlowReturnMethod::HTTPRedirect(port)) => { - let server = server::Server::http(format!("127.0.0.1:{}", port).as_str()); - - match server { + match InstalledFlowServer::new(port) { Result::Err(_) => default, - Result::Ok(server) => { - let (tx, rx) = channel(); - let listening = server.handle(InstalledFlowHandler { - auth_code_snd: Mutex::new(tx), - }); - - match listening { - Result::Err(_) => default, - Result::Ok(listening) => InstalledFlow { - client: default.client, - server: Some(listening), - port: Some(port), - auth_code_rcv: Some(rx), - }, - } - } + Result::Ok(server) => InstalledFlow { + client: default.client, + server: Some(server), + }, } } } @@ -183,7 +168,8 @@ where T: AsRef + 'a, S: Iterator, { - let result: Result> = match self.server { + let server = self.server.take(); // Will shutdown the server if present when goes out of scope + let result: Result> = match server { None => { let url = build_authentication_request_url( &appsecret.auth_uri, @@ -208,39 +194,39 @@ where } } } - Some(_) => { + Some(mut server) => { // The redirect URI must be this very localhost URL, otherwise Google refuses // authorization. let url = build_authentication_request_url( &appsecret.auth_uri, &appsecret.client_id, scopes, - auth_delegate.redirect_uri().or_else(|| { - Some(format!("http://localhost:{}", self.port.unwrap_or(8080))) - }), + auth_delegate + .redirect_uri() + .or_else(|| Some(format!("http://localhost:{}", server.port))), ); auth_delegate.present_user_url(&url, false /* need_code */); - match self.auth_code_rcv.as_ref().unwrap().recv() { + match server.block_till_auth() { Result::Err(e) => Result::Err(Box::new(e)), Result::Ok(s) => Result::Ok(s), } } }; - self.server.as_mut().map(|l| l.close()).is_some(); + result } /// Sends the authorization code to the provider in order to obtain access and refresh tokens. fn request_token( - &mut self, + &self, appsecret: &ApplicationSecret, authcode: &str, custom_redirect_uri: Option, ) -> Result> { - let redirect_uri = custom_redirect_uri.unwrap_or_else(|| match self.port { + let redirect_uri = custom_redirect_uri.unwrap_or_else(|| match &self.server { None => OOB_REDIRECT_URI.to_string(), - Some(p) => format!("http://localhost:{}", p), + Some(server) => format!("http://localhost:{}", server.port), }); let body = form_urlencoded::Serializer::new(String::new()) @@ -253,22 +239,23 @@ where ]) .finish(); - let result: Result = self - .client - .borrow_mut() - .post(&appsecret.token_uri) - .body(&body) - .header(header::ContentType( - "application/x-www-form-urlencoded".parse().unwrap(), - )) - .send(); + let request = hyper::Request::post(&appsecret.token_uri) + .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") + .body(hyper::Body::from(body)) + .unwrap(); // TODO: error check - let mut resp = String::new(); + let result = self.client.request(request).wait(); + + let resp = String::new(); match result { Result::Err(e) => return Result::Err(Box::new(e)), - Result::Ok(mut response) => { - let result = response.read_to_string(&mut resp); + Result::Ok(res) => { + let result = res + .into_body() + .concat2() + .wait() + .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()); // TODO: error handling match result { Result::Err(e) => return Result::Err(Box::new(e)), @@ -297,50 +284,208 @@ struct JSONTokenResponse { error_description: Option, } -/// HTTP handler handling the redirect from the provider. -struct InstalledFlowHandler { - auth_code_snd: Mutex>, +struct InstalledFlowServer { + port: u16, + shutdown_tx: Option>, + auth_code_rx: Option>, + threadpool: Option, } -impl server::Handler for InstalledFlowHandler { - fn handle(&self, rq: server::Request, mut rp: server::Response) { - match rq.uri { - uri::RequestUri::AbsolutePath(path) => { +impl InstalledFlowServer { + fn new(port: u16) -> Result { + let bound_port = hyper::server::conn::AddrIncoming::bind(&([127, 0, 0, 1], port).into()); + match bound_port { + Result::Err(_) => Result::Err(()), + Result::Ok(bound_port) => { + let port = bound_port.local_addr().port(); + + let (auth_code_tx, auth_code_rx) = oneshot::channel::(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + + let threadpool = tokio_threadpool::Builder::new() + .pool_size(1) + .name_prefix("InstalledFlowServer-") + .build(); + let service_maker = InstalledFlowServiceMaker::new(auth_code_tx); + let server = hyper::server::Server::builder(bound_port) + .http1_only(true) + .serve(service_maker) + .with_graceful_shutdown(shutdown_rx) + .map_err(|err| panic!("Failed badly: {}", err)); // TODO: Error handling + + threadpool.spawn(server); + + Result::Ok(InstalledFlowServer { + port, + shutdown_tx: Option::Some(shutdown_tx), + auth_code_rx: Option::Some(auth_code_rx), + threadpool: Option::Some(threadpool), + }) + } + } + } + + fn block_till_auth(&mut self) -> Result { + match self.auth_code_rx.take() { + Some(auth_code_rx) => auth_code_rx.wait(), + None => Result::Err(oneshot::Canceled), + } + } +} + +impl std::ops::Drop for InstalledFlowServer { + fn drop(&mut self) { + self.shutdown_tx.take().map(|tx| tx.send(())); + self.auth_code_rx.take().map(|mut rx| rx.close()); + self.threadpool.take(); + } +} + +pub struct InstalledFlowHandlerResponseFuture { + inner: Box< + futures::Future, Error = hyper::http::Error> + Send, + >, +} + +impl InstalledFlowHandlerResponseFuture { + fn new( + fut: Box< + futures::Future, Error = hyper::http::Error> + Send, + >, + ) -> Self { + Self { inner: fut } + } +} + +impl futures::Future for InstalledFlowHandlerResponseFuture { + type Item = hyper::Response; + type Error = hyper::http::Error; + + fn poll(&mut self) -> futures::Poll { + self.inner.poll() + } +} + +/// Creates InstalledFlowService on demand +struct InstalledFlowServiceMaker { + auth_code_tx: Arc>>>, +} + +impl InstalledFlowServiceMaker { + fn new(auth_code_tx: oneshot::Sender) -> InstalledFlowServiceMaker { + let auth_code_tx = Arc::new(Mutex::new(Option::Some(auth_code_tx))); + InstalledFlowServiceMaker { auth_code_tx } + } +} + +impl hyper::service::MakeService for InstalledFlowServiceMaker { + type ReqBody = hyper::Body; + type ResBody = hyper::Body; + type Error = hyper::http::Error; + type Service = InstalledFlowService; + type Future = futures::future::FutureResult; + type MakeError = hyper::http::Error; + + fn make_service(&mut self, _ctx: Ctx) -> Self::Future { + let service = InstalledFlowService { + auth_code_tx: self.auth_code_tx.clone(), + }; + futures::future::ok(service) + } +} + +/// HTTP service handling the redirect from the provider. +struct InstalledFlowService { + auth_code_tx: Arc>>>, +} + +impl hyper::service::Service for InstalledFlowService { + type ReqBody = hyper::Body; + type ResBody = hyper::Body; + type Error = hyper::http::Error; + type Future = InstalledFlowHandlerResponseFuture; + + fn call(&mut self, req: hyper::Request) -> Self::Future { + match req.uri().path_and_query() { + Some(path_and_query) => { // We use a fake URL because the redirect goes to a URL, meaning we // can't use the url form decode (because there's slashes and hashes and stuff in // it). - let url = hyper::Url::parse(&format!("http://example.com{}", path)); + let url = Uri::builder() + .scheme("http") + .authority("example.com") + .path_and_query(path_and_query.clone()) + .build(); if url.is_err() { - *rp.status_mut() = status::StatusCode::BadRequest; - let _ = rp.send("Unparseable URL".as_ref()); + let response = hyper::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(hyper::Body::from("Unparseable URL")); + + match response { + Ok(response) => InstalledFlowHandlerResponseFuture::new(Box::new( + futures::future::ok(response), + )), + Err(err) => InstalledFlowHandlerResponseFuture::new(Box::new( + futures::future::err(err), + )), + } } else { self.handle_url(url.unwrap()); - *rp.status_mut() = status::StatusCode::Ok; - let _ = rp.send( - "SuccessYou may now \ - close this window." - .as_ref(), - ); + let response = + hyper::Response::builder() + .status(StatusCode::OK) + .body(hyper::Body::from( + "SuccessYou may now \ + close this window.", + )); + + match response { + Ok(response) => InstalledFlowHandlerResponseFuture::new(Box::new( + futures::future::ok(response), + )), + Err(err) => InstalledFlowHandlerResponseFuture::new(Box::new( + futures::future::err(err), + )), + } } } - _ => { - *rp.status_mut() = status::StatusCode::BadRequest; - let _ = rp.send("Invalid Request!".as_ref()); + None => { + let response = hyper::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(hyper::Body::from("Invalid Request!")); + + match response { + Ok(response) => InstalledFlowHandlerResponseFuture::new(Box::new( + futures::future::ok(response), + )), + Err(err) => { + InstalledFlowHandlerResponseFuture::new(Box::new(futures::future::err(err))) + } + } } } } } -impl InstalledFlowHandler { - fn handle_url(&self, url: hyper::Url) { +impl InstalledFlowService { + fn handle_url(&mut self, url: hyper::Uri) { // Google redirects to the specified localhost URL, appending the authorization // code, like this: http://localhost:8080/xyz/?code=4/731fJ3BheyCouCniPufAd280GHNV5Ju35yYcGs // We take that code and send it to the get_authorization_code() function that // waits for it. - for (param, val) in url.query_pairs().into_owned() { + for (param, val) in form_urlencoded::parse(url.query().unwrap_or("").as_bytes()) { if param == "code".to_string() { - let _ = self.auth_code_snd.lock().unwrap().send(val); + let mut auth_code_tx = self.auth_code_tx.lock().unwrap(); + match auth_code_tx.take() { + Some(auth_code_tx) => { + let _ = auth_code_tx.send(val.to_owned().to_string()); + } + None => { + // call to the server after a previous call. Each server is only designed + // to receive a single request. + } + }; } } } @@ -348,13 +493,7 @@ impl InstalledFlowHandler { #[cfg(test)] mod tests { - use super::build_authentication_request_url; - use super::InstalledFlowHandler; - - use std::sync::mpsc::channel; - use std::sync::Mutex; - - use hyper::Url; + use super::*; #[test] fn test_request_url_builder() { @@ -373,15 +512,73 @@ mod tests { ); } + #[test] + fn test_server_random_local_port() { + let addr1 = InstalledFlowServer::new(0).unwrap(); + let addr2 = InstalledFlowServer::new(0).unwrap(); + assert_ne!(addr1.port, addr2.port); + } + #[test] fn test_http_handle_url() { - let (tx, rx) = channel(); - let handler = InstalledFlowHandler { - auth_code_snd: Mutex::new(tx), + let (tx, rx) = oneshot::channel(); + let mut handler = InstalledFlowService { + auth_code_tx: Arc::new(Mutex::new(Option::Some(tx))), }; // URLs are usually a bit botched - let url = Url::parse("http://example.com:1234/?code=ab/c%2Fd#").unwrap(); + let url: Uri = "http://example.com:1234/?code=ab/c%2Fd#".parse().unwrap(); handler.handle_url(url); - assert_eq!(rx.recv().unwrap(), "ab/c/d".to_string()); + assert_eq!(rx.wait().unwrap(), "ab/c/d".to_string()); + } + + #[test] + fn test_server() { + let runtime = tokio::runtime::Runtime::new().unwrap(); + let client: hyper::Client = + hyper::Client::builder() + .executor(runtime.executor()) + .build_http(); + let mut server = InstalledFlowServer::new(0).unwrap(); + + let response = client + .get( + format!("http://127.0.0.1:{}/", server.port) + .parse() + .unwrap(), + ) + .wait(); + match response { + Result::Ok(response) => { + assert!(response.status().is_success()); + } + Result::Err(err) => { + assert!(false, "Failed to request from local server: {:?}", err); + } + } + + let response = client + .get( + format!("http://127.0.0.1:{}/?code=ab/c%2Fd#", server.port) + .parse() + .unwrap(), + ) + .wait(); + match response { + Result::Ok(response) => { + assert!(response.status().is_success()); + } + Result::Err(err) => { + assert!(false, "Failed to request from local server: {:?}", err); + } + } + + match server.block_till_auth() { + Result::Ok(response) => { + assert_eq!(response, "ab/c/d".to_string()); + } + Result::Err(err) => { + assert!(false, "Server failed to pass on the message: {:?}", err); + } + } } } diff --git a/src/lib.rs b/src/lib.rs index 03e0ff9..1df80b1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,14 +41,14 @@ //! use yup_oauth2::{Authenticator, DefaultAuthenticatorDelegate, PollInformation, ConsoleApplicationSecret, MemoryStorage, GetToken}; //! use serde_json as json; //! use std::default::Default; -//! use hyper::{Client, net::HttpsConnector}; -//! use hyper_native_tls::NativeTlsClient; +//! use hyper::Client; +//! use hyper_tls::HttpsConnector; //! # const SECRET: &'static str = "{\"installed\":{\"auth_uri\":\"https://accounts.google.com/o/oauth2/auth\",\"client_secret\":\"UqkDJd5RFwnHoiG5x5Rub8SI\",\"token_uri\":\"https://accounts.google.com/o/oauth2/token\",\"client_email\":\"\",\"redirect_uris\":[\"urn:ietf:wg:oauth:2.0:oob\",\"oob\"],\"client_x509_cert_url\":\"\",\"client_id\":\"14070749909-vgip2f1okm7bkvajhi9jugan6126io9v.apps.googleusercontent.com\",\"auth_provider_x509_cert_url\":\"https://www.googleapis.com/oauth2/v1/certs\"}}"; //! //! # #[test] fn device() { //! let secret = json::from_str::(SECRET).unwrap().installed.unwrap(); //! let res = Authenticator::new(&secret, DefaultAuthenticatorDelegate, -//! Client::with_connector(HttpsConnector::new(NativeTlsClient::new().unwrap())), +//! Client::builder().build(HttpsConnector::new(4).unwrap()), //! ::default(), None) //! .token(&["https://www.googleapis.com/auth/youtube.upload"]); //! match res { @@ -71,14 +71,18 @@ extern crate serde_json; extern crate base64; extern crate chrono; extern crate hyper; -extern crate hyper_native_tls; +extern crate hyper_tls; -extern crate itertools; #[cfg(test)] extern crate log; -extern crate url; #[cfg(test)] -extern crate yup_hyper_mock; +#[macro_use] +extern crate yup_hyper_mock as hyper_mock; +extern crate itertools; +#[cfg(test)] +extern crate tokio; +extern crate tokio_threadpool; +extern crate url; mod authenticator; mod authenticator_delegate; diff --git a/src/refresh.rs b/src/refresh.rs index 06e585d..2eef3f1 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -2,11 +2,11 @@ use crate::types::{ApplicationSecret, FlowType, JsonError}; use super::Token; use chrono::Utc; +use futures::stream::Stream; +use futures::Future; use hyper; -use hyper::header::ContentType; +use hyper::header; use serde_json as json; -use std::borrow::BorrowMut; -use std::io::Read; use url::form_urlencoded; /// Implements the [Outh2 Refresh Token Flow](https://developers.google.com/youtube/v3/guides/authentication#devices). @@ -15,12 +15,14 @@ use url::form_urlencoded; /// This flow is useful when your `Token` is expired and allows to obtain a new /// and valid access token. pub struct RefreshFlow { - client: C, + client: hyper::Client, result: RefreshResult, } /// All possible outcomes of the refresh flow pub enum RefreshResult { + // Indicates no attempt has been made to refresh yet + Uninitialized, /// Indicates connection failure Error(hyper::Error), /// The server did not answer with a new token, providing the server message @@ -29,14 +31,14 @@ pub enum RefreshResult { Success(Token), } -impl RefreshFlow +impl RefreshFlow where - C: BorrowMut, + C: hyper::client::connect::Connect, { - pub fn new(client: C) -> RefreshFlow { + pub fn new(client: hyper::Client) -> RefreshFlow { RefreshFlow { client: client, - result: RefreshResult::Error(hyper::Error::TooLarge), + result: RefreshResult::Uninitialized, } } @@ -56,11 +58,10 @@ where /// Please see the crate landing page for an example. pub fn refresh_token( &mut self, - flow_type: FlowType, + _flow_type: FlowType, client_secret: &ApplicationSecret, refresh_token: &str, ) -> &RefreshResult { - let _ = flow_type; if let RefreshResult::Success(_) = self.result { return &self.result; } @@ -74,24 +75,22 @@ where ]) .finish(); - let json_str: String = match self - .client - .borrow_mut() - .post(&client_secret.token_uri) - .header(ContentType( - "application/x-www-form-urlencoded".parse().unwrap(), - )) - .body(&*req) - .send() - { + let request = hyper::Request::post(&client_secret.token_uri) + .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") + .body(hyper::Body::from(req)) + .unwrap(); // TODO: error handling + + let json_str: String = match self.client.request(request).wait() { Err(err) => { self.result = RefreshResult::Error(err); return &self.result; } - Ok(mut res) => { - let mut json_str = String::new(); - res.read_to_string(&mut json_str).unwrap(); - json_str + Ok(res) => { + res.into_body() + .concat2() + .wait() + .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) + .unwrap() // TODO: error handling } }; @@ -125,42 +124,21 @@ where #[cfg(test)] mod tests { - use super::super::FlowType; use super::*; use crate::device::GOOGLE_DEVICE_CODE_URL; use crate::helper::parse_application_secret; - use hyper; - use std::default::Default; - use yup_hyper_mock::{MockStream, SequentialConnector}; - struct MockGoogleRefresh(SequentialConnector); - - impl Default for MockGoogleRefresh { - fn default() -> MockGoogleRefresh { - let mut c = MockGoogleRefresh(Default::default()); - c.0.content.push( - "HTTP/1.1 200 OK\r\n\ - Server: BOGUS\r\n\ - \r\n\ - {\r\n\ + mock_connector!(MockGoogleRefresh { + "https://accounts.google.com" => + "HTTP/1.1 200 OK\r\n\ + Server: BOGUS\r\n\ + \r\n\ + {\r\n\ \"access_token\":\"1/fFAGRNJru1FTz70BzhT3Zg\",\r\n\ \"expires_in\":3920,\r\n\ \"token_type\":\"Bearer\"\r\n\ - }" - .to_string(), - ); - - c - } - } - - impl hyper::net::NetworkConnector for MockGoogleRefresh { - type Stream = MockStream; - - fn connect(&self, host: &str, port: u16, scheme: &str) -> ::hyper::Result { - self.0.connect(host, port, scheme) - } - } + }" + }); const TEST_APP_SECRET: &'static str = r#"{"installed":{"client_id":"384278056379-tr5pbot1mil66749n639jo54i4840u77.apps.googleusercontent.com","project_id":"sanguine-rhythm-105020","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://accounts.google.com/o/oauth2/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_secret":"QeQUnhzsiO4t--ZGmj9muUAu","redirect_uris":["urn:ietf:wg:oauth:2.0:oob","http://localhost"]}}"#; @@ -168,19 +146,31 @@ mod tests { fn refresh_flow() { let appsecret = parse_application_secret(TEST_APP_SECRET).unwrap(); - let mut c = hyper::Client::with_connector(::default()); - let mut flow = RefreshFlow::new(&mut c); + let runtime = tokio::runtime::Runtime::new().unwrap(); + let client = hyper::Client::builder() + .executor(runtime.executor()) + .build(MockGoogleRefresh::default()); + let mut flow = RefreshFlow::new(client); + let device_flow = FlowType::Device(GOOGLE_DEVICE_CODE_URL.to_string()); - match *flow.refresh_token( - FlowType::Device(GOOGLE_DEVICE_CODE_URL.to_string()), - &appsecret, - "bogus_refresh_token", - ) { + match flow.refresh_token(device_flow, &appsecret, "bogus_refresh_token") { RefreshResult::Success(ref t) => { assert_eq!(t.access_token, "1/fFAGRNJru1FTz70BzhT3Zg"); assert!(!t.expired()); } - _ => unreachable!(), + RefreshResult::Error(err) => { + assert!(false, "Refresh flow failed: RefreshResult::Error({})", err); + } + RefreshResult::RefreshError(msg, err) => { + assert!( + false, + "Refresh flow failed: RefreshResult::RefreshError({}, {:?})", + msg, err + ); + } + RefreshResult::Uninitialized => { + assert!(false, "Refresh flow failed: RefreshResult::Uninitialized"); + } } } } diff --git a/src/service_account.rs b/src/service_account.rs index bdc0a46..3f526c6 100644 --- a/src/service_account.rs +++ b/src/service_account.rs @@ -11,10 +11,8 @@ //! Copyright (c) 2016 Google Inc (lewinb@google.com). //! -use std::borrow::BorrowMut; use std::default::Default; use std::error; -use std::io::Read; use std::result; use std::str; @@ -22,6 +20,8 @@ use crate::authenticator::GetToken; use crate::storage::{hash_scopes, MemoryStorage, TokenStorage}; use crate::types::{StringError, Token}; +use futures::stream::Stream; +use futures::Future; use hyper::header; use url::form_urlencoded; @@ -211,7 +211,7 @@ fn set_sub_claim(mut claims: Claims, sub: String) -> Claims { /// A token source (`GetToken`) yielding OAuth tokens for services that use ServiceAccount authorization. /// This token source caches token and automatically renews expired ones. pub struct ServiceAccountAccess { - client: C, + client: hyper::Client, key: ServiceAccountKey, cache: MemoryStorage, sub: Option, @@ -239,13 +239,16 @@ impl TokenResponse { } } -impl<'a, C> ServiceAccountAccess +impl<'a, C: 'static> ServiceAccountAccess where - C: BorrowMut, + C: hyper::client::connect::Connect, { /// Returns a new `ServiceAccountAccess` token source. #[allow(dead_code)] - pub fn new(key: ServiceAccountKey, client: C) -> ServiceAccountAccess { + pub fn new( + key: ServiceAccountKey, + client: hyper::Client, + ) -> ServiceAccountAccess { ServiceAccountAccess { client: client, key: key, @@ -254,7 +257,11 @@ where } } - pub fn with_sub(key: ServiceAccountKey, client: C, sub: String) -> ServiceAccountAccess { + pub fn with_sub( + key: ServiceAccountKey, + client: hyper::Client, + sub: String, + ) -> ServiceAccountAccess { ServiceAccountAccess { client: client, key: key, @@ -275,21 +282,21 @@ where ]) .finish(); - let mut response = String::new(); - let mut result = self - .client - .borrow_mut() - .post(self.key.token_uri.as_ref().unwrap()) - .body(&body) - .header(header::ContentType( - "application/x-www-form-urlencoded".parse().unwrap(), - )) - .send()?; + let request = hyper::Request::post(self.key.token_uri.as_ref().unwrap()) + .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded") + .body(hyper::Body::from(body)) + .unwrap(); // TOOD: error handling + let response = self.client.request(request).wait()?; - result.read_to_string(&mut response)?; + let json_str = response + .into_body() + .concat2() + .wait() + .map(|c| String::from_utf8(c.into_bytes().to_vec()).unwrap()) + .unwrap(); // TODO: error handling let token: Result = - serde_json::from_str(&response); + serde_json::from_str(&json_str); match token { Err(e) => return Err(Box::new(e)), @@ -310,7 +317,10 @@ where } } -impl> GetToken for ServiceAccountAccess { +impl GetToken for ServiceAccountAccess +where + C: hyper::client::connect::Connect, +{ fn token<'b, I, T>(&mut self, scopes: I) -> result::Result> where T: AsRef + Ord + 'b, @@ -341,8 +351,7 @@ mod tests { use crate::authenticator::GetToken; use crate::helper::service_account_key_from_file; use hyper; - use hyper::net::HttpsConnector; - use hyper_native_tls::NativeTlsClient; + use hyper_tls::HttpsConnector; // This is a valid but deactivated key. const TEST_PRIVATE_KEY_PATH: &'static str = "examples/Sanguine-69411a0c0eea.json"; @@ -351,9 +360,12 @@ mod tests { //#[test] #[allow(dead_code)] fn test_service_account_e2e() { - let key = service_account_key_from_file(TEST_PRIVATE_KEY_PATH).unwrap(); - let client = - hyper::Client::with_connector(HttpsConnector::new(NativeTlsClient::new().unwrap())); + let key = service_account_key_from_file(&TEST_PRIVATE_KEY_PATH.to_string()).unwrap(); + let https = HttpsConnector::new(4).unwrap(); + let runtime = tokio::runtime::Runtime::new().unwrap(); + let client = hyper::Client::builder() + .executor(runtime.executor()) + .build(https); let mut acc = ServiceAccountAccess::new(key, client); println!( "{:?}", diff --git a/src/types.rs b/src/types.rs index ba29073..42746c1 100644 --- a/src/types.rs +++ b/src/types.rs @@ -19,7 +19,9 @@ pub struct JsonError { /// Encapsulates all possible results of the `request_token(...)` operation pub enum RequestError { /// Indicates connection failure - HttpError(hyper::Error), + ClientError(hyper::Error), + /// Indicates HTTP status failure + HttpError(hyper::http::Error), /// The OAuth client was not found InvalidClient, /// Some requested scopes were invalid. String contains the scopes as part of @@ -30,6 +32,18 @@ pub enum RequestError { NegativeServerResponse(String, Option), } +impl From for RequestError { + fn from(error: hyper::Error) -> RequestError { + RequestError::ClientError(error) + } +} + +impl From for RequestError { + fn from(error: hyper::http::Error) -> RequestError { + RequestError::HttpError(error) + } +} + impl From for RequestError { fn from(value: JsonError) -> RequestError { match &*value.error { @@ -47,6 +61,7 @@ impl From for RequestError { impl fmt::Display for RequestError { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self { + RequestError::ClientError(ref err) => err.fmt(f), RequestError::HttpError(ref err) => err.fmt(f), RequestError::InvalidClient => "Invalid Client".fmt(f), RequestError::InvalidScope(ref scope) => writeln!(f, "Invalid Scope: '{}'", scope), @@ -136,13 +151,14 @@ pub struct Scheme { pub access_token: String, } -impl hyper::header::Scheme for Scheme { - fn scheme() -> Option<&'static str> { - None - } - - fn fmt_scheme(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {}", self.token_type.as_ref(), self.access_token) +impl std::convert::Into for Scheme { + fn into(self) -> hyper::header::HeaderValue { + hyper::header::HeaderValue::from_str(&format!( + "{} {}", + self.token_type.as_ref(), + self.access_token + )) + .expect("Invalid Scheme header value") } } @@ -240,7 +256,7 @@ pub enum FlowType { /// browser to a web server that is running on localhost. This may not work as well with the /// Windows Firewall, but is more comfortable otherwise. The integer describes which port to /// bind to (default: 8080) - InstalledRedirect(u32), + InstalledRedirect(u16), } /// Represents either 'installed' or 'web' applications in a json secrets file. @@ -304,19 +320,18 @@ pub mod tests { token_type: TokenType::Bearer, access_token: "foo".to_string(), }; - let mut headers = hyper::header::Headers::new(); - headers.set(hyper::header::Authorization(s)); + let mut headers = hyper::HeaderMap::new(); + headers.insert(hyper::header::AUTHORIZATION, s.into()); assert_eq!( - headers.to_string(), - "Authorization: Bearer foo\r\n".to_string() + format!("{:?}", headers), + "{\"authorization\": \"Bearer foo\"}".to_string() ); } #[test] fn parse_schema() { - let auth: hyper::header::Authorization = - hyper::header::Header::parse_header(&[b"Bearer foo".to_vec()]).unwrap(); - assert_eq!(auth.0.token_type, TokenType::Bearer); - assert_eq!(auth.0.access_token, "foo".to_string()); + let auth = Scheme::from_str("Bearer foo").unwrap(); + assert_eq!(auth.token_type, TokenType::Bearer); + assert_eq!(auth.access_token, "foo".to_string()); } }