From 5add81b5f33f386f2abde024ea7adf748b991dc5 Mon Sep 17 00:00:00 2001 From: Tim Date: Fri, 31 Mar 2017 12:16:40 -0700 Subject: [PATCH] Feature rollup (#129) * Create a directory for the `future::server` module, which has become quite large. server.rs => server/mod.rs. Server submodules for shutdown and connection logic are added. * Add fn thread_pool(...) to sync::server::Options * Configure idle threads to expire after one minute * Add tarpc::util::lazy for lazily executing functions. Similar to `futures::lazy` but useful in different circumstances. Specifically, `futures::lazy` typically requires a closure, whereas `util::lazy` kind of deconstructs a closure into its function and args. * Remove some unstable features, and `cfg(plugin)` only in tests. Features `unboxed_closures` and `fn_traits` are removed by replacing manual Fn impls with Stream impls. This actually leads to slightly more performant code, as well, because some `Rc`s could be removed. * Fix tokio deprecation warnings. Update to use tokio-io in lieu of deprecated tokio-core items. impl AsyncRead's optional `unsafe fn prepare_uninitialized_buffer` for huge perf wins * Add debug impls to all public items and add `deny(missing_debug_implementations)` to the crate. * Bump tokio core version. --- Cargo.toml | 10 +- examples/readme_sync.rs | 2 +- src/future/client.rs | 19 +- src/future/server.rs | 657 -------------------------------- src/future/server/connection.rs | 76 ++++ src/future/server/mod.rs | 448 ++++++++++++++++++++++ src/future/server/shutdown.rs | 181 +++++++++ src/lib.rs | 18 +- src/macros.rs | 411 ++++++++------------ src/protocol.rs | 75 ++-- src/stream_type.rs | 43 ++- src/sync/client.rs | 27 +- src/sync/server.rs | 196 ++++++++-- src/tls.rs | 12 + src/util.rs | 69 +++- 15 files changed, 1262 insertions(+), 982 deletions(-) delete mode 100644 src/future/server.rs create mode 100644 src/future/server/connection.rs create mode 100644 src/future/server/mod.rs create mode 100644 src/future/server/shutdown.rs diff --git a/Cargo.toml b/Cargo.toml index cf5f391..2954f71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,16 +17,20 @@ travis-ci = { repository = "google/tarpc" } [dependencies] bincode = "1.0.0-alpha6" byteorder = "1.0" +bytes = "0.4" cfg-if = "0.1.0" -futures = "0.1.7" +futures = "0.1.11" lazy_static = "0.2" log = "0.3" net2 = "0.2" +num_cpus = "1.0" serde = "0.9" serde_derive = "0.9" tarpc-plugins = { path = "src/plugins" } -tokio-core = "0.1" -tokio-proto = "0.1" +thread-pool = "0.1.1" +tokio-core = "0.1.6" +tokio-io = "0.1" +tokio-proto = "0.1.1" tokio-service = "0.1" # Optional dependencies diff --git a/examples/readme_sync.rs b/examples/readme_sync.rs index 9c2b9fa..f40a5a6 100644 --- a/examples/readme_sync.rs +++ b/examples/readme_sync.rs @@ -27,7 +27,7 @@ struct HelloServer; impl SyncService for HelloServer { fn hello(&self, name: String) -> Result { - Ok(format!("Hello, {}!", name)) + Ok(format!("Hello from thread {}, {}!", thread::current().name().unwrap(), name)) } } diff --git a/src/future/client.rs b/src/future/client.rs index 0914e2c..5156743 100644 --- a/src/future/client.rs +++ b/src/future/client.rs @@ -27,6 +27,7 @@ cfg_if! { } /// Additional options to configure how the client connects and operates. +#[derive(Debug)] pub struct Options { /// Max packet size in bytes. max_payload_size: u64, @@ -55,7 +56,7 @@ impl Default for Options { } impl Options { - /// Set the max payload size in bytes. The default is 2,000,000 (2 MB). + /// Set the max payload size in bytes. The default is 2 << 20 (2 MiB). pub fn max_payload_size(mut self, bytes: u64) -> Self { self.max_payload_size = bytes; self @@ -86,6 +87,19 @@ enum Reactor { Remote(reactor::Remote), } +impl fmt::Debug for Reactor { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + const HANDLE: &'static &'static str = &"Reactor::Handle"; + const HANDLE_INNER: &'static &'static str = &"Handle { .. }"; + const REMOTE: &'static &'static str = &"Reactor::Remote"; + const REMOTE_INNER: &'static &'static str = &"Remote { .. }"; + + match *self { + Reactor::Handle(_) => f.debug_tuple(HANDLE).field(HANDLE_INNER).finish(), + Reactor::Remote(_) => f.debug_tuple(REMOTE).field(REMOTE_INNER).finish(), + } + } +} #[doc(hidden)] pub struct Client where Req: Serialize + 'static, @@ -213,7 +227,8 @@ impl ClientExt for Client let (tx, rx) = futures::oneshot(); let setup = move |handle: &reactor::Handle| { connect(handle).then(move |result| { - tx.complete(result); + // If send fails it means the client no longer cared about connecting. + let _ = tx.send(result); Ok(()) }) }; diff --git a/src/future/server.rs b/src/future/server.rs deleted file mode 100644 index 476e415..0000000 --- a/src/future/server.rs +++ /dev/null @@ -1,657 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -use {bincode, net2}; -use errors::WireError; -use futures::{Future, Poll, Stream, future as futures, stream}; -use futures::sync::{mpsc, oneshot}; -use futures::unsync; -use protocol::Proto; -use serde::{Deserialize, Serialize}; -use std::cell::Cell; -use std::io; -use std::net::SocketAddr; -use std::rc::Rc; -use tokio_core::io::Io; -use tokio_core::net::{Incoming, TcpListener, TcpStream}; -use tokio_core::reactor; -use tokio_proto::BindServer; -use tokio_service::{NewService, Service}; - -cfg_if! { - if #[cfg(feature = "tls")] { - use native_tls::{self, TlsAcceptor}; - use tokio_tls::{AcceptAsync, TlsAcceptorExt, TlsStream}; - use errors::native_to_io; - use stream_type::StreamType; - } else {} -} - -/// A handle to a bound server. -#[derive(Clone)] -pub struct Handle { - addr: SocketAddr, - shutdown: Shutdown, -} - -impl Handle { - #[doc(hidden)] - pub fn listen(new_service: S, - addr: SocketAddr, - handle: &reactor::Handle, - options: Options) - -> io::Result<(Self, Listen)> - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: Deserialize + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static - { - let (addr, shutdown, server) = - listen_with(new_service, - addr, handle, - options.max_payload_size, - Acceptor::from(options))?; - Ok((Handle { - addr: addr, - shutdown: shutdown, - }, - server)) - } - - /// Returns a hook for shutting down the server. - pub fn shutdown(&self) -> &Shutdown { - &self.shutdown - } - - /// The socket address the server is bound to. - pub fn addr(&self) -> SocketAddr { - self.addr - } -} - -enum Acceptor { - Tcp, - #[cfg(feature = "tls")] - Tls(TlsAcceptor), -} - -#[cfg(feature = "tls")] -type Accept = futures::Either, - fn(TlsStream) -> StreamType>, - fn(native_tls::Error) -> io::Error>, - futures::FutureResult>; - -#[cfg(not(feature = "tls"))] -type Accept = futures::FutureResult; - -impl Acceptor { - // TODO(https://github.com/tokio-rs/tokio-proto/issues/132): move this into the ServerProto impl - #[cfg(feature = "tls")] - fn accept(&self, socket: TcpStream) -> Accept { - match *self { - Acceptor::Tls(ref tls_acceptor) => { - futures::Either::A(tls_acceptor.accept_async(socket) - .map(StreamType::Tls as _) - .map_err(native_to_io)) - } - Acceptor::Tcp => futures::Either::B(futures::ok(StreamType::Tcp(socket))), - } - } - - #[cfg(not(feature = "tls"))] - fn accept(&self, socket: TcpStream) -> Accept { - futures::ok(socket) - } -} - -#[cfg(feature = "tls")] -impl From for Acceptor { - fn from(options: Options) -> Self { - match options.tls_acceptor { - Some(tls_acceptor) => Acceptor::Tls(tls_acceptor), - None => Acceptor::Tcp, - } - } -} - -#[cfg(not(feature = "tls"))] -impl From for Acceptor { - fn from(_: Options) -> Self { - Acceptor::Tcp - } -} - -impl FnOnce<((TcpStream, SocketAddr),)> for Acceptor { - type Output = Accept; - - extern "rust-call" fn call_once(self, ((socket, _),): ((TcpStream, SocketAddr),)) -> Accept { - self.accept(socket) - } -} - -impl FnMut<((TcpStream, SocketAddr),)> for Acceptor { - extern "rust-call" fn call_mut(&mut self, - ((socket, _),): ((TcpStream, SocketAddr),)) - -> Accept { - self.accept(socket) - } -} - -impl Fn<((TcpStream, SocketAddr),)> for Acceptor { - extern "rust-call" fn call(&self, ((socket, _),): ((TcpStream, SocketAddr),)) -> Accept { - self.accept(socket) - } -} - -/// Additional options to configure how the server operates. -pub struct Options { - /// Max packet size in bytes. - max_payload_size: u64, - #[cfg(feature = "tls")] - tls_acceptor: Option, -} - -impl Default for Options { - #[cfg(not(feature = "tls"))] - fn default() -> Self { - Options { - max_payload_size: 2 << 20, - } - } - - #[cfg(feature = "tls")] - fn default() -> Self { - Options { - max_payload_size: 2 << 20, - tls_acceptor: None, - } - } -} - -impl Options { - /// Set the max payload size in bytes. The default is 2,000,000 (2 MB). - pub fn max_payload_size(mut self, bytes: u64) -> Self { - self.max_payload_size = bytes; - self - } - - /// Sets the `TlsAcceptor` - #[cfg(feature = "tls")] - pub fn tls(mut self, tls_acceptor: TlsAcceptor) -> Self { - self.tls_acceptor = Some(tls_acceptor); - self - } -} - -/// A message from server to client. -#[doc(hidden)] -pub type Response = Result>; - -/// A hook to shut down a running server. -#[derive(Clone)] -pub struct Shutdown { - tx: mpsc::UnboundedSender>, -} - -/// A future that resolves when server shutdown completes. -pub struct ShutdownFuture { - inner: futures::Either, - futures::OrElse, Result<(), ()>, AlwaysOk>>, -} - -impl Future for ShutdownFuture { - type Item = (); - type Error = (); - - fn poll(&mut self) -> Poll<(), ()> { - self.inner.poll() - } -} - -impl Shutdown { - /// Initiates an orderly server shutdown. - /// - /// First, the server enters lameduck mode, in which - /// existing connections are honored but no new connections are accepted. Then, once all - /// connections are closed, it initates total shutdown. - /// - /// This fn will not return until the server is completely shut down. - pub fn shutdown(&self) -> ShutdownFuture { - let (tx, rx) = oneshot::channel(); - let inner = if let Err(_) = self.tx.send(tx) { - trace!("Server already initiated shutdown."); - futures::Either::A(futures::ok(())) - } else { - futures::Either::B(rx.or_else(AlwaysOk)) - }; - ShutdownFuture { inner: inner } - } -} - -enum ConnectionAction { - Increment, - Decrement, -} - -#[derive(Clone)] -struct ConnectionTracker { - tx: unsync::mpsc::UnboundedSender, -} - -impl ConnectionTracker { - fn increment(&self) { - let _ = self.tx.send(ConnectionAction::Increment); - } - - fn decrement(&self) { - debug!("Closing connection"); - let _ = self.tx.send(ConnectionAction::Decrement); - } -} - -struct ConnectionTrackingService { - service: S, - tracker: ConnectionTracker, -} - -impl Service for ConnectionTrackingService { - type Request = S::Request; - type Response = S::Response; - type Error = S::Error; - type Future = S::Future; - - fn call(&self, req: Self::Request) -> Self::Future { - trace!("Calling service."); - self.service.call(req) - } -} - -impl Drop for ConnectionTrackingService { - fn drop(&mut self) { - debug!("Dropping ConnnectionTrackingService."); - self.tracker.decrement(); - } -} - -struct ConnectionTrackingNewService { - new_service: S, - connection_tracker: ConnectionTracker, -} - -impl NewService for ConnectionTrackingNewService { - type Request = S::Request; - type Response = S::Response; - type Error = S::Error; - type Instance = ConnectionTrackingService; - - fn new_service(&self) -> io::Result { - self.connection_tracker.increment(); - Ok(ConnectionTrackingService { - service: self.new_service.new_service()?, - tracker: self.connection_tracker.clone(), - }) - } -} - -struct ShutdownSetter { - shutdown: Rc>>>, -} - -impl FnOnce<(oneshot::Sender<()>,)> for ShutdownSetter { - type Output = (); - - extern "rust-call" fn call_once(self, tx: (oneshot::Sender<()>,)) { - self.call(tx); - } -} - -impl FnMut<(oneshot::Sender<()>,)> for ShutdownSetter { - extern "rust-call" fn call_mut(&mut self, tx: (oneshot::Sender<()>,)) { - self.call(tx); - } -} - -impl Fn<(oneshot::Sender<()>,)> for ShutdownSetter { - extern "rust-call" fn call(&self, (tx,): (oneshot::Sender<()>,)) { - debug!("Received shutdown request."); - self.shutdown.set(Some(tx)); - } -} - -struct ConnectionWatcher { - connections: Rc>, -} - -impl FnOnce<(ConnectionAction,)> for ConnectionWatcher { - type Output = (); - - extern "rust-call" fn call_once(self, action: (ConnectionAction,)) { - self.call(action); - } -} - -impl FnMut<(ConnectionAction,)> for ConnectionWatcher { - extern "rust-call" fn call_mut(&mut self, action: (ConnectionAction,)) { - self.call(action); - } -} - -impl Fn<(ConnectionAction,)> for ConnectionWatcher { - extern "rust-call" fn call(&self, (action,): (ConnectionAction,)) { - match action { - ConnectionAction::Increment => self.connections.set(self.connections.get() + 1), - ConnectionAction::Decrement => self.connections.set(self.connections.get() - 1), - } - trace!("Open connections: {}", self.connections.get()); - } -} - -struct ShutdownPredicate { - shutdown: Rc>>>, - connections: Rc>, -} - -impl FnOnce for ShutdownPredicate { - type Output = Result; - - extern "rust-call" fn call_once(self, arg: T) -> Self::Output { - self.call(arg) - } -} - -impl FnMut for ShutdownPredicate { - extern "rust-call" fn call_mut(&mut self, arg: T) -> Self::Output { - self.call(arg) - } -} - -impl Fn for ShutdownPredicate { - extern "rust-call" fn call(&self, _: T) -> Self::Output { - match self.shutdown.take() { - Some(shutdown) => { - let num_connections = self.connections.get(); - debug!("Lameduck mode: {} open connections", num_connections); - if num_connections == 0 { - debug!("Shutting down."); - let _ = shutdown.complete(()); - Ok(false) - } else { - self.shutdown.set(Some(shutdown)); - Ok(true) - } - } - None => Ok(true), - } - } -} - -struct Warn(&'static str); - -impl FnOnce for Warn { - type Output = (); - - extern "rust-call" fn call_once(self, arg: T) -> Self::Output { - self.call(arg) - } -} - -impl FnMut for Warn { - extern "rust-call" fn call_mut(&mut self, arg: T) -> Self::Output { - self.call(arg) - } -} - -impl Fn for Warn { - extern "rust-call" fn call(&self, _: T) -> Self::Output { - warn!("{}", self.0) - } -} - -struct AlwaysOk; - -impl FnOnce for AlwaysOk { - type Output = Result<(), ()>; - - extern "rust-call" fn call_once(self, arg: T) -> Self::Output { - self.call(arg) - } -} - -impl FnMut for AlwaysOk { - extern "rust-call" fn call_mut(&mut self, arg: T) -> Self::Output { - self.call(arg) - } -} - -impl Fn for AlwaysOk { - extern "rust-call" fn call(&self, _: T) -> Self::Output { - Ok(()) - } -} - -type ShutdownStream = stream::Map>>, - ShutdownSetter>; - -type ConnectionStream = stream::Map, - ConnectionWatcher>; - -struct ShutdownWatcher { - inner: stream::ForEach, - ShutdownPredicate, - Result>, - Warn>, - AlwaysOk, - Result<(), ()>>, -} - -impl Future for ShutdownWatcher { - type Item = (); - type Error = (); - - fn poll(&mut self) -> Poll<(), ()> { - self.inner.poll() - } -} - -/// Creates a future that completes when a shutdown is signaled and no connections are open. -fn shutdown_watcher() -> (ConnectionTracker, Shutdown, ShutdownWatcher) { - let (shutdown_tx, shutdown_rx) = mpsc::unbounded::>(); - let (connection_tx, connection_rx) = unsync::mpsc::unbounded(); - let shutdown = Rc::new(Cell::new(None)); - let connections = Rc::new(Cell::new(0)); - let shutdown2 = shutdown.clone(); - let connections2 = connections.clone(); - - let inner = shutdown_rx.take(1) - .map(ShutdownSetter { shutdown: shutdown }) - .merge(connection_rx.map(ConnectionWatcher { connections: connections })) - .take_while(ShutdownPredicate { - shutdown: shutdown2, - connections: connections2, - }) - .map_err(Warn("UnboundedReceiver resolved to an Err; can it do that?")) - .for_each(AlwaysOk); - - (ConnectionTracker { tx: connection_tx }, - Shutdown { tx: shutdown_tx }, - ShutdownWatcher { inner: inner }) -} - -type AcceptStream = stream::AndThen; - -type BindStream = stream::ForEach>, - io::Result<()>>; - -/// The future representing a running server. -#[doc(hidden)] -pub struct Listen - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: Deserialize + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static -{ - inner: futures::Then, fn(io::Error)>, - ShutdownWatcher>, - Result<(), ()>, - AlwaysOk>, -} - -impl Future for Listen - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: Deserialize + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static -{ - type Item = (); - type Error = (); - - fn poll(&mut self) -> Poll<(), ()> { - self.inner.poll() - } -} - -/// Spawns a service that binds to the given address using the given handle. -fn listen_with(new_service: S, - addr: SocketAddr, - handle: &reactor::Handle, - max_payload_size: u64, - acceptor: Acceptor) - -> io::Result<(SocketAddr, Shutdown, Listen)> - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: Deserialize + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static -{ - let listener = listener(&addr, handle)?; - let addr = listener.local_addr()?; - debug!("Listening on {}.", addr); - - let handle = handle.clone(); - - let (connection_tracker, shutdown, shutdown_future) = shutdown_watcher(); - let server = listener.incoming() - .and_then(acceptor) - .for_each(Bind { - max_payload_size: max_payload_size, - handle: handle, - new_service: ConnectionTrackingNewService { - connection_tracker: connection_tracker, - new_service: new_service, - }, - }) - .map_err(log_err as _); - - let server = server.select(shutdown_future).then(AlwaysOk); - Ok((addr, shutdown, Listen { inner: server })) -} - -fn log_err(e: io::Error) { - error!("While processing incoming connections: {}", e); -} - -struct Bind { - max_payload_size: u64, - handle: reactor::Handle, - new_service: S, -} - -impl Bind - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: Deserialize + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static -{ - fn bind(&self, socket: I) -> io::Result<()> - where I: Io + 'static - { - Proto::new(self.max_payload_size) - .bind_server(&self.handle, socket, self.new_service.new_service()?); - Ok(()) - } -} - -impl FnOnce<(I,)> for Bind - where I: Io + 'static, - S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: Deserialize + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static -{ - type Output = io::Result<()>; - - extern "rust-call" fn call_once(self, (socket,): (I,)) -> io::Result<()> { - self.bind(socket) - } -} - -impl FnMut<(I,)> for Bind - where I: Io + 'static, - S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: Deserialize + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static -{ - extern "rust-call" fn call_mut(&mut self, (socket,): (I,)) -> io::Result<()> { - self.bind(socket) - } -} - -impl Fn<(I,)> for Bind - where I: Io + 'static, - S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: Deserialize + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static -{ - extern "rust-call" fn call(&self, (socket,): (I,)) -> io::Result<()> { - self.bind(socket) - } -} - -fn listener(addr: &SocketAddr, handle: &reactor::Handle) -> io::Result { - const PENDING_CONNECTION_BACKLOG: i32 = 1024; - - let builder = match *addr { - SocketAddr::V4(_) => net2::TcpBuilder::new_v4(), - SocketAddr::V6(_) => net2::TcpBuilder::new_v6(), - }?; - configure_tcp(&builder)?; - builder.reuse_address(true)?; - builder.bind(addr)? - .listen(PENDING_CONNECTION_BACKLOG) - .and_then(|l| TcpListener::from_listener(l, addr, handle)) -} - -#[cfg(unix)] -fn configure_tcp(tcp: &net2::TcpBuilder) -> io::Result<()> { - use net2::unix::UnixTcpBuilderExt; - - tcp.reuse_port(true)?; - - Ok(()) -} - -#[cfg(windows)] -fn configure_tcp(_tcp: &net2::TcpBuilder) -> io::Result<()> { - Ok(()) -} diff --git a/src/future/server/connection.rs b/src/future/server/connection.rs new file mode 100644 index 0000000..7883ad3 --- /dev/null +++ b/src/future/server/connection.rs @@ -0,0 +1,76 @@ +use futures::unsync; +use std::io; +use tokio_service::{NewService, Service}; + +#[derive(Debug)] +pub enum Action { + Increment, + Decrement, +} + +#[derive(Clone, Debug)] +pub struct Tracker { + pub tx: unsync::mpsc::UnboundedSender, +} + +impl Tracker { + pub fn pair() -> (Self, unsync::mpsc::UnboundedReceiver) { + let (tx, rx) = unsync::mpsc::unbounded(); + (Self { tx }, rx) + } + + pub fn increment(&self) { + let _ = self.tx.send(Action::Increment); + } + + pub fn decrement(&self) { + debug!("Closing connection"); + let _ = self.tx.send(Action::Decrement); + } +} + +#[derive(Debug)] +pub struct TrackingService { + pub service: S, + pub tracker: Tracker, +} + +#[derive(Debug)] +pub struct TrackingNewService { + pub new_service: S, + pub connection_tracker: Tracker, +} + +impl Service for TrackingService { + type Request = S::Request; + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn call(&self, req: Self::Request) -> Self::Future { + trace!("Calling service."); + self.service.call(req) + } +} + +impl Drop for TrackingService { + fn drop(&mut self) { + debug!("Dropping ConnnectionTrackingService."); + self.tracker.decrement(); + } +} + +impl NewService for TrackingNewService { + type Request = S::Request; + type Response = S::Response; + type Error = S::Error; + type Instance = TrackingService; + + fn new_service(&self) -> io::Result { + self.connection_tracker.increment(); + Ok(TrackingService { + service: self.new_service.new_service()?, + tracker: self.connection_tracker.clone(), + }) + } +} diff --git a/src/future/server/mod.rs b/src/future/server/mod.rs new file mode 100644 index 0000000..1a8dcb0 --- /dev/null +++ b/src/future/server/mod.rs @@ -0,0 +1,448 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +use {bincode, net2}; +use errors::WireError; +use futures::{Async, Future, Poll, Stream, future as futures}; +use protocol::Proto; +use serde::{Deserialize, Serialize}; +use std::fmt; +use std::io; +use std::net::SocketAddr; +use stream_type::StreamType; +use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_core::net::{Incoming, TcpListener, TcpStream}; +use tokio_core::reactor; +use tokio_proto::BindServer; +use tokio_service::NewService; + +mod connection; +mod shutdown; + +cfg_if! { + if #[cfg(feature = "tls")] { + use native_tls::{self, TlsAcceptor}; + use tokio_tls::{AcceptAsync, TlsAcceptorExt, TlsStream}; + use errors::native_to_io; + } else {} +} + +pub use self::shutdown::Shutdown; + +/// A handle to a bound server. +#[derive(Clone, Debug)] +pub struct Handle { + addr: SocketAddr, + shutdown: Shutdown, +} + +impl Handle { + /// Returns a hook for shutting down the server. + pub fn shutdown(&self) -> &Shutdown { + &self.shutdown + } + + /// The socket address the server is bound to. + pub fn addr(&self) -> SocketAddr { + self.addr + } +} + +enum Acceptor { + Tcp, + #[cfg(feature = "tls")] + Tls(TlsAcceptor), +} + +struct Accept { + #[cfg(feature = "tls")] + inner: futures::Either, + fn(TlsStream) -> StreamType>, + fn(native_tls::Error) -> io::Error>, + futures::FutureResult>, + #[cfg(not(feature = "tls"))] + inner: futures::FutureResult, +} + +impl Future for Accept { + type Item = StreamType; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + self.inner.poll() + } +} + +impl Acceptor { + // TODO(https://github.com/tokio-rs/tokio-proto/issues/132): move this into the ServerProto impl + #[cfg(feature = "tls")] + fn accept(&self, socket: TcpStream) -> Accept { + Accept { + inner: match *self { + Acceptor::Tls(ref tls_acceptor) => { + futures::Either::A(tls_acceptor.accept_async(socket) + .map(StreamType::Tls as _) + .map_err(native_to_io)) + } + Acceptor::Tcp => futures::Either::B(futures::ok(StreamType::Tcp(socket))), + } + } + } + + #[cfg(not(feature = "tls"))] + fn accept(&self, socket: TcpStream) -> Accept { + Accept { + inner: futures::ok(StreamType::Tcp(socket)) + } + } +} + +#[cfg(feature = "tls")] +impl From for Acceptor { + fn from(options: Options) -> Self { + match options.tls_acceptor { + Some(tls_acceptor) => Acceptor::Tls(tls_acceptor), + None => Acceptor::Tcp, + } + } +} + +#[cfg(not(feature = "tls"))] +impl From for Acceptor { + fn from(_: Options) -> Self { + Acceptor::Tcp + } +} + +impl fmt::Debug for Acceptor { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + use self::Acceptor::*; + #[cfg(feature = "tls")] + const TLS: &'static &'static str = &"TlsAcceptor { .. }"; + + match *self { + Tcp => fmt.debug_tuple("Acceptor::Tcp").finish(), + #[cfg(feature = "tls")] + Tls(_) => fmt.debug_tuple("Acceptlr::Tls").field(TLS).finish(), + } + } +} + +impl fmt::Debug for Accept { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("Accept").finish() + } +} + +#[derive(Debug)] +struct AcceptStream { + stream: S, + acceptor: Acceptor, + future: Option, +} + +impl Stream for AcceptStream + where S: Stream, +{ + type Item = ::Item; + type Error = io::Error; + + fn poll(&mut self) -> Poll, io::Error> { + if self.future.is_none() { + let stream = match try_ready!(self.stream.poll()) { + None => return Ok(Async::Ready(None)), + Some((stream, _)) => stream, + }; + self.future = Some(self.acceptor.accept(stream)); + } + assert!(self.future.is_some()); + match self.future.as_mut().unwrap().poll() { + Ok(Async::Ready(e)) => { + self.future = None; + Ok(Async::Ready(Some(e))) + } + Err(e) => { + self.future = None; + Err(e) + } + Ok(Async::NotReady) => Ok(Async::NotReady) + } + } +} + +/// Additional options to configure how the server operates. +pub struct Options { + /// Max packet size in bytes. + max_payload_size: u64, + #[cfg(feature = "tls")] + tls_acceptor: Option, +} + +impl Default for Options { + #[cfg(not(feature = "tls"))] + fn default() -> Self { + Options { + max_payload_size: 2 << 20, + } + } + + #[cfg(feature = "tls")] + fn default() -> Self { + Options { + max_payload_size: 2 << 20, + tls_acceptor: None, + } + } +} + +impl Options { + /// Set the max payload size in bytes. The default is 2 << 20 (2 MiB). + pub fn max_payload_size(mut self, bytes: u64) -> Self { + self.max_payload_size = bytes; + self + } + + /// Sets the `TlsAcceptor` + #[cfg(feature = "tls")] + pub fn tls(mut self, tls_acceptor: TlsAcceptor) -> Self { + self.tls_acceptor = Some(tls_acceptor); + self + } +} + +impl fmt::Debug for Options { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + #[cfg(feature = "tls")] + const SOME: &'static &'static str = &"Some(_)"; + #[cfg(feature = "tls")] + const NONE: &'static &'static str = &"None"; + + let mut debug_struct = fmt.debug_struct("Options"); + #[cfg(feature = "tls")] + debug_struct.field("tls_acceptor", if self.tls_acceptor.is_some() { SOME } else { NONE }); + debug_struct.finish() + } +} + +/// A message from server to client. +#[doc(hidden)] +pub type Response = Result>; + +#[doc(hidden)] +pub fn listen(new_service: S, + addr: SocketAddr, + handle: &reactor::Handle, + options: Options) + -> io::Result<(Handle, Listen)> + where S: NewService, + Response = Response, + Error = io::Error> + 'static, + Req: Deserialize + 'static, + Resp: Serialize + 'static, + E: Serialize + 'static +{ + let (addr, shutdown, server) = listen_with( + new_service, addr, handle, options.max_payload_size, Acceptor::from(options))?; + Ok((Handle { + addr: addr, + shutdown: shutdown, + }, + server)) +} + +/// Spawns a service that binds to the given address using the given handle. +fn listen_with(new_service: S, + addr: SocketAddr, + handle: &reactor::Handle, + max_payload_size: u64, + acceptor: Acceptor) + -> io::Result<(SocketAddr, Shutdown, Listen)> + where S: NewService, + Response = Response, + Error = io::Error> + 'static, + Req: Deserialize + 'static, + Resp: Serialize + 'static, + E: Serialize + 'static +{ + let listener = listener(&addr, handle)?; + let addr = listener.local_addr()?; + debug!("Listening on {}.", addr); + + let handle = handle.clone(); + let (connection_tracker, shutdown, shutdown_future) = shutdown::Watcher::triple(); + let server = BindStream { + handle: handle, + new_service: connection::TrackingNewService { + connection_tracker: connection_tracker, + new_service: new_service, + }, + stream: AcceptStream { + stream: listener.incoming(), + acceptor: acceptor, + future: None, + }, + max_payload_size: max_payload_size, + }; + + let server = AlwaysOkUnit(server.select(shutdown_future)); + Ok((addr, shutdown, Listen { inner: server })) +} + +fn listener(addr: &SocketAddr, handle: &reactor::Handle) -> io::Result { + const PENDING_CONNECTION_BACKLOG: i32 = 1024; + + let builder = match *addr { + SocketAddr::V4(_) => net2::TcpBuilder::new_v4(), + SocketAddr::V6(_) => net2::TcpBuilder::new_v6(), + }?; + configure_tcp(&builder)?; + builder.reuse_address(true)?; + builder.bind(addr)? + .listen(PENDING_CONNECTION_BACKLOG) + .and_then(|l| TcpListener::from_listener(l, addr, handle)) +} + +#[cfg(unix)] +fn configure_tcp(tcp: &net2::TcpBuilder) -> io::Result<()> { + use net2::unix::UnixTcpBuilderExt; + tcp.reuse_port(true)?; + Ok(()) +} + +#[cfg(windows)] +fn configure_tcp(_tcp: &net2::TcpBuilder) -> io::Result<()> { + Ok(()) +} + +struct BindStream { + handle: reactor::Handle, + new_service: connection::TrackingNewService, + stream: St, + max_payload_size: u64, +} + +impl fmt::Debug for BindStream + where S: fmt::Debug, + St: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + const HANDLE: &'static &'static str = &"Handle { .. }"; + f.debug_struct("BindStream") + .field("handle", HANDLE) + .field("new_service", &self.new_service) + .field("stream", &self.stream) + .finish() + } +} + +impl BindStream + where S: NewService, + Response = Response, + Error = io::Error> + 'static, + Req: Deserialize + 'static, + Resp: Serialize + 'static, + E: Serialize + 'static, + I: AsyncRead + AsyncWrite + 'static, + St: Stream, +{ + fn bind_each(&mut self) -> Poll<(), io::Error> { + loop { + match try!(self.stream.poll()) { + Async::Ready(Some(socket)) => { + Proto::new(self.max_payload_size) + .bind_server(&self.handle, socket, self.new_service.new_service()?); + } + Async::Ready(None) => return Ok(Async::Ready(())), + Async::NotReady => return Ok(Async::NotReady), + } + } + } +} + +impl Future for BindStream + where S: NewService, + Response = Response, + Error = io::Error> + 'static, + Req: Deserialize + 'static, + Resp: Serialize + 'static, + E: Serialize + 'static, + I: AsyncRead + AsyncWrite + 'static, + St: Stream, +{ + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll { + match self.bind_each() { + Ok(Async::Ready(())) => Ok(Async::Ready(())), + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(e) => { + error!("While processing incoming connections: {}", e); + Err(()) + } + } + } +} + +/// The future representing a running server. +#[doc(hidden)] +pub struct Listen + where S: NewService, + Response = Response, + Error = io::Error> + 'static, + Req: Deserialize + 'static, + Resp: Serialize + 'static, + E: Serialize + 'static +{ + inner: AlwaysOkUnit>, + shutdown::Watcher>>, +} + +impl Future for Listen + where S: NewService, + Response = Response, + Error = io::Error> + 'static, + Req: Deserialize + 'static, + Resp: Serialize + 'static, + E: Serialize + 'static +{ + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll<(), ()> { + self.inner.poll() + } +} + +impl fmt::Debug for Listen + where S: NewService, + Response = Response, + Error = io::Error> + 'static, + Req: Deserialize + 'static, + Resp: Serialize + 'static, + E: Serialize + 'static, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + f.debug_struct("Listen").finish() + } +} + +#[derive(Debug)] +struct AlwaysOkUnit(F); + +impl Future for AlwaysOkUnit + where F: Future, +{ + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll<(), ()> { + match self.0.poll() { + Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())), + Ok(Async::NotReady) => Ok(Async::NotReady), + } + } +} + diff --git a/src/future/server/shutdown.rs b/src/future/server/shutdown.rs new file mode 100644 index 0000000..df707a9 --- /dev/null +++ b/src/future/server/shutdown.rs @@ -0,0 +1,181 @@ +use futures::{Async, Future, Poll, Stream, future as futures, stream}; +use futures::sync::{mpsc, oneshot}; +use futures::unsync; + +use super::{AlwaysOkUnit, connection}; + +/// A hook to shut down a running server. +#[derive(Clone, Debug)] +pub struct Shutdown { + tx: mpsc::UnboundedSender>, +} + +#[derive(Debug)] +/// A future that resolves when server shutdown completes. +pub struct ShutdownFuture { + inner: futures::Either, + AlwaysOkUnit>>, +} + +impl Future for ShutdownFuture { + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll<(), ()> { + self.inner.poll() + } +} + +impl Shutdown { + /// Initiates an orderly server shutdown. + /// + /// First, the server enters lameduck mode, in which + /// existing connections are honored but no new connections are accepted. Then, once all + /// connections are closed, it initates total shutdown. + /// + /// This fn will not return until the server is completely shut down. + pub fn shutdown(&self) -> ShutdownFuture { + let (tx, rx) = oneshot::channel(); + let inner = if let Err(_) = self.tx.send(tx) { + trace!("Server already initiated shutdown."); + futures::Either::A(futures::ok(())) + } else { + futures::Either::B(AlwaysOkUnit(rx)) + }; + ShutdownFuture { inner: inner } + } +} + +#[derive(Debug)] +pub struct Watcher { + shutdown_rx: stream::Take>>, + connections: unsync::mpsc::UnboundedReceiver, + queued_error: Option<()>, + shutdown: Option>, + done: bool, + num_connections: u64, +} + +impl Watcher { + pub fn triple() -> (connection::Tracker, Shutdown, Self) { + let (connection_tx, connections) = connection::Tracker::pair(); + let (shutdown_tx, shutdown_rx) = mpsc::unbounded(); + (connection_tx, + Shutdown { tx: shutdown_tx }, + Watcher { + shutdown_rx: shutdown_rx.take(1), + connections: connections, + queued_error: None, + shutdown: None, + done: false, + num_connections: 0, + }) + } + + fn process_connection(&mut self, action: connection::Action) { + match action { + connection::Action::Increment => self.num_connections += 1, + connection::Action::Decrement => self.num_connections -= 1, + } + } + + fn poll_shutdown_requests(&mut self) -> Poll, ()> { + Ok(Async::Ready(match try_ready!(self.shutdown_rx.poll()) { + Some(tx) => { + debug!("Received shutdown request."); + self.shutdown = Some(tx); + Some(()) + } + None => None, + })) + } + + fn poll_connections(&mut self) -> Poll, ()> { + Ok(Async::Ready(match try_ready!(self.connections.poll()) { + Some(action) => { + self.process_connection(action); + Some(()) + } + None => None, + })) + } + + fn poll_shutdown_requests_and_connections(&mut self) -> Poll, ()> { + if let Some(e) = self.queued_error.take() { + return Err(e) + } + + match try!(self.poll_shutdown_requests()) { + Async::NotReady => { + match try_ready!(self.poll_connections()) { + Some(()) => Ok(Async::Ready(Some(()))), + None => Ok(Async::NotReady), + } + } + Async::Ready(None) => { + match try_ready!(self.poll_connections()) { + Some(()) => Ok(Async::Ready(Some(()))), + None => Ok(Async::Ready(None)), + } + } + Async::Ready(Some(())) => { + match self.poll_connections() { + Err(e) => { + self.queued_error = Some(e); + Ok(Async::Ready(Some(()))) + } + Ok(Async::NotReady) | Ok(Async::Ready(None)) | Ok(Async::Ready(Some(()))) => { + Ok(Async::Ready(Some(()))) + } + } + } + } + } + + fn should_continue(&mut self) -> bool { + match self.shutdown.take() { + Some(shutdown) => { + debug!("Lameduck mode: {} open connections", self.num_connections); + if self.num_connections == 0 { + debug!("Shutting down."); + // Not required for the shutdown future to be waited on, so this + // can fail (which is fine). + let _ = shutdown.send(()); + false + } else { + self.shutdown = Some(shutdown); + true + } + } + None => true, + } + } + + fn process_request(&mut self) -> Poll, ()> { + if self.done { + return Ok(Async::Ready(None)); + } + if self.should_continue() { + self.poll_shutdown_requests_and_connections() + } else { + self.done = true; + Ok(Async::Ready(None)) + } + } +} + +impl Future for Watcher { + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll<(), ()> { + loop { + match try!(self.process_request()) { + Async::Ready(Some(())) => continue, + Async::Ready(None) => return Ok(Async::Ready(())), + Async::NotReady => return Ok(Async::NotReady), + } + } + } +} + diff --git a/src/lib.rs b/src/lib.rs index 6fe61b7..82903bb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -113,25 +113,31 @@ //! println!("{}", client.hello("Mom".to_string()).unwrap()); //! } //! ``` -//! -#![deny(missing_docs)] -#![feature(fn_traits, move_cell, never_type, plugin, struct_field_attributes, unboxed_closures)] -#![plugin(tarpc_plugins)] + +#![deny(missing_docs, missing_debug_implementations)] +#![feature(never_type)] +#![cfg_attr(test, feature(plugin))] +#![cfg_attr(test, plugin(tarpc_plugins))] extern crate byteorder; +extern crate bytes; +#[macro_use] +extern crate cfg_if; #[macro_use] extern crate lazy_static; #[macro_use] extern crate log; extern crate net2; +extern crate num_cpus; #[macro_use] extern crate serde_derive; -#[macro_use] -extern crate cfg_if; +extern crate thread_pool; +extern crate tokio_io; #[doc(hidden)] pub extern crate bincode; #[doc(hidden)] +#[macro_use] pub extern crate futures; #[doc(hidden)] pub extern crate serde; diff --git a/src/macros.rs b/src/macros.rs index 2968a2b..8edd279 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -122,12 +122,12 @@ macro_rules! impl_deserialize { formatter.write_str("an enum variant") } - fn visit_enum(self, tarpc_enum_visitor__: V) + fn visit_enum(self, visitor__: V) -> ::std::result::Result where V: $crate::serde::de::EnumVisitor { use $crate::serde::de::VariantVisitor; - match tarpc_enum_visitor__.visit_variant()? { + match visitor__.visit_variant()? { $( (impl_deserialize_Field__::$name, variant) => { ::std::result::Result::Ok( @@ -290,40 +290,40 @@ macro_rules! service { #[doc(hidden)] #[allow(non_camel_case_types, unused)] - pub enum tarpc_service_Request__ { + pub enum Request__ { NotIrrefutable(()), $( $fn_name(( $($in_,)* )) ),* } - impl_deserialize!(tarpc_service_Request__, NotIrrefutable(()) $($fn_name(($($in_),*)))*); - impl_serialize!(tarpc_service_Request__, {}, NotIrrefutable(()) $($fn_name(($($in_),*)))*); + impl_deserialize!(Request__, NotIrrefutable(()) $($fn_name(($($in_),*)))*); + impl_serialize!(Request__, {}, NotIrrefutable(()) $($fn_name(($($in_),*)))*); #[doc(hidden)] #[allow(non_camel_case_types, unused)] - pub enum tarpc_service_Response__ { + pub enum Response__ { NotIrrefutable(()), $( $fn_name($out) ),* } - impl_deserialize!(tarpc_service_Response__, NotIrrefutable(()) $($fn_name($out))*); - impl_serialize!(tarpc_service_Response__, {}, NotIrrefutable(()) $($fn_name($out))*); + impl_deserialize!(Response__, NotIrrefutable(()) $($fn_name($out))*); + impl_serialize!(Response__, {}, NotIrrefutable(()) $($fn_name($out))*); #[doc(hidden)] #[allow(non_camel_case_types, unused)] #[derive(Debug)] - pub enum tarpc_service_Error__ { + pub enum Error__ { NotIrrefutable(()), $( $fn_name($error) ),* } - impl_deserialize!(tarpc_service_Error__, NotIrrefutable(()) $($fn_name($error))*); - impl_serialize!(tarpc_service_Error__, {}, NotIrrefutable(()) $($fn_name($error))*); + impl_deserialize!(Error__, NotIrrefutable(()) $($fn_name($error))*); + impl_serialize!(Error__, {}, NotIrrefutable(()) $($fn_name($error))*); /// Defines the `Future` RPC service. Implementors must be `Clone` and `'static`, /// as required by `tokio_proto::NewService`. This is required so that the service can be used @@ -333,7 +333,6 @@ macro_rules! service { 'static { $( - snake_to_camel! { /// The type of future returned by `{}`. type $fn_name: $crate::futures::IntoFuture; @@ -346,50 +345,42 @@ macro_rules! service { #[allow(non_camel_case_types)] #[derive(Clone)] - struct tarpc_service_AsyncServer__(S); + struct TarpcNewService(S); - impl ::std::fmt::Debug for tarpc_service_AsyncServer__ { + impl ::std::fmt::Debug for TarpcNewService { fn fmt(&self, fmt: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - write!(fmt, "tarpc_service_AsyncServer__ {{ .. }}") + fmt.debug_struct("TarpcNewService").finish() } } #[allow(non_camel_case_types)] - type tarpc_service_Future__ = - $crate::futures::Finished<$crate::future::server::Response, + type ResponseFuture__ = + $crate::futures::Finished<$crate::future::server::Response, ::std::io::Error>; #[allow(non_camel_case_types)] - enum tarpc_service_FutureReply__ { - DeserializeError(tarpc_service_Future__), + enum FutureReply__ { + DeserializeError(ResponseFuture__), $($fn_name( $crate::futures::Then< - ::Future, - tarpc_service_Future__, - fn(::std::result::Result<$out, $error>) - -> tarpc_service_Future__>)),* + ::Future, + ResponseFuture__, + fn(::std::result::Result<$out, $error>) -> ResponseFuture__>)),* } - impl $crate::futures::Future for tarpc_service_FutureReply__ { - type Item = $crate::future::server::Response; - + impl $crate::futures::Future for FutureReply__ { + type Item = $crate::future::server::Response; type Error = ::std::io::Error; fn poll(&mut self) -> $crate::futures::Poll { match *self { - tarpc_service_FutureReply__::DeserializeError( - ref mut tarpc_service_future__) => - { - $crate::futures::Future::poll(tarpc_service_future__) + FutureReply__::DeserializeError(ref mut future__) => { + $crate::futures::Future::poll(future__) } $( - tarpc_service_FutureReply__::$fn_name( - ref mut tarpc_service_future__) => - { - $crate::futures::Future::poll(tarpc_service_future__) + FutureReply__::$fn_name(ref mut future__) => { + $crate::futures::Future::poll(future__) } ),* } @@ -398,53 +389,45 @@ macro_rules! service { #[allow(non_camel_case_types)] - impl $crate::tokio_service::Service - for tarpc_service_AsyncServer__ - where tarpc_service_S__: FutureService + impl $crate::tokio_service::Service for TarpcNewService + where S__: FutureService { - type Request = ::std::result::Result; - type Response = $crate::future::server::Response; + type Request = ::std::result::Result; + type Response = $crate::future::server::Response; type Error = ::std::io::Error; - type Future = tarpc_service_FutureReply__; + type Future = FutureReply__; - fn call(&self, tarpc_service_request__: Self::Request) -> Self::Future { - let tarpc_service_request__ = match tarpc_service_request__ { - Ok(tarpc_service_request__) => tarpc_service_request__, - Err(tarpc_service_deserialize_err__) => { - return tarpc_service_FutureReply__::DeserializeError( + fn call(&self, request__: Self::Request) -> Self::Future { + let request__ = match request__ { + Ok(request__) => request__, + Err(err__) => { + return FutureReply__::DeserializeError( $crate::futures::finished( ::std::result::Result::Err( $crate::WireError::RequestDeserialize( - ::std::string::ToString::to_string( - &tarpc_service_deserialize_err__))))); + ::std::string::ToString::to_string(&err__))))); } }; - match tarpc_service_request__ { - tarpc_service_Request__::NotIrrefutable(()) => unreachable!(), + match request__ { + Request__::NotIrrefutable(()) => unreachable!(), $( - tarpc_service_Request__::$fn_name(( $($arg,)* )) => { - fn tarpc_service_wrap__( - tarpc_service_response__: - ::std::result::Result<$out, $error>) - -> tarpc_service_Future__ + Request__::$fn_name(( $($arg,)* )) => { + fn wrap__(response__: ::std::result::Result<$out, $error>) + -> ResponseFuture__ { $crate::futures::finished( - tarpc_service_response__ - .map(tarpc_service_Response__::$fn_name) - .map_err(|tarpc_service_error__| { - $crate::WireError::App( - tarpc_service_Error__::$fn_name( - tarpc_service_error__)) + response__ + .map(Response__::$fn_name) + .map_err(|err__| { + $crate::WireError::App(Error__::$fn_name(err__)) }) ) } - return tarpc_service_FutureReply__::$fn_name( + return FutureReply__::$fn_name( $crate::futures::Future::then( $crate::futures::IntoFuture::into_future( FutureService::$fn_name(&self.0, $($arg),*)), - tarpc_service_wrap__)); + wrap__)); } )* } @@ -452,9 +435,9 @@ macro_rules! service { } #[allow(non_camel_case_types)] - impl $crate::tokio_service::NewService - for tarpc_service_AsyncServer__ - where tarpc_service_S__: FutureService + impl $crate::tokio_service::NewService + for TarpcNewService + where S__: FutureService { type Request = ::Request; type Response = ::Response; @@ -471,10 +454,10 @@ macro_rules! service { pub struct Listen where S: FutureService, { - inner: $crate::future::server::Listen, - tarpc_service_Request__, - tarpc_service_Response__, - tarpc_service_Error__>, + inner: $crate::future::server::Listen, + Request__, + Response__, + Error__>, } impl $crate::futures::Future for Listen @@ -502,10 +485,10 @@ macro_rules! service { options: $crate::future::server::Options) -> ::std::io::Result<($crate::future::server::Handle, Listen)> { - $crate::future::server::Handle::listen(tarpc_service_AsyncServer__(self), - addr, - handle, - options) + $crate::future::server::listen(TarpcNewService(self), + addr, + handle, + options) .map(|(handle, inner)| (handle, Listen { inner })) } } @@ -534,59 +517,36 @@ macro_rules! service { -> ::std::io::Result<$crate::sync::server::Handle> where A: ::std::net::ToSocketAddrs { - let tarpc_service__ = tarpc_service_AsyncServer__(SyncServer__ { - service: self, - }); - - let tarpc_service_addr__ = - $crate::util::FirstSocketAddr::try_first_socket_addr(&addr)?; - - return $crate::sync::server::Handle::listen(tarpc_service__, - tarpc_service_addr__, - options); - #[derive(Clone)] - struct SyncServer__ { - service: S, - } - - #[allow(non_camel_case_types)] - impl FutureService for SyncServer__ - where tarpc_service_S__: SyncService - { + struct BlockingFutureService(S); + impl FutureService for BlockingFutureService { $( impl_snake_to_camel! { type $fn_name = - $crate::futures::Flatten< - $crate::futures::MapErr< - $crate::futures::Oneshot< - $crate::futures::Done<$out, $error>>, - fn($crate::futures::Canceled) -> $error>>; + $crate::util::Lazy< + fn((S, $($in_),*)) -> ::std::result::Result<$out, $error>, + (S, $($in_),*), + ::std::result::Result<$out, $error>>; } - fn $fn_name(&self, $($arg:$in_),*) -> ty_snake_to_camel!(Self::$fn_name) { - fn unimplemented(_: $crate::futures::Canceled) -> $error { - // TODO(tikue): what do do if SyncService panics? - unimplemented!() + + $(#[$attr])* + fn $fn_name(&self, $($arg:$in_),*) + -> $crate::util::Lazy< + fn((S, $($in_),*)) -> ::std::result::Result<$out, $error>, + (S, $($in_),*), + ::std::result::Result<$out, $error>> { + fn execute((s, $($arg),*): (S, $($in_),*)) + -> ::std::result::Result<$out, $error> { + SyncService::$fn_name(&s, $($arg),*) } - let (tarpc_service_complete__, tarpc_service_promise__) = - $crate::futures::oneshot(); - let tarpc_service__ = self.clone(); - const UNIMPLEMENTED: fn($crate::futures::Canceled) -> $error = - unimplemented; - ::std::thread::spawn(move || { - let tarpc_service_reply__ = SyncService::$fn_name( - &tarpc_service__.service, $($arg),*); - tarpc_service_complete__.complete( - $crate::futures::IntoFuture::into_future( - tarpc_service_reply__)); - }); - let tarpc_service_promise__ = - $crate::futures::Future::map_err( - tarpc_service_promise__, UNIMPLEMENTED); - $crate::futures::Future::flatten(tarpc_service_promise__) + $crate::util::lazy(execute, (self.0.clone(), $($arg),*)) } )* } + + let tarpc_service__ = TarpcNewService(BlockingFutureService(self)); + let addr__ = $crate::util::FirstSocketAddr::try_first_socket_addr(&addr)?; + return $crate::sync::server::listen(tarpc_service__, addr__, options); } } @@ -597,7 +557,7 @@ macro_rules! service { #[allow(unused)] #[derive(Clone, Debug)] pub struct SyncClient { - inner: tarpc_service_SyncClient__, + inner: SyncClient__, } impl $crate::sync::client::ClientExt for SyncClient { @@ -605,8 +565,8 @@ macro_rules! service { -> ::std::io::Result where A: ::std::net::ToSocketAddrs, { - let client_ = ::connect(addr_, options_)?; + let client_ = + ::connect(addr_, options_)?; ::std::result::Result::Ok(SyncClient { inner: client_, }) @@ -620,71 +580,26 @@ macro_rules! service { pub fn $fn_name(&self, $($arg: $in_),*) -> ::std::result::Result<$out, $crate::Error<$error>> { - return then__(self.inner.call(tarpc_service_Request__::$fn_name(($($arg,)*)))); - - // TODO: this code is duplicated in both FutureClient and SyncClient. - fn then__(tarpc_service_msg__: - ::std::result::Result>) - -> ::std::result::Result<$out, $crate::Error<$error>> { - match tarpc_service_msg__ { - ::std::result::Result::Ok(tarpc_service_msg__) => { - if let tarpc_service_Response__::$fn_name(tarpc_service_msg__) = - tarpc_service_msg__ - { - ::std::result::Result::Ok(tarpc_service_msg__) - } else { - unreachable!() - } - } - ::std::result::Result::Err(tarpc_service_err__) => { - ::std::result::Result::Err(match tarpc_service_err__ { - $crate::Error::App(tarpc_service_err__) => { - if let tarpc_service_Error__::$fn_name( - tarpc_service_err__) = tarpc_service_err__ - { - $crate::Error::App(tarpc_service_err__) - } else { - unreachable!() - } - } - $crate::Error::RequestDeserialize(tarpc_service_err__) => { - $crate::Error::RequestDeserialize(tarpc_service_err__) - } - $crate::Error::ResponseDeserialize(tarpc_service_err__) => { - $crate::Error::ResponseDeserialize(tarpc_service_err__) - } - $crate::Error::Io(tarpc_service_error__) => { - $crate::Error::Io(tarpc_service_error__) - } - }) - } - } - } + tarpc_service_then__!($out, $error, $fn_name); + let resp__ = self.inner.call(Request__::$fn_name(($($arg,)*))); + tarpc_service_then__(resp__) } )* } #[allow(non_camel_case_types)] - type tarpc_service_FutureClient__ = - $crate::future::client::Client; + type FutureClient__ = $crate::future::client::Client; #[allow(non_camel_case_types)] - type tarpc_service_SyncClient__ = - $crate::sync::client::Client; + type SyncClient__ = $crate::sync::client::Client; #[allow(non_camel_case_types)] /// A future representing a client connecting to a server. pub struct Connect { - inner: $crate::futures::Map<$crate::future::client::ConnectFuture< - tarpc_service_Request__, - tarpc_service_Response__, - tarpc_service_Error__>, - fn(tarpc_service_FutureClient__) -> T>, + inner: + $crate::futures::Map< + $crate::future::client::ConnectFuture< Request__, Response__, Error__>, + fn(FutureClient__) -> T>, } impl $crate::futures::Future for Connect { @@ -699,18 +614,18 @@ macro_rules! service { #[allow(unused)] #[derive(Clone, Debug)] /// The client stub that makes RPC calls to the server. Exposes a Future interface. - pub struct FutureClient(tarpc_service_FutureClient__); + pub struct FutureClient(FutureClient__); impl<'a> $crate::future::client::ClientExt for FutureClient { type ConnectFut = Connect; - fn connect(tarpc_service_addr__: ::std::net::SocketAddr, - tarpc_service_options__: $crate::future::client::Options) + fn connect(addr__: ::std::net::SocketAddr, + options__: $crate::future::client::Options) -> Self::ConnectFut { - let client = ::connect(tarpc_service_addr__, - tarpc_service_options__); + let client = + ::connect(addr__, + options__); Connect { inner: $crate::futures::Future::map(client, FutureClient) @@ -724,61 +639,67 @@ macro_rules! service { $(#[$attr])* pub fn $fn_name(&self, $($arg: $in_),*) -> $crate::futures::future::Then< - ::Future, + ::Future, ::std::result::Result<$out, $crate::Error<$error>>, - fn(::std::result::Result>) + fn(::std::result::Result>) -> ::std::result::Result<$out, $crate::Error<$error>>> { + tarpc_service_then__!($out, $error, $fn_name); - let tarpc_service_req__ = tarpc_service_Request__::$fn_name(($($arg,)*)); - let tarpc_service_fut__ = - $crate::tokio_service::Service::call(&self.0, tarpc_service_req__); - return $crate::futures::Future::then(tarpc_service_fut__, then__); - - fn then__(tarpc_service_msg__: - ::std::result::Result>) - -> ::std::result::Result<$out, $crate::Error<$error>> { - match tarpc_service_msg__ { - ::std::result::Result::Ok(tarpc_service_msg__) => { - if let tarpc_service_Response__::$fn_name(tarpc_service_msg__) = - tarpc_service_msg__ - { - ::std::result::Result::Ok(tarpc_service_msg__) - } else { - unreachable!() - } - } - ::std::result::Result::Err(tarpc_service_err__) => { - ::std::result::Result::Err(match tarpc_service_err__ { - $crate::Error::App(tarpc_service_err__) => { - if let tarpc_service_Error__::$fn_name( - tarpc_service_err__) = tarpc_service_err__ - { - $crate::Error::App(tarpc_service_err__) - } else { - unreachable!() - } - } - $crate::Error::RequestDeserialize(tarpc_service_err__) => { - $crate::Error::RequestDeserialize(tarpc_service_err__) - } - $crate::Error::ResponseDeserialize(tarpc_service_err__) => { - $crate::Error::ResponseDeserialize(tarpc_service_err__) - } - $crate::Error::Io(tarpc_service_error__) => { - $crate::Error::Io(tarpc_service_error__) - } - }) - } - } - } + let request__ = Request__::$fn_name(($($arg,)*)); + let future__ = $crate::tokio_service::Service::call(&self.0, request__); + return $crate::futures::Future::then(future__, tarpc_service_then__); } )* - } } } + +#[doc(hidden)] +#[macro_export] +macro_rules! tarpc_service_then__ { + ($out:ty, $error:ty, $fn_name:ident) => { + fn tarpc_service_then__(msg__: + ::std::result::Result>) + -> ::std::result::Result<$out, $crate::Error<$error>> { + match msg__ { + ::std::result::Result::Ok(msg__) => { + if let Response__::$fn_name(msg__) = + msg__ + { + ::std::result::Result::Ok(msg__) + } else { + unreachable!() + } + } + ::std::result::Result::Err(err__) => { + ::std::result::Result::Err(match err__ { + $crate::Error::App(err__) => { + if let Error__::$fn_name( + err__) = err__ + { + $crate::Error::App(err__) + } else { + unreachable!() + } + } + $crate::Error::RequestDeserialize(err__) => { + $crate::Error::RequestDeserialize(err__) + } + $crate::Error::ResponseDeserialize(err__) => { + $crate::Error::ResponseDeserialize(err__) + } + $crate::Error::Io(err__) => { + $crate::Error::Io(err__) + } + }) + } + } + } + }; +} + // allow dead code; we're just testing that the macro expansion compiles #[allow(dead_code)] #[cfg(test)] @@ -1101,8 +1022,8 @@ mod functional_test { let _ = env_logger::init(); let (addr, client, shutdown) = unwrap!(start_server_with_sync_client::(Server)); - assert_eq!(3, client.add(1, 2).unwrap()); - assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).unwrap()); + assert_eq!(3, unwrap!(client.add(1, 2))); + assert_eq!("Hey, Tim.", unwrap!(client.hey("Tim".to_string()))); info!("Dropping client."); drop(client); @@ -1110,23 +1031,23 @@ mod functional_test { let (tx2, rx2) = ::std::sync::mpsc::channel(); let shutdown2 = shutdown.clone(); ::std::thread::spawn(move || { - let client = get_sync_client::(addr).unwrap(); - tx.send(()).unwrap(); - let add = client.add(3, 2).unwrap(); + let client = unwrap!(get_sync_client::(addr)); + let add = unwrap!(client.add(3, 2)); + unwrap!(tx.send(())); drop(client); // Make sure 2 shutdowns are concurrent safe. - shutdown2.shutdown().wait().unwrap(); - tx2.send(add).unwrap(); + unwrap!(shutdown2.shutdown().wait()); + unwrap!(tx2.send(add)); }); - rx.recv().unwrap(); + unwrap!(rx.recv()); let mut shutdown1 = shutdown.shutdown(); - shutdown.shutdown().wait().unwrap(); + unwrap!(shutdown.shutdown().wait()); // Assert shutdown2 blocks until shutdown is complete. - if let Async::NotReady = shutdown1.poll().unwrap() { + if let Async::NotReady = unwrap!(shutdown1.poll()) { panic!("Shutdown should have completed"); } // Existing clients are served - assert_eq!(5, rx2.recv().unwrap()); + assert_eq!(5, unwrap!(rx2.recv())); let e = get_sync_client::(addr).err().unwrap(); debug!("(Success) shutdown caused client err: {}", e); @@ -1162,10 +1083,10 @@ mod functional_test { } mod bad_serialize { - use sync::{client, server}; - use sync::client::ClientExt; use serde::{Serialize, Serializer}; use serde::ser::SerializeSeq; + use sync::{client, server}; + use sync::client::ClientExt; #[derive(Deserialize)] pub struct Bad; diff --git a/src/protocol.rs b/src/protocol.rs index b002e31..b9d9c6d 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -3,23 +3,28 @@ // Licensed under the MIT License, . // This file may not be copied, modified, or distributed except according to those terms. -use {serde, tokio_core}; +use serde; use bincode::{self, Infinite}; -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use byteorder::{BigEndian, ReadBytesExt}; +use bytes::BytesMut; +use bytes::buf::BufMut; use std::io::{self, Cursor}; use std::marker::PhantomData; use std::mem; -use tokio_core::io::{EasyBuf, Framed, Io}; +use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_io::codec::{Encoder, Decoder, Framed}; use tokio_proto::multiplex::{ClientProto, ServerProto}; use tokio_proto::streaming::multiplex::RequestId; // `Encode` is the type that `Codec` encodes. `Decode` is the type it decodes. +#[derive(Debug)] pub struct Codec { max_payload_size: u64, state: CodecState, _phantom_data: PhantomData<(Encode, Decode)>, } +#[derive(Debug)] enum CodecState { Id, Len { id: u64 }, @@ -44,32 +49,41 @@ fn too_big(payload_size: u64, max_payload_size: u64) -> io::Error { max_payload_size, payload_size)) } -impl tokio_core::io::Codec for Codec +impl Encoder for Codec where Encode: serde::Serialize, Decode: serde::Deserialize { - type Out = (RequestId, Encode); - type In = (RequestId, Result); + type Item = (RequestId, Encode); + type Error = io::Error; - fn encode(&mut self, (id, message): Self::Out, buf: &mut Vec) -> io::Result<()> { - buf.write_u64::(id).unwrap(); - trace!("Encoded request id = {} as {:?}", id, buf); + fn encode(&mut self, (id, message): Self::Item, buf: &mut BytesMut) -> io::Result<()> { let payload_size = bincode::serialized_size(&message); if payload_size > self.max_payload_size { return Err(too_big(payload_size, self.max_payload_size)); } - buf.write_u64::(payload_size).unwrap(); - bincode::serialize_into(buf, + let message_size = 2 * mem::size_of::() + payload_size as usize; + buf.reserve(message_size); + buf.put_u64::(id); + trace!("Encoded request id = {} as {:?}", id, buf); + buf.put_u64::(payload_size); + bincode::serialize_into(&mut buf.writer(), &message, Infinite) .map_err(|serialize_err| io::Error::new(io::ErrorKind::Other, serialize_err))?; trace!("Encoded buffer: {:?}", buf); Ok(()) } +} - fn decode(&mut self, buf: &mut EasyBuf) -> Result, io::Error> { +impl Decoder for Codec + where Decode: serde::Deserialize +{ + type Item = (RequestId, Result); + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { use self::CodecState::*; - trace!("Codec::decode: {:?}", buf.as_slice()); + trace!("Codec::decode: {:?}", buf); loop { match self.state { @@ -78,9 +92,9 @@ impl tokio_core::io::Codec for Codec return Ok(None); } Id => { - let mut id_buf = buf.drain_to(mem::size_of::()); + let mut id_buf = buf.split_to(mem::size_of::()); let id = Cursor::new(&mut id_buf).read_u64::()?; - trace!("--> Parsed id = {} from {:?}", id, id_buf.as_slice()); + trace!("--> Parsed id = {} from {:?}", id, id_buf); self.state = Len { id: id }; } Len { .. } if buf.len() < mem::size_of::() => { @@ -89,7 +103,7 @@ impl tokio_core::io::Codec for Codec return Ok(None); } Len { id } => { - let len_buf = buf.drain_to(mem::size_of::()); + let len_buf = buf.split_to(mem::size_of::()); let len = Cursor::new(len_buf).read_u64::()?; trace!("--> Parsed payload length = {}, remaining buffer length = {}", len, @@ -106,7 +120,7 @@ impl tokio_core::io::Codec for Codec return Ok(None); } Payload { id, len } => { - let payload = buf.drain_to(len as usize); + let payload = buf.split_to(len as usize); let result = bincode::deserialize_from(&mut Cursor::new(payload), Infinite); // Reset the state machine because, either way, we're done processing this @@ -121,6 +135,7 @@ impl tokio_core::io::Codec for Codec } /// Implements the `multiplex::ServerProto` trait. +#[derive(Debug)] pub struct Proto { max_payload_size: u64, _phantom_data: PhantomData<(Encode, Decode)>, @@ -137,7 +152,7 @@ impl Proto { } impl ServerProto for Proto - where T: Io + 'static, + where T: AsyncRead + AsyncWrite + 'static, Encode: serde::Serialize + 'static, Decode: serde::Deserialize + 'static { @@ -152,7 +167,7 @@ impl ServerProto for Proto } impl ClientProto for Proto - where T: Io + 'static, + where T: AsyncRead + AsyncWrite + 'static, Encode: serde::Serialize + 'static, Decode: serde::Deserialize + 'static { @@ -168,17 +183,13 @@ impl ClientProto for Proto #[test] fn serialize() { - use tokio_core::io::Codec as TokioCodec; - const MSG: (u64, (char, char, char)) = (4, ('a', 'b', 'c')); - let mut buf = EasyBuf::new(); - let mut vec = Vec::new(); + let mut buf = BytesMut::with_capacity(10); // Serialize twice to check for idempotence. for _ in 0..2 { let mut codec: Codec<(char, char, char), (char, char, char)> = Codec::new(2_000_000); - codec.encode(MSG, &mut vec).unwrap(); - buf.get_mut().append(&mut vec); + codec.encode(MSG, &mut buf).unwrap(); let actual: Result)>, io::Error> = codec.decode(&mut buf); @@ -187,26 +198,22 @@ fn serialize() { bad => panic!("Expected {:?}, but got {:?}", Some(MSG), bad), } - assert!(buf.get_mut().is_empty(), - "Expected empty buf but got {:?}", - *buf.get_mut()); + assert!(buf.is_empty(), "Expected empty buf but got {:?}", buf); } } #[test] fn deserialize_big() { - use tokio_core::io::Codec as TokioCodec; let mut codec: Codec, Vec> = Codec::new(24); - let mut vec = Vec::new(); - assert_eq!(codec.encode((0, vec![0; 24]), &mut vec).err().unwrap().kind(), + let mut buf = BytesMut::with_capacity(40); + assert_eq!(codec.encode((0, vec![0; 24]), &mut buf).err().unwrap().kind(), io::ErrorKind::InvalidData); - let mut buf = EasyBuf::new(); // Header - buf.get_mut().append(&mut vec![0; 8]); + buf.put_slice(&mut [0u8; 8]); // Len - buf.get_mut().append(&mut vec![0, 0, 0, 0, 0, 0, 0, 25]); + buf.put_slice(&mut [0u8, 0, 0, 0, 0, 0, 0, 25]); assert_eq!(codec.decode(&mut buf).err().unwrap().kind(), io::ErrorKind::InvalidData); } diff --git a/src/stream_type.rs b/src/stream_type.rs index 78fbe17..b60ddd6 100644 --- a/src/stream_type.rs +++ b/src/stream_type.rs @@ -1,6 +1,8 @@ +use bytes::{Buf, BufMut}; +use futures::Poll; use std::io; -use tokio_core::io::Io; use tokio_core::net::TcpStream; +use tokio_io::{AsyncRead, AsyncWrite}; #[cfg(feature = "tls")] use tokio_tls::TlsStream; @@ -52,4 +54,41 @@ impl io::Write for StreamType { } } -impl Io for StreamType {} +impl AsyncRead for StreamType { + // By overriding this fn, `StreamType` is obliged to never read the uninitialized buffer. + // Most sane implementations would never have a reason to, and `StreamType` does not, so + // this is safe. + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + match *self { + StreamType::Tcp(ref stream) => stream.prepare_uninitialized_buffer(buf), + #[cfg(feature = "tls")] + StreamType::Tls(ref stream) => stream.prepare_uninitialized_buffer(buf), + } + } + + fn read_buf(&mut self, buf: &mut B) -> Poll { + match *self { + StreamType::Tcp(ref mut stream) => stream.read_buf(buf), + #[cfg(feature = "tls")] + StreamType::Tls(ref mut stream) => stream.read_buf(buf), + } + } +} + +impl AsyncWrite for StreamType { + fn shutdown(&mut self) -> Poll<(), io::Error> { + match *self { + StreamType::Tcp(ref mut stream) => stream.shutdown(), + #[cfg(feature = "tls")] + StreamType::Tls(ref mut stream) => stream.shutdown(), + } + } + + fn write_buf(&mut self, buf: &mut B) -> Poll { + match *self { + StreamType::Tcp(ref mut stream) => stream.write_buf(buf), + #[cfg(feature = "tls")] + StreamType::Tls(ref mut stream) => stream.write_buf(buf), + } + } +} diff --git a/src/sync/client.rs b/src/sync/client.rs index 7031c97..f56467b 100644 --- a/src/sync/client.rs +++ b/src/sync/client.rs @@ -1,4 +1,3 @@ - use future::client::{Client as FutureClient, ClientExt as FutureClientExt, Options as FutureOptions}; /// Exposes a trait for connecting synchronously to servers. @@ -29,7 +28,10 @@ impl Clone for Client { impl fmt::Debug for Client { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - write!(f, "Client {{ .. }}") + const PROXY: &'static &'static str = &"ClientProxy { .. }"; + f.debug_struct("Client") + .field("proxy", PROXY) + .finish() } } @@ -40,6 +42,9 @@ impl Client { /// Drives an RPC call for the given request. pub fn call(&self, request: Req) -> Result> { + // Must call wait here to block on the response. + // The request handler relies on this fact to safely unwrap the + // oneshot send. self.proxy.call(request).wait() } @@ -85,6 +90,19 @@ impl Options { } } +impl fmt::Debug for Options { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + #[cfg(feature = "tls")] + const SOME: &'static &'static str = &"Some(_)"; + #[cfg(feature = "tls")] + const NONE: &'static &'static str = &"None"; + let mut f = f.debug_struct("Options"); + #[cfg(feature = "tls")] + f.field("tls_ctx", if self.tls_ctx.is_some() { SOME } else { NONE }); + f.finish() + } +} + impl Into for (reactor::Handle, Options) { #[cfg(feature = "tls")] fn into(self) -> FutureOptions { @@ -180,7 +198,10 @@ impl RequestHandler .for_each(|(request, response_tx)| { let request = client.call(request) .then(move |response| { - response_tx.complete(response); + // Safe to unwrap because clients always block on the response future. + response_tx.send(response) + .map_err(|_| ()) + .expect("Client should block on response"); Ok(()) }); handle.spawn(request); diff --git a/src/sync/server.rs b/src/sync/server.rs index 8a726c9..7f8e946 100644 --- a/src/sync/server.rs +++ b/src/sync/server.rs @@ -1,20 +1,41 @@ -use {bincode, future}; +use {bincode, future, num_cpus}; use future::server::{Response, Shutdown}; -use futures::Future; +use futures::{Future, future as futures}; +use futures::sync::oneshot; use serde::{Deserialize, Serialize}; use std::io; +use std::fmt; use std::net::SocketAddr; +use std::time::Duration; +use std::usize; +use thread_pool::{self, Sender, Task, ThreadPool}; use tokio_core::reactor; -use tokio_service::NewService; +use tokio_service::{NewService, Service}; #[cfg(feature = "tls")] use native_tls_inner::TlsAcceptor; /// Additional options to configure how the server operates. -#[derive(Default)] +#[derive(Debug)] pub struct Options { + thread_pool: thread_pool::Builder, opts: future::server::Options, } +impl Default for Options { + fn default() -> Self { + let num_cpus = num_cpus::get(); + Options { + thread_pool: thread_pool::Builder::new() + .keep_alive(Duration::from_secs(60)) + .max_pool_size(num_cpus * 100) + .core_pool_size(num_cpus) + .work_queue_capacity(usize::MAX) + .name_prefix("request-thread-"), + opts: future::server::Options::default(), + } + } +} + impl Options { /// Set the max payload size in bytes. The default is 2,000,000 (2 MB). pub fn max_payload_size(mut self, bytes: u64) -> Self { @@ -22,6 +43,12 @@ impl Options { self } + /// Sets the thread pool builder to use when creating the server's thread pool. + pub fn thread_pool(mut self, builder: thread_pool::Builder) -> Self { + self.thread_pool = builder; + self + } + /// Set the `TlsAcceptor` #[cfg(feature = "tls")] pub fn tls(mut self, tls_acceptor: TlsAcceptor) -> Self { @@ -39,29 +66,6 @@ pub struct Handle { } impl Handle { - #[doc(hidden)] - pub fn listen(new_service: S, - addr: SocketAddr, - options: Options) - -> io::Result - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: Deserialize + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static - { - let reactor = reactor::Core::new()?; - let (handle, server) = - future::server::Handle::listen(new_service, addr, &reactor.handle(), options.opts)?; - let server = Box::new(server); - Ok(Handle { - reactor: reactor, - handle: handle, - server: server, - }) - } - /// Runs the server on the current thread, blocking indefinitely. pub fn run(mut self) { trace!("Running..."); @@ -81,3 +85,141 @@ impl Handle { self.handle.addr() } } + +impl fmt::Debug for Handle { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + const CORE: &'static &'static str = &"Core { .. }"; + const SERVER: &'static &'static str = &"Box>"; + + f.debug_struct("Handle").field("reactor", CORE) + .field("handle", &self.handle) + .field("server", SERVER) + .finish() + } +} + +#[doc(hidden)] +pub fn listen(new_service: S, addr: SocketAddr, options: Options) + -> io::Result + where S: NewService, + Response = Response, + Error = io::Error> + 'static, + ::Future: Send + 'static, + S::Response: Send, + S::Error: Send, + Req: Deserialize + 'static, + Resp: Serialize + 'static, + E: Serialize + 'static +{ + let new_service = NewThreadService::new(new_service, options.thread_pool); + let reactor = reactor::Core::new()?; + let (handle, server) = + future::server::listen(new_service, addr, &reactor.handle(), options.opts)?; + let server = Box::new(server); + Ok(Handle { + reactor: reactor, + handle: handle, + server: server, + }) +} + +/// A service that uses a thread pool. +struct NewThreadService where S: NewService { + new_service: S, + sender: Sender::Future>>, + _pool: ThreadPool::Future>>, +} + +/// A service that runs by executing request handlers in a thread pool. +struct ThreadService where S: Service { + service: S, + sender: Sender>, +} + +/// A task that handles a single request. +struct ServiceTask where F: Future { + future: F, + tx: oneshot::Sender>, +} + +impl NewThreadService + where S: NewService, + ::Future: Send + 'static, + S::Response: Send, + S::Error: Send, +{ + /// Create a NewThreadService by wrapping another service. + fn new(new_service: S, pool: thread_pool::Builder) -> Self { + let (sender, _pool) = pool.build(); + NewThreadService { new_service, sender, _pool } + } +} + +impl NewService for NewThreadService + where S: NewService, + ::Future: Send + 'static, + S::Response: Send, + S::Error: Send, +{ + type Request = S::Request; + type Response = S::Response; + type Error = S::Error; + type Instance = ThreadService; + + fn new_service(&self) -> io::Result { + Ok(ThreadService { + service: self.new_service.new_service()?, + sender: self.sender.clone(), + }) + } +} + +impl Task for ServiceTask + where F: Future + Send + 'static, + F::Item: Send, + F::Error: Send, +{ + fn run(self) { + // Don't care if sending fails. It just means the request is no longer + // being handled (I think). + let _ = self.tx.send(self.future.wait()); + } +} + +impl Service for ThreadService + where S: Service, + S::Future: Send + 'static, + S::Response: Send, + S::Error: Send, +{ + type Request = S::Request; + type Response = S::Response; + type Error = S::Error; + type Future = + futures::AndThen< + futures::MapErr< + oneshot::Receiver>, + fn(oneshot::Canceled) -> Self::Error>, + Result, + fn(Result) -> Result>; + + fn call(&self, request: Self::Request) -> Self::Future { + let (tx, rx) = oneshot::channel(); + self.sender.send(ServiceTask { + future: self.service.call(request), + tx: tx, + }).unwrap(); + rx.map_err(unreachable as _).and_then(ident) + } +} + +fn unreachable(t: T) -> U + where T: fmt::Display +{ + unreachable!(t) +} + +fn ident(t: T) -> T { + t +} + diff --git a/src/tls.rs b/src/tls.rs index 063ed80..62327c8 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,6 +1,7 @@ /// TLS-specific functionality for clients. pub mod client { use native_tls::{Error, TlsConnector}; + use std::fmt; /// TLS context for client pub struct Context { @@ -35,5 +36,16 @@ pub mod client { } } } + + impl fmt::Debug for Context { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + const TLS_CONNECTOR: &'static &'static str = &"TlsConnector { .. }"; + f.debug_struct("Context") + .field("domain", &self.domain) + .field("tls_connector", TLS_CONNECTOR) + .finish() + } + } + } diff --git a/src/util.rs b/src/util.rs index 44f36d2..fd07b91 100644 --- a/src/util.rs +++ b/src/util.rs @@ -3,10 +3,10 @@ // Licensed under the MIT License, . // This file may not be copied, modified, or distributed except according to those terms. -use futures::{Future, Poll}; +use futures::{Future, IntoFuture, Poll}; use futures::stream::Stream; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use std::{fmt, io}; +use std::{fmt, io, mem}; use std::error::Error; use std::net::{SocketAddr, ToSocketAddrs}; @@ -111,3 +111,68 @@ pub trait FirstSocketAddr: ToSocketAddrs { } impl FirstSocketAddr for A {} + +/// Creates a new future which will eventually be the same as the one created +/// by calling the closure provided with the arguments provided. +/// +/// The provided closure is only run once the future has a callback scheduled +/// on it, otherwise the callback never runs. Once run, however, this future is +/// the same as the one the closure creates. +pub fn lazy(f: F, args: A) -> Lazy + where F: FnOnce(A) -> R, + R: IntoFuture +{ + Lazy { + inner: _Lazy::First(f, args), + } +} + +/// A future which defers creation of the actual future until a callback is +/// scheduled. +/// +/// This is created by the `lazy` function. +#[derive(Debug)] +#[must_use = "futures do nothing unless polled"] +pub struct Lazy { + inner: _Lazy, +} + +#[derive(Debug)] +enum _Lazy { + First(F, A), + Second(R), + Moved, +} + +impl Lazy + where F: FnOnce(A) -> R, + R: IntoFuture, +{ + fn get(&mut self) -> &mut R::Future { + match self.inner { + _Lazy::First(..) => {} + _Lazy::Second(ref mut f) => return f, + _Lazy::Moved => panic!(), // can only happen if `f()` panics + } + match mem::replace(&mut self.inner, _Lazy::Moved) { + _Lazy::First(f, args) => self.inner = _Lazy::Second(f(args).into_future()), + _ => panic!(), // we already found First + } + match self.inner { + _Lazy::Second(ref mut f) => f, + _ => panic!(), // we just stored Second + } + } +} + +impl Future for Lazy + where F: FnOnce(A) -> R, + R: IntoFuture, +{ + type Item = R::Item; + type Error = R::Error; + + fn poll(&mut self) -> Poll { + self.get().poll() + } +}