diff --git a/.travis.yml b/.travis.yml index 9205d1d..ad81c00 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,7 +21,8 @@ before_script: script: - | - travis-cargo build && travis-cargo test + travis-cargo build && travis-cargo test && + travis-cargo build -- --features tls && travis-cargo test -- --features tls after_success: - travis-cargo coveralls --no-sudo diff --git a/Cargo.toml b/Cargo.toml index d50554d..37f546c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,17 +6,19 @@ license = "MIT" documentation = "https://docs.rs/tarpc" homepage = "https://github.com/google/tarpc" repository = "https://github.com/google/tarpc" -keywords = ["rpc", "protocol", "remote", "procedure", "serialize"] +keywords = ["rpc", "protocol", "remote", "procedure", "serialize", "tls"] readme = "README.md" description = "An RPC framework for Rust with a focus on ease of use." [dependencies] bincode = "0.6" byteorder = "0.5" +cfg-if = "0.1.0" bytes = "0.3" futures = "0.1.7" lazy_static = "0.2" log = "0.3" +native-tls = { version = "0.1.1", optional = true } scoped-pool = "1.0" serde = "0.8" serde_derive = "0.8" @@ -25,6 +27,7 @@ take = "0.1" tokio-service = "0.1" tokio-proto = "0.1" tokio-core = "0.1" +tokio-tls = { version = "0.1", optional = true } net2 = "0.2" [dev-dependencies] @@ -33,7 +36,12 @@ env_logger = "0.3" futures-cpupool = "0.1" clap = "2.0" +[target.'cfg(target_os = "macos")'.dev-dependencies] +security-framework = "0.1" + [features] +default = [] +tls = ["tokio-tls", "native-tls"] unstable = ["serde/unstable"] [workspace] diff --git a/README.md b/README.md index d3fbb67..da7b06c 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,85 @@ fn main() { } ``` +## Example: Futures + TLS + +By default, tarpc internally uses a [`TcpStream`] for communication between your clients and +servers. However, TCP by itself has no encryption. As a result, your communication will be sent in +the clear. If you want your RPC communications to be encrypted, you can choose to use [TLS]. TLS +operates as an encryption layer on top of TCP. When using TLS, your communication will occur over a +[`TlsStream`]. You can add the ability to make TLS clients and servers by adding `tarpc` +with the `tls` feature flag enabled. + +When using TLS, some additional information is required. You will need to make [`TlsAcceptor`] and +`client::tls::Context` structs; `client::tls::Context` requires a [`TlsConnector`]. The +[`TlsAcceptor`] and [`TlsConnector`] types are defined in the [native-tls]. tarpc re-exports +external TLS-related types in its `native_tls` module (`tarpc::native_tls`). + +[TLS]: https://en.wikipedia.org/wiki/Transport_Layer_Security +[`TcpStream`]: https://docs.rs/tokio-core/0.1/tokio_core/net/struct.TcpStream.html +[`TlsStream`]: https://docs.rs/native-tls/0.1/native_tls/struct.TlsStream.html +[`TlsAcceptor`]: https://docs.rs/native-tls/0.1/native_tls/struct.TlsAcceptor.html +[`TlsConnector`]: https://docs.rs/native-tls/0.1/native_tls/struct.TlsConnector.html +[native-tls]: https://github.com/sfackler/rust-native-tls + +Both TLS streams and TCP streams are supported in the same binary when the `tls` feature is enabled. +However, if you are working with both stream types, ensure that you use the TLS clients with TLS +servers and TCP clients with TCP servers. + +```rust +#![feature(conservative_impl_trait, plugin)] +#![plugin(tarpc_plugins)] + +extern crate futures; +#[macro_use] +extern crate tarpc; +extern crate tokio_core; + +use futures::Future; +use tarpc::{client, server}; +use tarpc::client::future::Connect; +use tarpc::util::{FirstSocketAddr, Never}; +use tokio_core::reactor; +use tarpc::native_tls::{Pkcs12, TlsAcceptor}; + +service! { + rpc hello(name: String) -> String; +} + +#[derive(Clone)] +struct HelloServer; + +impl FutureService for HelloServer { + type HelloFut = futures::Finished; + + fn hello(&mut self, name: String) -> Self::HelloFut { + futures::finished(format!("Hello, {}!", name)) + } +} + +fn get_acceptor() -> TlsAcceptor { + let buf = include_bytes!("test/identity.p12"); + let pkcs12 = Pkcs12::from_der(buf, "password").unwrap(); + TlsAcceptor::builder(pkcs12).unwrap().build().unwrap() +} + +fn main() { + let addr = "localhost:10000".first_socket_addr(); + let mut core = reactor::Core::new().unwrap(); + let acceptor = get_acceptor(); + HelloServer.listen(addr, server::Options::default() + .handle(core.handle()) + .tls(acceptor)).wait().unwrap(); + let options = client::Options::default().handle(core.handle() + .tls(client::tls::Context::new("foobar.com").unwrap())); + core.run(FutureClient::connect(addr, options) + .map_err(tarpc::Error::from) + .and_then(|client| client.hello("Mom".to_string())) + .map(|resp| println!("{}", resp))) + .unwrap(); +} +``` + ## Tips ### Sync vs Futures diff --git a/hooks/pre-push b/hooks/pre-push index 6d63c15..b3ddbfe 100755 --- a/hooks/pre-push +++ b/hooks/pre-push @@ -70,10 +70,12 @@ run_cargo() { rustup run $2 cargo $1 &>/dev/null else rustup run nightly cargo $1 --features unstable &>/dev/null + rustup run nightly cargo $1 --features unstable,tls &>/dev/null fi else printf "${PREFIX} $VERB... " cargo $1 &>/dev/null + cargo $1 --features tls &>/dev/null fi if [ "$?" != "0" ]; then printf "${FAILURE}\n" diff --git a/src/client.rs b/src/client.rs index ac4768c..a1a2b19 100644 --- a/src/client.rs +++ b/src/client.rs @@ -7,10 +7,12 @@ use {Reactor, WireError}; use bincode::serde::DeserializeError; use futures::{self, Future}; use protocol::Proto; +#[cfg(feature = "tls")] +use self::tls::*; use serde::{Deserialize, Serialize}; use std::fmt; use std::io; -use tokio_core::net::TcpStream; +use stream_type::StreamType; use tokio_core::reactor; use tokio_proto::BindClient as ProtoBindClient; use tokio_proto::multiplex::Multiplex; @@ -19,8 +21,48 @@ use tokio_service::Service; type WireResponse = Result>, DeserializeError>; type ResponseFuture = futures::Map< as Service>::Future, fn(WireResponse) -> Result>>; -type BindClient = - >> as ProtoBindClient>::BindClient; +type BindClient = >> as + ProtoBindClient>::BindClient; + +/// TLS-specific functionality +#[cfg(feature = "tls")] +pub mod tls { + use native_tls::{Error, TlsConnector}; + + /// TLS context for client + pub struct Context { + /// Domain to connect to + pub domain: String, + /// TLS connector + pub tls_connector: TlsConnector, + } + + impl Context { + /// Try to construct a new `Context`. + /// + /// The provided domain will be used for both + /// [SNI](https://en.wikipedia.org/wiki/Server_Name_Indication) and certificate hostname + /// validation. + pub fn new>(domain: S) -> Result { + Ok(Context { + domain: domain.into(), + tls_connector: TlsConnector::builder()?.build()?, + }) + } + + /// Construct a new `Context` using the provided domain and `TlsConnector` + /// + /// The domain will be used for both + /// [SNI](https://en.wikipedia.org/wiki/Server_Name_Indication) and certificate hostname + /// validation. + pub fn from_connector>(domain: S, tls_connector: TlsConnector) -> Self { + Context { + domain: domain.into(), + tls_connector: tls_connector, + } + } + } +} /// A client that impls `tokio_service::Service` that writes and reads bytes. /// @@ -91,9 +133,11 @@ impl fmt::Debug for Client } /// Additional options to configure how the client connects and operates. -#[derive(Clone, Default)] +#[derive(Default)] pub struct Options { reactor: Option, + #[cfg(feature = "tls")] + tls_ctx: Option, } impl Options { @@ -108,11 +152,17 @@ impl Options { self.reactor = Some(Reactor::Remote(remote)); self } + + /// Connect using the given `Context` + #[cfg(feature = "tls")] + pub fn tls(mut self, tls_ctx: Context) -> Self { + self.tls_ctx = Some(tls_ctx); + self + } } /// Exposes a trait for connecting asynchronously to servers. pub mod future { - use super::{Client, Options}; use {REMOTE, Reactor}; use futures::{self, Async, Future, future}; use protocol::Proto; @@ -120,9 +170,18 @@ pub mod future { use std::io; use std::marker::PhantomData; use std::net::SocketAddr; - use tokio_core::{self, reactor}; - use tokio_core::net::TcpStream; + use stream_type::StreamType; + use super::{Client, Options}; + use tokio_core::net::{TcpStream, TcpStreamNew}; + use tokio_core::reactor; use tokio_proto::BindClient; + cfg_if! { + if #[cfg(feature = "tls")] { + use tokio_tls::{ConnectAsync, TlsStream, TlsConnectorExt}; + use super::tls::Context; + use errors::native_to_io; + } else {} + } /// Types that can connect to a server asynchronously. pub trait Connect: Sized { @@ -133,21 +192,26 @@ pub mod future { fn connect(addr: SocketAddr, options: Options) -> Self::ConnectFut; } + type ConnectFutureInner = future::Either, MultiplexConnect>, futures::Flatten< + futures::MapErr>>, + fn(futures::Canceled) -> io::Error>>>; + /// A future that resolves to a `Client` or an `io::Error`. #[doc(hidden)] pub struct ConnectFuture where Req: Serialize + 'static, Resp: Deserialize + 'static, - E: Deserialize + 'static, + E: Deserialize + 'static { - #[allow(unknown_lints, type_complexity)] - inner: - future::Either< - futures::Map>, - futures::Flatten< - futures::MapErr< - futures::Oneshot>>, - fn(futures::Canceled) -> io::Error>>>, + #[cfg(not(feature = "tls"))] + #[cfg_attr(feature = "cargo-clippy", allow(type_complexity))] + inner: ConnectFutureInner>, + #[cfg(feature = "tls")] + #[cfg_attr(feature = "cargo-clippy", allow(type_complexity))] + inner: ConnectFutureInner, futures::Map, + fn(::native_tls::Error) -> io::Error>, fn(TlsStream) -> StreamType>>>, } impl Future for ConnectFuture @@ -175,15 +239,49 @@ pub mod future { } } - impl FnOnce<(TcpStream,)> for MultiplexConnect + impl FnOnce<(I,)> for MultiplexConnect where Req: Serialize + Sync + Send + 'static, Resp: Deserialize + Sync + Send + 'static, - E: Deserialize + Sync + Send + 'static + E: Deserialize + Sync + Send + 'static, + I: Into { type Output = Client; - extern "rust-call" fn call_once(self, (tcp,): (TcpStream,)) -> Client { - Client::new(Proto::new().bind_client(&self.0, tcp)) + extern "rust-call" fn call_once(self, (stream,): (I,)) -> Self::Output { + Client::new(Proto::new().bind_client(&self.0, stream.into())) + } + } + + /// Provides the connection Fn impl for Tls + struct ConnectFn { + #[cfg(feature = "tls")] + tls_ctx: Option, + } + + impl FnOnce<(TcpStream,)> for ConnectFn { + #[cfg(feature = "tls")] + #[cfg_attr(feature = "cargo-clippy", allow(type_complexity))] + type Output = future::Either, + futures::Map, + fn(::native_tls::Error) + -> io::Error>, + fn(TlsStream) -> StreamType>>; + #[cfg(not(feature = "tls"))] + type Output = future::FutureResult; + + extern "rust-call" fn call_once(self, (tcp,): (TcpStream,)) -> Self::Output { + #[cfg(feature = "tls")] + match self.tls_ctx { + None => future::Either::A(future::ok(StreamType::from(tcp))), + Some(tls_ctx) => { + future::Either::B(tls_ctx.tls_connector + .connect_async(&tls_ctx.domain, tcp) + .map_err(native_to_io as fn(_) -> _) + .map(StreamType::from as fn(_) -> _)) + } + } + #[cfg(not(feature = "tls"))] + future::ok(StreamType::from(tcp)) } } @@ -195,10 +293,31 @@ pub mod future { type ConnectFut = ConnectFuture; fn connect(addr: SocketAddr, options: Options) -> Self::ConnectFut { + // we need to do this for tls because we need to avoid moving the entire `Options` + // struct into the `setup` closure, since `Reactor` is not `Send`. + #[cfg(feature = "tls")] + let mut options = options; + #[cfg(feature = "tls")] + let tls_ctx = options.tls_ctx.take(); + let setup = move |tx: futures::sync::oneshot::Sender<_>| { move |handle: &reactor::Handle| { let handle2 = handle.clone(); TcpStream::connect(&addr, handle) + .and_then(move |socket| { + #[cfg(feature = "tls")] + match tls_ctx { + Some(tls_ctx) => { + future::Either::A(tls_ctx.tls_connector + .connect_async(&tls_ctx.domain, socket) + .map(StreamType::Tls) + .map_err(native_to_io)) + } + None => future::Either::B(future::ok(StreamType::Tcp(socket))), + } + #[cfg(not(feature = "tls"))] + future::ok(StreamType::Tcp(socket)) + }) .map(move |tcp| Client::new(Proto::new().bind_client(&handle2, tcp))) .then(move |result| { tx.complete(result); @@ -206,9 +325,16 @@ pub mod future { }) } }; + let rx = match options.reactor { Some(Reactor::Handle(handle)) => { - let tcp = TcpStream::connect(&addr, &handle).map(MultiplexConnect::new(handle)); + #[cfg(feature = "tls")] + let connect_fn = ConnectFn { tls_ctx: options.tls_ctx }; + #[cfg(not(feature = "tls"))] + let connect_fn = ConnectFn {}; + let tcp = TcpStream::connect(&addr, &handle) + .and_then(connect_fn) + .map(MultiplexConnect::new(handle)); return ConnectFuture { inner: future::Either::A(tcp) }; } Some(Reactor::Remote(remote)) => { @@ -232,12 +358,12 @@ pub mod future { /// Exposes a trait for connecting synchronously to servers. pub mod sync { - use super::{Client, Options}; use client::future::Connect as FutureConnect; use futures::{Future, future}; use serde::{Deserialize, Serialize}; use std::io; use std::net::ToSocketAddrs; + use super::{Client, Options}; use util::FirstSocketAddr; /// Types that can connect to a server synchronously. diff --git a/src/errors.rs b/src/errors.rs index 01329b9..3bb4189 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -102,3 +102,9 @@ pub enum WireError { /// The server was unable to reply to the rpc for some reason. App(E), } + +/// Convert `native_tls::Error` to `std::io::Error` +#[cfg(feature = "tls")] +pub fn native_to_io(e: ::native_tls::Error) -> io::Error { + io::Error::new(io::ErrorKind::Other, e) +} diff --git a/src/lib.rs b/src/lib.rs index 434f1fe..da44cec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,6 +59,52 @@ //! } //! ``` //! +//! Example usage with TLS: +//! +//! ```ignore +//! // required by `FutureClient` (not used in this example) +//! #![feature(conservative_impl_trait, plugin)] +//! #![plugin(tarpc_plugins)] +//! +//! #[macro_use] +//! extern crate tarpc; +//! +//! use tarpc::{client, server}; +//! use tarpc::client::sync::Connect; +//! use tarpc::util::Never; +//! use tarpc::native_tls::{TlsAcceptor, Pkcs12}; +//! +//! service! { +//! rpc hello(name: String) -> String; +//! } +//! +//! #[derive(Clone)] +//! struct HelloServer; +//! +//! impl SyncService for HelloServer { +//! fn hello(&self, name: String) -> Result { +//! Ok(format!("Hello, {}!", name)) +//! } +//! } +//! +//! fn get_acceptor() -> TlsAcceptor { +//! let buf = include_bytes!("test/identity.p12"); +//! let pkcs12 = Pkcs12::from_der(buf, "password").unwrap(); +//! TlsAcceptor::builder(pkcs12).unwrap().build().unwrap() +//! } +//! +//! fn main() { +//! let addr = "localhost:10000"; +//! let acceptor = get_acceptor(); +//! let _server = HelloServer.listen(addr, server::Options::default().tls(acceptor)); +//! let client = SyncClient::connect(addr, +//! client::Options::default() +//! .tls(client::tls::Context::new("foobar.com").unwrap())) +//! .unwrap(); +//! println!("{}", client.hello("Mom".to_string()).unwrap()); +//! } +//! ``` +//! #![deny(missing_docs)] #![feature(plugin, conservative_impl_trait, never_type, unboxed_closures, fn_traits, specialization)] @@ -74,6 +120,8 @@ extern crate net2; #[macro_use] extern crate serde_derive; extern crate take; +#[macro_use] +extern crate cfg_if; #[doc(hidden)] pub extern crate bincode; @@ -108,6 +156,8 @@ pub mod server; mod protocol; /// Provides a few different error types. mod errors; +/// Provides an abstraction over TLS and TCP streams. +mod stream_type; use std::sync::mpsc; use std::thread; @@ -138,3 +188,15 @@ enum Reactor { Handle(reactor::Handle), Remote(reactor::Remote), } + +cfg_if! { + if #[cfg(feature = "tls")] { + extern crate tokio_tls; + extern crate native_tls as native_tls_inner; + + /// Re-exported TLS-related types + pub mod native_tls { + pub use native_tls_inner::{Error, Pkcs12, TlsAcceptor, TlsConnector}; + } + } else {} +} diff --git a/src/macros.rs b/src/macros.rs index 4207bd8..2eaeffa 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -696,21 +696,169 @@ mod syntax_test { #[cfg(test)] mod functional_test { + use {client, server}; use futures::{Future, failed}; + use std::io; + use std::net::SocketAddr; use util::FirstSocketAddr; extern crate env_logger; + macro_rules! unwrap { + ($e:expr) => (match $e { + Ok(e) => e, + Err(e) => panic!("{} failed with {:?}", stringify!($e), e), + }) + } + service! { rpc add(x: i32, y: i32) -> i32; rpc hey(name: String) -> String; } + cfg_if! { + if #[cfg(feature = "tls")] { + const DOMAIN: &'static str = "foobar.com"; + + use client::tls::Context; + use native_tls::{Pkcs12, TlsAcceptor, TlsConnector}; + + fn tls_context() -> (server::Options, client::Options) { + let buf = include_bytes!("../test/identity.p12"); + let pkcs12 = unwrap!(Pkcs12::from_der(buf, "mypass")); + let acceptor = unwrap!(unwrap!(TlsAcceptor::builder(pkcs12)).build()); + let server_options = server::Options::default().tls(acceptor); + let client_options = get_tls_client_options(); + + (server_options, client_options) + } + + // Making the TlsConnector for testing needs to be OS-dependent just like native-tls. + // We need to go through this trickery because the test self-signed cert is not part + // of the system's cert chain. If it was, then all that is required is + // `TlsConnector::builder().unwrap().build().unwrap()`. + cfg_if! { + if #[cfg(target_os = "macos")] { + extern crate security_framework; + + use self::security_framework::certificate::SecCertificate; + use native_tls_inner::backend::security_framework::TlsConnectorBuilderExt; + + fn get_tls_client_options() -> client::Options { + let buf = include_bytes!("../test/root-ca.der"); + let cert = unwrap!(SecCertificate::from_der(buf)); + let mut connector = unwrap!(TlsConnector::builder()); + connector.anchor_certificates(&[cert]); + + client::Options::default().tls(Context { + domain: DOMAIN.into(), + tls_connector: unwrap!(connector.build()), + }) + } + } else if #[cfg(all(not(target_os = "macos"), not(windows)))] { + use native_tls_inner::backend::openssl::TlsConnectorBuilderExt; + + fn get_tls_client_options() -> client::Options { + let mut connector = unwrap!(TlsConnector::builder()); + unwrap!(connector.builder_mut() + .builder_mut() + .set_ca_file("test/root-ca.pem")); + + client::Options::default().tls(Context { + domain: DOMAIN.into(), + tls_connector: unwrap!(connector.build()), + }) + } + // not implemented for windows or other platforms + } else { + fn get_tls_client_context() -> Context { + unimplemented!() + } + } + } + + fn get_sync_client(addr: SocketAddr) -> io::Result + where C: client::sync::Connect + { + let client_options = get_tls_client_options(); + C::connect(addr, client_options) + } + + fn start_server_with_sync_client(server: S) -> (SocketAddr, io::Result) + where C: client::sync::Connect, S: SyncServiceExt + { + let (server_options, client_options) = tls_context(); + let addr = unwrap!(server.listen("localhost:0".first_socket_addr(), + server_options)); + let client = C::connect(addr, client_options); + (addr, client) + } + + fn start_server_with_async_client(server: S) -> (SocketAddr, C) + where C: client::future::Connect, S: FutureServiceExt + { + let (server_options, client_options) = tls_context(); + let addr = unwrap!(server.listen("localhost:0".first_socket_addr(), + server_options).wait()); + let client = unwrap!(C::connect(addr, client_options).wait()); + (addr, client) + } + + fn start_err_server_with_async_client(server: S) -> (SocketAddr, C) + where C: client::future::Connect, S: error_service::FutureServiceExt + { + let (server_options, client_options) = tls_context(); + let addr = unwrap!(server.listen("localhost:0".first_socket_addr(), + server_options).wait()); + let client = unwrap!(C::connect(addr, client_options).wait()); + (addr, client) + } + } else { + fn get_server_options() -> server::Options { + server::Options::default() + } + + fn get_client_options() -> client::Options { + client::Options::default() + } + + fn get_sync_client(addr: SocketAddr) -> io::Result + where C: client::sync::Connect + { + C::connect(addr, get_client_options()) + } + + fn start_server_with_sync_client(server: S) -> (SocketAddr, io::Result) + where C: client::sync::Connect, S: SyncServiceExt + { + let addr = unwrap!(server.listen("localhost:0".first_socket_addr(), + get_server_options())); + let client = C::connect(addr, get_client_options()); + (addr, client) + } + + fn start_server_with_async_client(server: S) -> (SocketAddr, C) + where C: client::future::Connect, S: FutureServiceExt + { + let addr = unwrap!(server.listen("localhost:0".first_socket_addr(), + get_server_options()).wait()); + let client = unwrap!(C::connect(addr, get_client_options()).wait()); + (addr, client) + } + + fn start_err_server_with_async_client(server: S) -> (SocketAddr, C) + where C: client::future::Connect, S: error_service::FutureServiceExt + { + let addr = unwrap!(server.listen("localhost:0".first_socket_addr(), + get_server_options()).wait()); + let client = unwrap!(C::connect(addr, get_client_options()).wait()); + (addr, client) + } + } + } + + mod sync { - use super::{SyncClient, SyncService, SyncServiceExt}; - use super::env_logger; - use {client, server}; - use client::sync::Connect; - use util::FirstSocketAddr; + use super::{SyncClient, SyncService, env_logger, start_server_with_sync_client}; use util::Never; #[derive(Clone, Copy)] @@ -728,10 +876,8 @@ mod functional_test { #[test] fn simple() { let _ = env_logger::init(); - let addr = Server.listen("localhost:0".first_socket_addr(), - server::Options::default()) - .unwrap(); - let client = SyncClient::connect(addr, client::Options::default()).unwrap(); + let (_, client) = start_server_with_sync_client::(Server); + let client = unwrap!(client); assert_eq!(3, client.add(1, 2).unwrap()); assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).unwrap()); } @@ -739,13 +885,10 @@ mod functional_test { #[test] fn other_service() { let _ = env_logger::init(); - let addr = Server.listen("localhost:0".first_socket_addr(), - server::Options::default()) - .unwrap(); - let client = super::other_service::SyncClient::connect(addr, - client::Options::default()) - .expect("Could not connect!"); - match client.foo().err().unwrap() { + let (_, client) = start_server_with_sync_client::(Server); + let client = client.expect("Could not connect!"); + match client.foo().err().expect("failed unwrap") { ::Error::ServerDeserialize(_) => {} // good bad => panic!("Expected Error::ServerDeserialize but got {}", bad), } @@ -753,12 +896,8 @@ mod functional_test { } mod future { - use super::{FutureClient, FutureService, FutureServiceExt}; - use super::env_logger; - use {client, server}; - use client::future::Connect; use futures::{Finished, Future, finished}; - use util::FirstSocketAddr; + use super::{FutureClient, FutureService, env_logger, start_server_with_async_client}; use util::Never; #[derive(Clone)] @@ -781,11 +920,7 @@ mod functional_test { #[test] fn simple() { let _ = env_logger::init(); - let addr = Server.listen("localhost:0".first_socket_addr(), - server::Options::default()) - .wait() - .unwrap(); - let client = FutureClient::connect(addr, client::Options::default()).wait().unwrap(); + let (_, client) = start_server_with_async_client::(Server); assert_eq!(3, client.add(1, 2).wait().unwrap()); assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).wait().unwrap()); } @@ -793,11 +928,7 @@ mod functional_test { #[test] fn concurrent() { let _ = env_logger::init(); - let addr = Server.listen("localhost:0".first_socket_addr(), - server::Options::default()) - .wait() - .unwrap(); - let client = FutureClient::connect(addr, client::Options::default()).wait().unwrap(); + let (_, client) = start_server_with_async_client::(Server); let req1 = client.add(1, 2); let req2 = client.add(3, 4); let req3 = client.hey("Tim".to_string()); @@ -809,14 +940,9 @@ mod functional_test { #[test] fn other_service() { let _ = env_logger::init(); - let addr = Server.listen("localhost:0".first_socket_addr(), - server::Options::default()) - .wait() - .unwrap(); - let client = super::other_service::FutureClient::connect(addr, - client::Options::default()) - .wait() - .unwrap(); + let (_, client) = + start_server_with_async_client::(Server); match client.foo().wait().err().unwrap() { ::Error::ServerDeserialize(_) => {} // good bad => panic!(r#"Expected Error::ServerDeserialize but got "{}""#, bad), @@ -825,6 +951,10 @@ mod functional_test { #[test] fn reuse_addr() { + use util::FirstSocketAddr; + use server; + use super::FutureServiceExt; + let _ = env_logger::init(); let addr = Server.listen("localhost:0".first_socket_addr(), server::Options::default()) .wait() @@ -833,6 +963,28 @@ mod functional_test { .wait() .unwrap(); } + + #[cfg(feature = "tls")] + #[test] + fn tcp_and_tls() { + use {client, server}; + use util::FirstSocketAddr; + use client::future::Connect; + use super::FutureServiceExt; + + let _ = env_logger::init(); + let (_, client) = start_server_with_async_client::(Server); + assert_eq!(3, client.add(1, 2).wait().unwrap()); + assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).wait().unwrap()); + + let addr = Server.listen("localhost:0".first_socket_addr(), + server::Options::default()) + .wait() + .unwrap(); + let client = FutureClient::connect(addr, client::Options::default()).wait().unwrap(); + assert_eq!(3, client.add(1, 2).wait().unwrap()); + assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).wait().unwrap()); + } } pub mod error_service { @@ -855,18 +1007,12 @@ mod functional_test { #[test] fn error() { - use {client, server}; - use client::future::Connect as Fc; - use client::sync::Connect as Sc; use std::error::Error as E; use self::error_service::*; let _ = env_logger::init(); - let addr = ErrorServer.listen("localhost:0".first_socket_addr(), - server::Options::default()) - .wait() - .unwrap(); - let client = FutureClient::connect(addr, client::Options::default()).wait().unwrap(); + let (addr, client) = start_err_server_with_async_client::(ErrorServer); client.bar() .then(move |result| { match result.err().unwrap() { @@ -880,7 +1026,7 @@ mod functional_test { .wait() .unwrap(); - let client = SyncClient::connect(&addr, client::Options::default()).unwrap(); + let client = get_sync_client::(addr).unwrap(); match client.bar().err().unwrap() { ::Error::App(e) => { assert_eq!(e.description(), "lol jk"); diff --git a/src/server.rs b/src/server.rs index 5fb8fd7..2739c46 100644 --- a/src/server.rs +++ b/src/server.rs @@ -17,10 +17,27 @@ use tokio_core::reactor::{self, Handle}; use tokio_proto::BindServer; use tokio_service::NewService; +cfg_if! { + if #[cfg(feature = "tls")] { + use native_tls::TlsAcceptor; + use tokio_tls::TlsAcceptorExt; + use errors::native_to_io; + use stream_type::StreamType; + } else {} +} + +enum Acceptor { + Tcp, + #[cfg(feature = "tls")] + Tls(TlsAcceptor), +} + /// Additional options to configure how the server operates. -#[derive(Clone, Default)] +#[derive(Default)] pub struct Options { reactor: Option, + #[cfg(feature = "tls")] + tls_acceptor: Option, } impl Options { @@ -35,6 +52,13 @@ impl Options { self.reactor = Some(Reactor::Remote(remote)); self } + + /// Set the `TlsAcceptor` + #[cfg(feature = "tls")] + pub fn tls(mut self, tls_acceptor: TlsAcceptor) -> Self { + self.tls_acceptor = Some(tls_acceptor); + self + } } /// A message from server to client. @@ -50,24 +74,37 @@ pub fn listen(new_service: S, addr: SocketAddr, options: Option Resp: Serialize + 'static, E: Serialize + 'static { + // Similar to the client, since `Options` is not `Send`, we take the `TlsAcceptor` when it is + // available. + #[cfg(feature = "tls")] + let acceptor = match options.tls_acceptor { + Some(tls_acceptor) => Acceptor::Tls(tls_acceptor), + None => Acceptor::Tcp, + }; + #[cfg(not(feature = "tls"))] + let acceptor = Acceptor::Tcp; + match options.reactor { None => { let (tx, rx) = futures::oneshot(); REMOTE.spawn(move |handle| { - Ok(tx.complete(listen_with(new_service, addr, handle.clone()))) + Ok(tx.complete(listen_with(new_service, addr, handle.clone(), acceptor))) }); ListenFuture { inner: future::Either::A(rx) } } Some(Reactor::Remote(remote)) => { let (tx, rx) = futures::oneshot(); remote.spawn(move |handle| { - Ok(tx.complete(listen_with(new_service, addr, handle.clone()))) + Ok(tx.complete(listen_with(new_service, addr, handle.clone(), acceptor))) }); ListenFuture { inner: future::Either::A(rx) } } Some(Reactor::Handle(handle)) => { ListenFuture { - inner: future::Either::B(future::ok(listen_with(new_service, addr, handle))), + inner: future::Either::B(future::ok(listen_with(new_service, + addr, + handle, + acceptor))), } } } @@ -76,7 +113,8 @@ pub fn listen(new_service: S, addr: SocketAddr, options: Option /// Spawns a service that binds to the given address using the given handle. fn listen_with(new_service: S, addr: SocketAddr, - handle: Handle) + handle: Handle, + _acceptor: Acceptor) -> io::Result where S: NewService, Response = Response, @@ -89,8 +127,22 @@ fn listen_with(new_service: S, let addr = listener.local_addr()?; let handle2 = handle.clone(); + let server = listener.incoming() - .for_each(move |(socket, _)| { + .and_then(move |(socket, _)| { + #[cfg(feature = "tls")] + match _acceptor { + Acceptor::Tls(ref tls_acceptor) => { + future::Either::A(tls_acceptor.accept_async(socket) + .map(StreamType::Tls) + .map_err(native_to_io)) + } + Acceptor::Tcp => future::Either::B(future::ok(StreamType::Tcp(socket))), + } + #[cfg(not(feature = "tls"))] + future::ok(socket) + }) + .for_each(move |socket| { Proto::new().bind_server(&handle2, socket, new_service.new_service()?); Ok(()) diff --git a/src/stream_type.rs b/src/stream_type.rs new file mode 100644 index 0000000..78fbe17 --- /dev/null +++ b/src/stream_type.rs @@ -0,0 +1,55 @@ +use std::io; +use tokio_core::io::Io; +use tokio_core::net::TcpStream; +#[cfg(feature = "tls")] +use tokio_tls::TlsStream; + +#[derive(Debug)] +pub enum StreamType { + Tcp(TcpStream), + #[cfg(feature = "tls")] + Tls(TlsStream), +} + +impl From for StreamType { + fn from(stream: TcpStream) -> Self { + StreamType::Tcp(stream) + } +} + +#[cfg(feature = "tls")] +impl From> for StreamType { + fn from(stream: TlsStream) -> Self { + StreamType::Tls(stream) + } +} + +impl io::Read for StreamType { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match *self { + StreamType::Tcp(ref mut stream) => stream.read(buf), + #[cfg(feature = "tls")] + StreamType::Tls(ref mut stream) => stream.read(buf), + } + } +} + +impl io::Write for StreamType { + fn write(&mut self, buf: &[u8]) -> io::Result { + match *self { + StreamType::Tcp(ref mut stream) => stream.write(buf), + #[cfg(feature = "tls")] + StreamType::Tls(ref mut stream) => stream.write(buf), + } + } + + fn flush(&mut self) -> io::Result<()> { + match *self { + StreamType::Tcp(ref mut stream) => stream.flush(), + #[cfg(feature = "tls")] + StreamType::Tls(ref mut stream) => stream.flush(), + } + } +} + +impl Io for StreamType {} diff --git a/test/identity.p12 b/test/identity.p12 new file mode 100644 index 0000000..d16abb8 Binary files /dev/null and b/test/identity.p12 differ diff --git a/test/root-ca.der b/test/root-ca.der new file mode 100644 index 0000000..a9335c6 Binary files /dev/null and b/test/root-ca.der differ diff --git a/test/root-ca.pem b/test/root-ca.pem new file mode 100644 index 0000000..4ec2f53 --- /dev/null +++ b/test/root-ca.pem @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDXTCCAkWgAwIBAgIJAOIvDiVb18eVMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTYwODE0MTY1NjExWhcNMjYwODEyMTY1NjExWjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEArVHWFn52Lbl1l59exduZntVSZyDYpzDND+S2LUcO6fRBWhV/1Kzox+2G +ZptbuMGmfI3iAnb0CFT4uC3kBkQQlXonGATSVyaFTFR+jq/lc0SP+9Bd7SBXieIV +eIXlY1TvlwIvj3Ntw9zX+scTA4SXxH6M0rKv9gTOub2vCMSHeF16X8DQr4XsZuQr +7Cp7j1I4aqOJyap5JTl5ijmG8cnu0n+8UcRlBzy99dLWJG0AfI3VRJdWpGTNVZ92 +aFff3RpK3F/WI2gp3qV1ynRAKuvmncGC3LDvYfcc2dgsc1N6Ffq8GIrkgRob6eBc +klDHp1d023Lwre+VaVDSo1//Y72UFwIDAQABo1AwTjAdBgNVHQ4EFgQUbNOlA6sN +XyzJjYqciKeId7g3/ZowHwYDVR0jBBgwFoAUbNOlA6sNXyzJjYqciKeId7g3/Zow +DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAVVaR5QWLZIRR4Dw6TSBn +BQiLpBSXN6oAxdDw6n4PtwW6CzydaA+creiK6LfwEsiifUfQe9f+T+TBSpdIYtMv +Z2H2tjlFX8VrjUFvPrvn5c28CuLI0foBgY8XGSkR2YMYzWw2jPEq3Th/KM5Catn3 +AFm3bGKWMtGPR4v+90chEN0jzaAmJYRrVUh9vea27bOCn31Nse6XXQPmSI6Gyncy +OAPUsvPClF3IjeL1tmBotWqSGn1cYxLo+Lwjk22A9h6vjcNQRyZF2VLVvtwYrNU3 +mwJ6GCLsLHpwW/yjyvn8iEltnJvByM/eeRnfXV6WDObyiZsE/n6DxIRJodQzFqy9 +GA== +-----END CERTIFICATE-----