diff --git a/examples/concurrency.rs b/examples/concurrency.rs index fc26935..0ab4ff9 100644 --- a/examples/concurrency.rs +++ b/examples/concurrency.rs @@ -59,7 +59,7 @@ fn run_once(clients: &[FutureClient], concurrency: u32, print: bool) { .take(concurrency as usize) .map(|client| { let start = SystemTime::now(); - let future = client.read(&CHUNK_SIZE).map(move |_| start.elapsed().unwrap()); + let future = client.read(CHUNK_SIZE).map(move |_| start.elapsed().unwrap()); thread::yield_now(); future }) diff --git a/examples/pubsub.rs b/examples/pubsub.rs index dc1b2b4..a34e9a9 100644 --- a/examples/pubsub.rs +++ b/examples/pubsub.rs @@ -65,7 +65,7 @@ impl Subscriber { .listen("localhost:0") .wait() .unwrap(); - publisher.subscribe(&id, &subscriber.local_addr()).unwrap(); + publisher.subscribe(id, *subscriber.local_addr()).unwrap(); subscriber } } @@ -90,7 +90,7 @@ impl publisher::FutureService for Publisher { .unwrap() .values() // Ignore failing subscribers. - .map(move |client| client.receive(&message).then(|_| Ok(()))) + .map(move |client| client.receive(message.clone()).then(|_| Ok(()))) .collect::>()) .map(|_| ()) .boxed() @@ -127,8 +127,8 @@ fn main() { let _subscriber2 = Subscriber::new(1, publisher.clone()); println!("Broadcasting..."); - publisher.broadcast(&"hello to all".to_string()).unwrap(); - publisher.unsubscribe(&1).unwrap(); - publisher.broadcast(&"hello again".to_string()).unwrap(); + publisher.broadcast("hello to all".to_string()).unwrap(); + publisher.unsubscribe(1).unwrap(); + publisher.broadcast("hello again".to_string()).unwrap(); thread::sleep(Duration::from_millis(300)); } diff --git a/examples/readme.rs b/examples/readme.rs index 04f9e98..b982e91 100644 --- a/examples/readme.rs +++ b/examples/readme.rs @@ -30,5 +30,5 @@ fn main() { let addr = "localhost:10000"; let _server = HelloServer.listen(addr); let client = SyncClient::connect(addr).unwrap(); - println!("{}", client.hello(&"Mom".to_string()).unwrap()); + println!("{}", client.hello("Mom".to_string()).unwrap()); } diff --git a/examples/readme2.rs b/examples/readme2.rs index 8810df9..e28211c 100644 --- a/examples/readme2.rs +++ b/examples/readme2.rs @@ -52,6 +52,6 @@ fn main() { let addr = "localhost:10000"; let _server = HelloServer.listen(addr); let client = SyncClient::connect(addr).unwrap(); - println!("{}", client.hello(&"Mom".to_string()).unwrap()); - println!("{}", client.hello(&"".to_string()).unwrap_err()); + println!("{}", client.hello("Mom".to_string()).unwrap()); + println!("{}", client.hello("".to_string()).unwrap_err()); } diff --git a/examples/server_calling_server.rs b/examples/server_calling_server.rs index 4e71e0c..aebdfe3 100644 --- a/examples/server_calling_server.rs +++ b/examples/server_calling_server.rs @@ -54,7 +54,7 @@ impl DoubleFutureService for DoubleServer { fn double(&self, x: i32) -> Self::DoubleFut { self.client - .add(&x, &x) + .add(x, x) .map_err(|e| e.to_string().into()) .boxed() } @@ -68,6 +68,6 @@ fn main() { let double_client = double::SyncClient::connect(double.local_addr()).unwrap(); for i in 0..5 { - println!("{:?}", double_client.double(&i).unwrap()); + println!("{:?}", double_client.double(i).unwrap()); } } diff --git a/examples/two_clients.rs b/examples/two_clients.rs index 9fe3820..feaad18 100644 --- a/examples/two_clients.rs +++ b/examples/two_clients.rs @@ -63,14 +63,14 @@ fn main() { let bar_client = bar::SyncClient::connect(bar.local_addr()).unwrap(); let baz_client = baz::SyncClient::connect(baz.local_addr()).unwrap(); - info!("Result: {:?}", bar_client.bar(&17)); + info!("Result: {:?}", bar_client.bar(17)); let total = 20; for i in 1..(total + 1) { if i % 2 == 0 { - info!("Result 1: {:?}", bar_client.bar(&i)); + info!("Result 1: {:?}", bar_client.bar(i)); } else { - info!("Result 2: {:?}", baz_client.baz(&i.to_string())); + info!("Result 2: {:?}", baz_client.baz(i.to_string())); } } diff --git a/src/client.rs b/src/client.rs index 2e044b7..4ec8be7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,8 +3,9 @@ // Licensed under the MIT License, . // This file may not be copied, modified, or distributed except according to those terms. -use Packet; -use futures::{Async, BoxFuture}; +use WireError; +use bincode::serde::DeserializeError; +use futures::{Async, BoxFuture, Future}; use futures::stream::Empty; use std::fmt; use std::io; @@ -16,27 +17,44 @@ use util::Never; /// /// Typically, this would be combined with a serialization pre-processing step /// and a deserialization post-processing step. -#[derive(Clone)] -pub struct Client { - inner: pipeline::Client, Empty, io::Error>, +pub struct Client { + inner: pipeline::Client>, + DeserializeError>, + Empty, + io::Error>, } -impl Service for Client { - type Request = Packet; - type Response = Vec; +impl Clone for Client { + fn clone(&self) -> Self { + Client { inner: self.inner.clone() } + } +} + +impl Service for Client + where Req: Send + 'static, + Resp: Send + 'static, + E: Send + 'static, +{ + type Request = Req; + type Response = Result>; type Error = io::Error; - type Future = BoxFuture, io::Error>; + type Future = BoxFuture; fn poll_ready(&self) -> Async<()> { Async::Ready(()) } - fn call(&self, request: Packet) -> Self::Future { + fn call(&self, request: Self::Request) -> Self::Future { self.inner.call(pipeline::Message::WithoutBody(request)) + .map(|r| r.map(|r| r.map_err(::Error::from)) + .map_err(::Error::ClientDeserialize) + .and_then(|r| r)) + .boxed() } } -impl fmt::Debug for Client { +impl fmt::Debug for Client { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { write!(f, "Client {{ .. }}") } @@ -45,7 +63,8 @@ impl fmt::Debug for Client { /// Exposes a trait for connecting asynchronously to servers. pub mod future { use futures::{self, Async, Future}; - use protocol::{LOOP_HANDLE, TarpcTransport}; + use protocol::{LOOP_HANDLE, new_transport}; + use serde::{Deserialize, Serialize}; use std::cell::RefCell; use std::io; use std::net::SocketAddr; @@ -64,12 +83,12 @@ pub mod future { } /// A future that resolves to a `Client` or an `io::Error`. - pub struct ClientFuture { - inner: futures::Oneshot>, + pub struct ClientFuture { + inner: futures::Oneshot>>, } - impl Future for ClientFuture { - type Item = Client; + impl Future for ClientFuture { + type Item = Client; type Error = io::Error; fn poll(&mut self) -> futures::Poll { @@ -81,12 +100,16 @@ pub mod future { } } - impl Connect for Client { - type Fut = ClientFuture; + impl Connect for Client + where Req: Serialize + Send + 'static, + Resp: Deserialize + Send + 'static, + E: Deserialize + Send + 'static, + { + type Fut = ClientFuture; /// Starts an event loop on a thread and registers a new client /// connected to the given address. - fn connect(addr: &SocketAddr) -> ClientFuture { + fn connect(addr: &SocketAddr) -> ClientFuture { let addr = *addr; let (tx, rx) = futures::oneshot(); LOOP_HANDLE.spawn(move |handle| { @@ -95,7 +118,7 @@ pub mod future { .and_then(move |tcp| { let tcp = RefCell::new(Some(tcp)); let c = try!(pipeline::connect(&handle2, move || { - Ok(TarpcTransport::new(tcp.borrow_mut().take().unwrap())) + Ok(new_transport(tcp.borrow_mut().take().unwrap())) })); Ok(Client { inner: c }) }) @@ -109,6 +132,7 @@ pub mod future { /// Exposes a trait for connecting synchronously to servers. pub mod sync { use futures::Future; + use serde::{Deserialize, Serialize}; use std::io; use std::net::ToSocketAddrs; use super::Client; @@ -119,7 +143,11 @@ pub mod sync { fn connect(addr: A) -> Result where A: ToSocketAddrs; } - impl Connect for Client { + impl Connect for Client + where Req: Serialize + Send + 'static, + Resp: Deserialize + Send + 'static, + E: Deserialize + Send + 'static, + { fn connect(addr: A) -> Result where A: ToSocketAddrs { diff --git a/src/errors.rs b/src/errors.rs index 6c28d3f..f0cab7e 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -7,13 +7,11 @@ use bincode; use serde::{Deserialize, Serialize}; use std::{fmt, io}; use std::error::Error as StdError; -use tokio_proto::pipeline; +use tokio_proto::{multiplex, pipeline}; /// All errors that can occur during the use of tarpc. #[derive(Debug)] -pub enum Error - where E: SerializableError -{ +pub enum Error { /// Any IO error. Io(io::Error), /// Error in deserializing a server response. @@ -78,7 +76,7 @@ impl StdError for Error { } } -impl From>> for Error { +impl From>> for Error { fn from(err: pipeline::Error>) -> Self { match err { pipeline::Error::Transport(e) => e, @@ -87,13 +85,22 @@ impl From>> for Error { } } -impl From for Error { +impl From>> for Error { + fn from(err: multiplex::Error>) -> Self { + match err { + multiplex::Error::Transport(e) => e, + multiplex::Error::Io(e) => e.into(), + } + } +} + +impl From for Error { fn from(err: io::Error) -> Self { Error::Io(err) } } -impl From> for Error { +impl From> for Error { fn from(err: WireError) -> Self { match err { WireError::ServerDeserialize(s) => Error::ServerDeserialize(s), @@ -106,9 +113,7 @@ impl From> for Error { /// A serializable, server-supplied error. #[doc(hidden)] #[derive(Deserialize, Serialize, Clone, Debug)] -pub enum WireError - where E: SerializableError -{ +pub enum WireError { /// Error in deserializing a client request. ServerDeserialize(String), /// Error in serializing server response. diff --git a/src/lib.rs b/src/lib.rs index 9975ef6..7308e2b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,7 +54,7 @@ //! let addr = "localhost:10000"; //! let _server = HelloServer.listen(addr); //! let client = SyncClient::connect(addr).unwrap(); -//! println!("{}", client.hello(&"Mom".to_string()).unwrap()); +//! println!("{}", client.hello("Mom".to_string()).unwrap()); //! } //! ``` //! @@ -62,7 +62,6 @@ #![feature(plugin, question_mark, conservative_impl_trait, never_type, rustc_macro)] #![plugin(tarpc_plugins)] -extern crate bincode; extern crate byteorder; extern crate bytes; #[macro_use] @@ -73,6 +72,8 @@ extern crate log; extern crate serde_derive; extern crate take; +#[doc(hidden)] +pub extern crate bincode; #[doc(hidden)] pub extern crate futures; #[doc(hidden)] @@ -94,9 +95,9 @@ pub use client::future::ClientFuture; #[doc(hidden)] pub use errors::{WireError}; #[doc(hidden)] -pub use protocol::{Packet, deserialize}; +pub use protocol::{new_transport, Framed}; #[doc(hidden)] -pub use server::{ListenFuture, SerializeFuture, SerializedReply, listen, serialize_reply}; +pub use server::{ListenFuture, Response, listen_pipeline}; /// Provides some utility error types, as well as a trait for spawning futures on the default event /// loop. diff --git a/src/macros.rs b/src/macros.rs index e2ef0b7..077ac96 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -337,39 +337,42 @@ macro_rules! service { rpc $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty | $error:ty; )* ) => { - service! { - { } + #[allow(non_camel_case_types, unused)] + #[derive(Debug)] + enum __tarpc_service_Request { + NotIrrefutable(()), $( - $(#[$attr])* - rpc $fn_name( $( $arg : $in_ ),* ) -> $out | $error; - )* - - { - #[allow(non_camel_case_types, unused)] - #[derive(Debug)] - enum __ClientSideRequest<'a> { - $( - $fn_name(&'a ( $(&'a $in_,)* )) - ),* - } - - impl_serialize!(__ClientSideRequest, { <'__a> }, $($fn_name(($($in_),*)))*); - } + $fn_name(( $($in_,)* )) + ),* } - }; - // Pattern for when all return types and the client request have been expanded - ( - { } // none left to expand - $( - $(#[$attr:meta])* - rpc $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty | $error:ty; - )* - { - $client_req:item - $client_serialize_impl:item + + impl_deserialize!(__tarpc_service_Request, NotIrrefutable(()) $($fn_name(($($in_),*)))*); + impl_serialize!(__tarpc_service_Request, {}, NotIrrefutable(()) $($fn_name(($($in_),*)))*); + + #[allow(non_camel_case_types, unused)] + #[derive(Debug)] + enum __tarpc_service_Response { + NotIrrefutable(()), + $( + $fn_name($out) + ),* } - ) => { + + impl_deserialize!(__tarpc_service_Response, NotIrrefutable(()) $($fn_name($out))*); + impl_serialize!(__tarpc_service_Response, {}, NotIrrefutable(()) $($fn_name($out))*); + + #[allow(non_camel_case_types, unused)] + #[derive(Debug)] + enum __tarpc_service_Error { + NotIrrefutable(()), + $( + $fn_name($error) + ),* + } + + impl_deserialize!(__tarpc_service_Error, NotIrrefutable(()) $($fn_name($error))*); + impl_serialize!(__tarpc_service_Error, {}, NotIrrefutable(()) $($fn_name($error))*); /// Defines the `Future` RPC service. Implementors must be `Clone`, `Send`, and `'static`, /// as required by `tokio_proto::NewService`. This is required so that the service can be used @@ -399,7 +402,7 @@ macro_rules! service { fn listen(self, addr: L) -> $crate::ListenFuture where L: ::std::net::ToSocketAddrs { - return $crate::listen(addr, __tarpc_service_AsyncServer(self)); + return $crate::listen_pipeline(addr, __tarpc_service_AsyncServer(self)); #[allow(non_camel_case_types)] #[derive(Clone)] @@ -411,30 +414,35 @@ macro_rules! service { } } + #[allow(non_camel_case_types)] - enum __tarpc_service_Reply<__tarpc_service_S: FutureService> { - DeserializeError($crate::SerializeFuture), - $($fn_name($crate::futures::Then< - $crate::futures::MapErr< - ty_snake_to_camel!(__tarpc_service_S::$fn_name), - fn($error) -> $crate::WireError<$error>>, - $crate::SerializeFuture, - fn(::std::result::Result<$out, $crate::WireError<$error>>) - -> $crate::SerializeFuture>)),* + type __tarpc_service_Future = + $crate::futures::Finished<$crate::Response<__tarpc_service_Response, + __tarpc_service_Error>, + ::std::io::Error>; + + #[allow(non_camel_case_types)] + enum __tarpc_service_FutureReply<__tarpc_service_S: FutureService> { + DeserializeError(__tarpc_service_Future), + $($fn_name($crate::futures::Then) + -> __tarpc_service_Future>)),* } - impl $crate::futures::Future for __tarpc_service_Reply { - type Item = $crate::SerializedReply; + impl $crate::futures::Future for __tarpc_service_FutureReply { + type Item = $crate::Response<__tarpc_service_Response, __tarpc_service_Error>; + type Error = ::std::io::Error; fn poll(&mut self) -> $crate::futures::Poll { match *self { - __tarpc_service_Reply::DeserializeError(ref mut f) => { - $crate::futures::Future::poll(f) + __tarpc_service_FutureReply::DeserializeError(ref mut __tarpc_service_future) => { + $crate::futures::Future::poll(__tarpc_service_future) } $( - __tarpc_service_Reply::$fn_name(ref mut f) => { - $crate::futures::Future::poll(f) + __tarpc_service_FutureReply::$fn_name(ref mut __tarpc_service_future) => { + $crate::futures::Future::poll(__tarpc_service_future) } ),* } @@ -447,66 +455,48 @@ macro_rules! service { for __tarpc_service_AsyncServer<__tarpc_service_S> where __tarpc_service_S: FutureService { - type Request = ::std::vec::Vec; - type Response = $crate::SerializedReply; + type Request = ::std::result::Result<__tarpc_service_Request, + $crate::bincode::serde::DeserializeError>; + type Response = $crate::Response<__tarpc_service_Response, __tarpc_service_Error>; type Error = ::std::io::Error; - type Future = __tarpc_service_Reply<__tarpc_service_S>; + type Future = __tarpc_service_FutureReply<__tarpc_service_S>; fn poll_ready(&self) -> $crate::futures::Async<()> { $crate::futures::Async::Ready(()) } - fn call(&self, __tarpc_service_req: Self::Request) -> Self::Future { - #[allow(non_camel_case_types, unused)] - #[derive(Debug)] - enum __tarpc_service_ServerSideRequest { - $( - $fn_name(( $($in_,)* )) - ),* - } - - impl_deserialize!(__tarpc_service_ServerSideRequest, - $($fn_name(($($in_),*)))*); - - let __tarpc_service_request = $crate::deserialize(&__tarpc_service_req); - let __tarpc_service_request: __tarpc_service_ServerSideRequest = - match __tarpc_service_request { - ::std::result::Result::Ok(__tarpc_service_request) => { - __tarpc_service_request - } - ::std::result::Result::Err(__tarpc_service_e) => { - return __tarpc_service_Reply::DeserializeError( - deserialize_error(__tarpc_service_e)); - } - }; - match __tarpc_service_request {$( - __tarpc_service_ServerSideRequest::$fn_name(( $($arg,)* )) => { - const SERIALIZE: - fn(::std::result::Result<$out, $crate::WireError<$error>>) - -> $crate::SerializeFuture - = $crate::serialize_reply; - const TO_APP: fn($error) -> $crate::WireError<$error> = - $crate::WireError::App; - - return __tarpc_service_Reply::$fn_name( - $crate::futures::Future::then( - $crate::futures::Future::map_err( - FutureService::$fn_name(&self.0, $($arg),*), - TO_APP), - SERIALIZE)); + fn call(&self, __tarpc_service_request: Self::Request) -> Self::Future { + let __tarpc_service_request = match __tarpc_service_request { + Ok(__tarpc_service_request) => __tarpc_service_request, + Err(__tarpc_service_deserialize_err) => { + return __tarpc_service_FutureReply::DeserializeError( + $crate::futures::finished( + $crate::tokio_proto::pipeline::Message::WithoutBody( + ::std::result::Result::Err( + $crate::WireError::ServerDeserialize( + ::std::string::ToString::to_string(&__tarpc_service_deserialize_err)))))); } - )*} - - #[inline] - fn deserialize_error(__tarpc_service_e: E) - -> $crate::SerializeFuture - { - $crate::serialize_reply( - // The type param is only used in the Error::App variant, so it - // doesn't matter what we specify it as here. - ::std::result::Result::Err::<(), _>( - $crate::WireError::ServerDeserialize::<$crate::util::Never>( - __tarpc_service_e.to_string()))) + }; + match __tarpc_service_request { + __tarpc_service_Request::NotIrrefutable(()) => unreachable!(), + $( + __tarpc_service_Request::$fn_name(( $($arg,)* )) => { + fn __tarpc_service_wrap( + __tarpc_service_response: ::std::result::Result<$out, $error>) + -> __tarpc_service_Future + { + $crate::futures::finished($crate::tokio_proto::pipeline::Message::WithoutBody( + __tarpc_service_response + .map(__tarpc_service_Response::$fn_name) + .map_err(|__tarpc_service_error| $crate::WireError::App(__tarpc_service_Error::$fn_name(__tarpc_service_error))) + )) + } + return __tarpc_service_FutureReply::$fn_name( + $crate::futures::Future::then( + FutureService::$fn_name(&self.0, $($arg),*), + __tarpc_service_wrap)); + } + )* } } } @@ -622,7 +612,7 @@ macro_rules! service { #[allow(unused)] $(#[$attr])* #[inline] - pub fn $fn_name(&self, $($arg: &$in_),*) + pub fn $fn_name(&self, $($arg: $in_),*) -> ::std::result::Result<$out, $crate::Error<$error>> { let rpc = (self.0).$fn_name($($arg),*); @@ -631,17 +621,36 @@ macro_rules! service { )* } + #[allow(non_camel_case_types)] + type __tarpc_service_Client = $crate::Client<__tarpc_service_Request, __tarpc_service_Response, __tarpc_service_Error>; + + #[allow(non_camel_case_types)] + pub struct __tarpc_service_ConnectFuture { + inner: $crate::futures::Map<$crate::ClientFuture<__tarpc_service_Request, __tarpc_service_Response, __tarpc_service_Error>, fn(__tarpc_service_Client) -> T>, + } + + impl $crate::futures::Future for __tarpc_service_ConnectFuture { + type Item = T; + type Error = ::std::io::Error; + + fn poll(&mut self) -> $crate::futures::Poll { + $crate::futures::Future::poll(&mut self.inner) + } + } + #[allow(unused)] #[derive(Clone, Debug)] /// The client stub that makes RPC calls to the server. Exposes a Future interface. - pub struct FutureClient($crate::Client); + pub struct FutureClient(__tarpc_service_Client); impl $crate::future::Connect for FutureClient { - type Fut = $crate::futures::Map<$crate::ClientFuture, fn($crate::Client) -> Self>; + type Fut = __tarpc_service_ConnectFuture; fn connect(addr: &::std::net::SocketAddr) -> Self::Fut { - let client = <$crate::Client as $crate::future::Connect>::connect(addr); - $crate::futures::Future::map(client, FutureClient) + let client = <__tarpc_service_Client as $crate::future::Connect>::connect(addr); + __tarpc_service_ConnectFuture { + inner: $crate::futures::Future::map(client, FutureClient) + } } } @@ -650,51 +659,46 @@ macro_rules! service { #[allow(unused)] $(#[$attr])* #[inline] - pub fn $fn_name(&self, $($arg: &$in_),*) + pub fn $fn_name(&self, $($arg: $in_),*) -> impl $crate::futures::Future> + 'static { - $client_req - $client_serialize_impl - - future_enum! { - enum Fut { - Called(C), - Failed(F) - } - } - - let __tarpc_service_args = ($($arg,)*); - let __tarpc_service_req = &__ClientSideRequest::$fn_name(&__tarpc_service_args); - let __tarpc_service_req = - match $crate::Packet::serialize(&__tarpc_service_req) - { - ::std::result::Result::Err(__tarpc_service_e) => { - return Fut::Failed( - $crate::futures::failed( - $crate::Error::ClientSerialize(__tarpc_service_e))) - } - ::std::result::Result::Ok(__tarpc_service_req) => __tarpc_service_req, - }; + let __tarpc_service_req = __tarpc_service_Request::$fn_name(($($arg,)*)); let __tarpc_service_fut = $crate::tokio_service::Service::call(&self.0, __tarpc_service_req); - Fut::Called($crate::futures::Future::then(__tarpc_service_fut, + $crate::futures::Future::then(__tarpc_service_fut, move |__tarpc_service_msg| { - let __tarpc_service_msg: Vec = try!(__tarpc_service_msg); - let __tarpc_service_msg: - ::std::result::Result< - ::std::result::Result<$out, $crate::WireError<$error>>, _> - = $crate::deserialize(&__tarpc_service_msg); + let __tarpc_service_msg = try!(__tarpc_service_msg); match __tarpc_service_msg { ::std::result::Result::Ok(__tarpc_service_msg) => { - ::std::result::Result::Ok(try!(__tarpc_service_msg)) + if let __tarpc_service_Response::$fn_name(__tarpc_service_msg) = + __tarpc_service_msg + { + ::std::result::Result::Ok(__tarpc_service_msg) + } else { + unreachable!() + } } - ::std::result::Result::Err(__tarpc_service_e) => { - ::std::result::Result::Err( - $crate::Error::ClientDeserialize(__tarpc_service_e)) + ::std::result::Result::Err(__tarpc_service_err) => { + ::std::result::Result::Err(match __tarpc_service_err { + $crate::Error::App(__tarpc_service_err) => { + if let __tarpc_service_Error::$fn_name(__tarpc_service_err) = + __tarpc_service_err + { + $crate::Error::App(__tarpc_service_err) + } else { + unreachable!() + } + } + $crate::Error::ServerDeserialize(__tarpc_service_err) => $crate::Error::ServerDeserialize(__tarpc_service_err), + $crate::Error::ServerSerialize(__tarpc_service_err) => $crate::Error::ServerSerialize(__tarpc_service_err), + $crate::Error::ClientDeserialize(__tarpc_service_err) => $crate::Error::ClientDeserialize(__tarpc_service_err), + $crate::Error::ClientSerialize(__tarpc_service_err) => $crate::Error::ClientSerialize(__tarpc_service_err), + $crate::Error::Io(__tarpc_service_error) => $crate::Error::Io(__tarpc_service_error), + }) } } - })) + }) } )* @@ -760,8 +764,8 @@ mod functional_test { let _ = env_logger::init(); let handle = Server.listen("localhost:0"); let client = SyncClient::connect(handle.local_addr()).unwrap(); - assert_eq!(3, client.add(&1, &2).unwrap()); - assert_eq!("Hey, Tim.", client.hey(&"Tim".to_string()).unwrap()); + assert_eq!(3, client.add(1, 2).unwrap()); + assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).unwrap()); } #[test] @@ -769,8 +773,8 @@ mod functional_test { let handle = Server.listen("localhost:0"); let client1 = SyncClient::connect(handle.local_addr()).unwrap(); let client2 = client1.clone(); - assert_eq!(3, client1.add(&1, &2).unwrap()); - assert_eq!(3, client2.add(&1, &2).unwrap()); + assert_eq!(3, client1.add(1, 2).unwrap()); + assert_eq!(3, client2.add(1, 2).unwrap()); } #[test] @@ -814,8 +818,8 @@ mod functional_test { let _ = env_logger::init(); let handle = Server.listen("localhost:0").wait().unwrap(); let client = FutureClient::connect(handle.local_addr()).wait().unwrap(); - assert_eq!(3, client.add(&1, &2).wait().unwrap()); - assert_eq!("Hey, Tim.", client.hey(&"Tim".to_string()).wait().unwrap()); + assert_eq!(3, client.add(1, 2).wait().unwrap()); + assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).wait().unwrap()); } #[test] @@ -824,8 +828,8 @@ mod functional_test { let handle = Server.listen("localhost:0").wait().unwrap(); let client1 = FutureClient::connect(handle.local_addr()).wait().unwrap(); let client2 = client1.clone(); - assert_eq!(3, client1.add(&1, &2).wait().unwrap()); - assert_eq!(3, client2.add(&1, &2).wait().unwrap()); + assert_eq!(3, client1.add(1, 2).wait().unwrap()); + assert_eq!(3, client2.add(1, 2).wait().unwrap()); } #[test] diff --git a/src/protocol.rs b/src/protocol.rs new file mode 100644 index 0000000..da18c9b --- /dev/null +++ b/src/protocol.rs @@ -0,0 +1,240 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +use serde; +use futures::{self, Async}; +use bincode::{SizeLimit, serde as bincode}; +use byteorder::BigEndian; +use bytes::{BlockBuf, BlockBufCursor, Buf, MutBuf}; +use std::{cmp, io, mem, thread}; +use std::marker::PhantomData; +use std::sync::mpsc; +use util::Never; +use tokio_core::io::{FramedIo, Io}; +use tokio_core::reactor::{Core, Remote}; +use tokio_proto::{self as proto, pipeline}; + +lazy_static! { + #[doc(hidden)] + pub static ref LOOP_HANDLE: Remote = { + let (tx, rx) = mpsc::channel(); + thread::spawn(move || { + let mut lupe = Core::new().unwrap(); + tx.send(lupe.handle().remote().clone()).unwrap(); + // Run forever + lupe.run(futures::empty::<(), !>()).unwrap(); + }); + rx.recv().unwrap() + }; +} + +/// Handles the IO of tarpc messages. +pub struct Framed { + inner: proto::Framed, Serializer>, +} + +/// The type of message sent and received by the transport. +pub type Frame = pipeline::Frame; + +impl FramedIo for Framed + where I: Io, + In: serde::Serialize, + Out: serde::Deserialize, +{ + type In = Frame; + type Out = Frame>; + + fn poll_read(&mut self) -> Async<()> { + self.inner.poll_read() + } + + fn poll_write(&mut self) -> Async<()> { + self.inner.poll_write() + } + + fn read(&mut self) -> io::Result> { + self.inner.read() + } + + fn write(&mut self, req: Self::In) -> io::Result> { + self.inner.write(req) + } + + fn flush(&mut self) -> io::Result> { + self.inner.flush() + } +} + +/// Constructs a new tarpc FramedIo +pub fn new_transport(upstream: I) -> Framed + where I: Io, + In: serde::Serialize, + Out: serde::Deserialize, +{ + Framed { + inner: proto::Framed::new(upstream, + Parser::new(), + Serializer::new(), + BlockBuf::new(128, 8_192), + BlockBuf::new(128, 8_192)) + } +} + +struct Parser { + state: ParserState, + _phantom_data: PhantomData +} + +enum ParserState { + Len, + Payload { + len: u64, + } +} + +impl Parser { + fn new() -> Self { + Parser { + state: ParserState::Len, + _phantom_data: PhantomData, + } + } +} + +impl proto::Parse for Parser + where T: serde::Deserialize, +{ + type Out = Frame>; + + fn parse(&mut self, buf: &mut BlockBuf) -> Option { + use self::ParserState::*; + + loop { + match self.state { + Len if buf.len() < mem::size_of::() => return None, + Len => { + self.state = Payload { len: buf.buf().read_u64::() }; + buf.shift(mem::size_of::()); + } + Payload { len } if buf.len() < len as usize => return None, + Payload { len } => { + match bincode::deserialize_from(&mut BlockBufReader::new(buf), + SizeLimit::Infinite) + { + Ok(msg) => { + buf.shift(len as usize); + self.state = Len; + return Some(pipeline::Frame::Message(Ok(msg))); + } + Err(err) => { + // Clear any unread bytes so we don't read garbage on next request. + let buf_len = buf.len(); + buf.shift(buf_len); + return Some(pipeline::Frame::Message(Err(err))); + } + } + } + } + } + } +} + +struct Serializer(PhantomData); + +impl Serializer { + fn new() -> Self { + Serializer(PhantomData) + } +} + +impl proto::Serialize for Serializer + where T: serde::Serialize, +{ + type In = Frame; + + fn serialize(&mut self, msg: Self::In, buf: &mut BlockBuf) { + use tokio_proto::pipeline::Frame::*; + + match msg { + Message(msg) => { + buf.write_u64::(bincode::serialized_size(&msg)); + bincode::serialize_into(&mut BlockBufWriter::new(buf), + &msg, + SizeLimit::Infinite) + // TODO(tikue): handle err + .expect("In bincode::serialize_into"); + } + Error(e) => panic!("Unexpected error in Serializer::serialize: {}", e), + MessageWithBody(..) | Body(..) | Done => unreachable!(), + } + + } +} + +// == Scaffolding from Buf/MutBuf to Read/Write == + +struct BlockBufReader<'a> { + cursor: BlockBufCursor<'a>, +} + +impl<'a> BlockBufReader<'a> { + fn new(buf: &'a mut BlockBuf) -> Self { + BlockBufReader { cursor: buf.buf() } + } +} + +impl<'a> io::Read for BlockBufReader<'a> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let init_remaining = self.cursor.remaining(); + let buf_len = buf.len(); + self.cursor.read_slice(&mut buf[..cmp::min(init_remaining, buf_len)]); + Ok(init_remaining - self.cursor.remaining()) + } +} + +struct BlockBufWriter<'a> { + buf: &'a mut BlockBuf, +} + +impl<'a> BlockBufWriter<'a> { + fn new(buf: &'a mut BlockBuf) -> Self { + BlockBufWriter { buf: buf } + } +} + +impl<'a> io::Write for BlockBufWriter<'a> { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.buf.write_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + // Always writes immediately, so there's never anything to flush. + Ok(()) + } +} + +#[test] +fn serialize() { + use tokio_proto::{Parse, Serialize}; + + const MSG: Frame<(char, char, char)> = pipeline::Frame::Message(('a', 'b', 'c')); + let mut buf = BlockBuf::default(); + + // Serialize twice to check for idempotence. + for _ in 0..2 { + Serializer::new().serialize(MSG, &mut buf); + let actual: Option>> = Parser::new().parse(&mut buf); + + match actual { + Some(pipeline::Frame::Message(ref v)) if *v.as_ref().unwrap() == MSG.unwrap_msg() => {} // good, + bad => panic!("Expected {:?}, but got {:?}", Some(MSG), bad), + } + + assert!(buf.is_empty(), + "Expected empty buf but got {:?}", + {buf.compact(); buf.bytes().unwrap()}); + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs deleted file mode 100644 index 8ec8032..0000000 --- a/src/protocol/mod.rs +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -use serde; -use futures::{self, Async}; -use bincode::{SizeLimit, serde as bincode}; -use std::{io, thread}; -use std::collections::VecDeque; -use std::sync::mpsc; -use util::Never; -use tokio_core::io::{FramedIo, Io}; -use tokio_core::reactor::{Core, Remote}; -use tokio_proto::pipeline::Frame; - -lazy_static! { - #[doc(hidden)] - pub static ref LOOP_HANDLE: Remote = { - let (tx, rx) = mpsc::channel(); - thread::spawn(move || { - let mut lupe = Core::new().unwrap(); - tx.send(lupe.handle().remote().clone()).unwrap(); - // Run forever - lupe.run(futures::empty::<(), !>()).unwrap(); - }); - rx.recv().unwrap() - }; -} - -pub use self::writer::Packet; - -pub mod reader; -pub mod writer; - -/// A helper trait to provide the `map_non_block` function on Results. -trait MapNonBlock { - /// Maps a `Result` to a `Result>` by converting - /// operation-would-block errors into `Ok(None)`. - fn map_non_block(self) -> io::Result>; -} - -impl MapNonBlock for io::Result { - fn map_non_block(self) -> io::Result> { - use std::io::ErrorKind::WouldBlock; - - match self { - Ok(value) => Ok(Some(value)), - Err(err) => { - if let WouldBlock = err.kind() { - Ok(None) - } else { - Err(err) - } - } - } - } -} - -/// Deserialize a buffer into a `D` and its ID. On error, returns `tarpc::Error`. -pub fn deserialize(mut buf: &[u8]) -> Result { - bincode::deserialize_from(&mut buf, SizeLimit::Infinite) -} - -pub struct TarpcTransport { - stream: T, - read_state: reader::ReadState, - outbound: VecDeque, - head: Option, -} - -impl TarpcTransport { - pub fn new(stream: T) -> Self { - TarpcTransport { - stream: stream, - read_state: reader::ReadState::init(), - outbound: VecDeque::new(), - head: None, - } - } -} - -impl FramedIo for TarpcTransport - where T: Io -{ - type In = Frame; - type Out = Frame, Never, io::Error>; - - fn poll_read(&mut self) -> Async<()> { - self.stream.poll_read() - } - - fn poll_write(&mut self) -> Async<()> { - self.stream.poll_write() - } - - fn read(&mut self) -> io::Result, Never, io::Error>>> { - self.read_state.next(&mut self.stream) - } - - fn write(&mut self, req: Self::In) -> io::Result> { - self.outbound.push_back(req.unwrap_msg()); - self.flush() - } - - fn flush(&mut self) -> io::Result> { - writer::NextWriteState::next(&mut self.head, &mut self.stream, &mut self.outbound) - } -} diff --git a/src/protocol/reader.rs b/src/protocol/reader.rs deleted file mode 100644 index e6c3808..0000000 --- a/src/protocol/reader.rs +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -use byteorder::{BigEndian, ReadBytesExt}; -use bytes::{MutBuf, Take}; -use futures::Async; -use std::io; -use std::mem; -use tokio_proto::TryRead; -use tokio_proto::pipeline::Frame; -use util::Never; - -#[derive(Debug)] -pub struct U64Reader { - read: usize, - data: [u8; 8], -} - -impl U64Reader { - fn new() -> Self { - U64Reader { - read: 0, - data: [0; 8], - } - } -} - -impl MutBuf for U64Reader { - fn remaining(&self) -> usize { - 8 - self.read - } - - unsafe fn advance(&mut self, count: usize) { - self.read += count; - } - - unsafe fn mut_bytes(&mut self) -> &mut [u8] { - &mut self.data[self.read..] - } -} - -#[derive(Debug)] -enum NextReadAction { - Continue, - Stop(Result), -} - -trait MutBufExt: MutBuf + Sized { - type Inner; - - fn take(&mut self) -> Self::Inner; - - fn try_read(&mut self, stream: &mut R) -> io::Result> { - while let Async::Ready(bytes_read) = stream.try_read_buf(self)? { - debug!("Reader: read {} bytes, {} remaining.", - bytes_read, - self.remaining()); - if bytes_read == 0 { - debug!("Reader: connection broken."); - let err = io::Error::new(io::ErrorKind::BrokenPipe, "The connection was closed."); - return Ok(NextReadAction::Stop(Err(err))); - } - - if !self.has_remaining() { - trace!("Reader: finished."); - return Ok(NextReadAction::Stop(Ok(self.take()))); - } - } - Ok(NextReadAction::Continue) - } -} - -impl MutBufExt for U64Reader { - type Inner = u64; - - fn take(&mut self) -> u64 { - (&self.data as &[u8]).read_u64::().unwrap() - } -} - -impl MutBufExt for Take> { - type Inner = Vec; - - fn take(&mut self) -> Vec { - mem::replace(self.get_mut(), vec![]) - } -} - -/// A state machine that reads packets in non-blocking fashion. -#[derive(Debug)] -pub enum ReadState { - /// Tracks how many bytes of the message size have been read. - Len(U64Reader), - /// Tracks read progress. - Data(Take>), -} - -#[derive(Debug)] -enum NextReadState { - Same, - Next(ReadState), - Reset(Vec), -} - -impl ReadState { - pub fn init() -> ReadState { - ReadState::Len(U64Reader::new()) - } - - pub fn next(&mut self, - socket: &mut R) - -> io::Result, Never, io::Error>>> { - loop { - let next = match *self { - ReadState::Len(ref mut len) => { - match len.try_read(socket)? { - NextReadAction::Continue => NextReadState::Same, - NextReadAction::Stop(result) => { - match result { - Ok(len) => { - let buf = Vec::with_capacity(len as usize); - NextReadState::Next(ReadState::Data(Take::new(buf, - len as usize))) - } - Err(e) => return Ok(Async::Ready(Frame::Error(e))), - } - } - } - } - ReadState::Data(ref mut buf) => { - match buf.try_read(socket)? { - NextReadAction::Continue => NextReadState::Same, - NextReadAction::Stop(result) => { - match result { - Ok(buf) => NextReadState::Reset(buf), - Err(e) => return Ok(Async::Ready(Frame::Error(e))), - } - } - } - } - }; - match next { - NextReadState::Same => return Ok(Async::NotReady), - NextReadState::Next(next) => *self = next, - NextReadState::Reset(packet) => { - *self = ReadState::init(); - return Ok(Async::Ready(Frame::Message(packet))); - } - } - } - } -} diff --git a/src/protocol/writer.rs b/src/protocol/writer.rs deleted file mode 100644 index 44e67a9..0000000 --- a/src/protocol/writer.rs +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -use bincode::SizeLimit; -use bincode::serde as bincode; -use byteorder::{BigEndian, WriteBytesExt}; -use bytes::Buf; -use futures::Async; -use serde::Serialize; -use std::collections::VecDeque; -use std::io::{self, Cursor}; -use std::mem; -use tokio_proto::TryWrite; - -/// The means of communication between client and server. -#[derive(Clone, Debug)] -pub struct Packet { - /// (payload_len: u64, payload) - /// - /// The payload is typically a serialized message. - pub buf: Cursor>, -} - -impl Packet { - /// Creates a new packet, (len, payload) - pub fn serialize(message: &S) -> Result - where S: Serialize - { - let payload_len = bincode::serialized_size(message); - - // (len, message) - let mut buf = Vec::with_capacity(mem::size_of::() + payload_len as usize); - - buf.write_u64::(payload_len).unwrap(); - bincode::serialize_into(&mut buf, message, SizeLimit::Infinite)?; - Ok(Packet { buf: Cursor::new(buf) }) - } -} - -#[derive(Debug)] -enum NextWriteAction { - Stop, - Continue, -} - -trait BufExt: Buf + Sized { - /// Writes data to stream. Returns Ok(true) if all data has been written or Ok(false) if - /// there's still data to write. - fn try_write(&mut self, stream: &mut W) -> io::Result { - while let Async::Ready(bytes_written) = stream.try_write_buf(self)? { - debug!("Writer: wrote {} bytes; {} remaining.", - bytes_written, - self.remaining()); - if bytes_written == 0 { - trace!("Writer: would block."); - return Ok(NextWriteAction::Continue); - } - if !self.has_remaining() { - return Ok(NextWriteAction::Stop); - } - } - Ok(NextWriteAction::Continue) - } -} - -impl BufExt for B {} - -#[derive(Debug)] -pub enum NextWriteState { - Nothing, - Next(Packet), -} - -impl NextWriteState { - pub fn next(state: &mut Option, - socket: &mut W, - outbound: &mut VecDeque) - -> io::Result> { - loop { - let update = match *state { - None => { - match outbound.pop_front() { - Some(packet) => { - let size = packet.buf.remaining() as u64; - debug_assert!(size >= mem::size_of::() as u64); - NextWriteState::Next(packet) - } - None => return Ok(Async::Ready(())), - } - } - Some(ref mut packet) => { - match BufExt::try_write(&mut packet.buf, socket)? { - NextWriteAction::Stop => NextWriteState::Nothing, - NextWriteAction::Continue => return Ok(Async::NotReady), - } - } - }; - match update { - NextWriteState::Next(next) => *state = Some(next), - NextWriteState::Nothing => { - *state = None; - } - } - } - } -} diff --git a/src/server.rs b/src/server.rs index c5b44b6..3df7e86 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,12 +3,12 @@ // Licensed under the MIT License, . // This file may not be copied, modified, or distributed except according to those terms. -use errors::{SerializableError, WireError}; +use bincode::serde::DeserializeError; +use errors::WireError; use futures::{self, Async, Future}; use futures::stream::Empty; -use protocol::{LOOP_HANDLE, TarpcTransport}; -use protocol::writer::Packet; -use serde::Serialize; +use protocol::{LOOP_HANDLE, new_transport}; +use serde::{Deserialize, Serialize}; use std::io; use std::net::ToSocketAddrs; use tokio_proto::pipeline; @@ -16,12 +16,18 @@ use tokio_proto::server::{self, ServerHandle}; use tokio_service::NewService; use util::Never; +/// A message from server to client. +pub type Response = pipeline::Message>, Empty>; + /// Spawns a service that binds to the given address and runs on the default tokio `Loop`. -pub fn listen(addr: A, new_service: T) -> ListenFuture - where T: NewService, - Response = pipeline::Message>, +pub fn listen_pipeline(addr: A, new_service: S) -> ListenFuture + where S: NewService, + Response = Response, Error = io::Error> + Send + 'static, - A: ToSocketAddrs + A: ToSocketAddrs, + Req: Deserialize, + Resp: Serialize, + E: Serialize, { // TODO(tikue): don't use ToSocketAddrs, or don't unwrap. let addr = addr.to_socket_addrs().unwrap().next().unwrap(); @@ -29,7 +35,7 @@ pub fn listen(addr: A, new_service: T) -> ListenFuture let (tx, rx) = futures::oneshot(); LOOP_HANDLE.spawn(move |handle| { Ok(tx.complete(server::listen(handle, addr, move |stream| { - pipeline::Server::new(new_service.new_service()?, TarpcTransport::new(stream)) + pipeline::Server::new(new_service.new_service()?, new_transport(stream)) }).unwrap())) }); ListenFuture { inner: rx } @@ -51,29 +57,3 @@ impl Future for ListenFuture { } } } - -/// Returns a future containing the serialized reply. -/// -/// Because serialization can take a non-trivial -/// amount of cpu time, it is run on a thread pool. -#[doc(hidden)] -#[inline] -pub fn serialize_reply(result: Result>) - -> SerializeFuture -{ - let packet = match Packet::serialize(&result) { - Ok(packet) => packet, - Err(e) => { - let err: Result> = Err(WireError::ServerSerialize(e.to_string())); - Packet::serialize(&err).unwrap() - } - }; - futures::finished(pipeline::Message::WithoutBody(packet)) -} - -#[doc(hidden)] -pub type SerializeFuture = futures::Finished; - -#[doc(hidden)] -pub type SerializedReply = pipeline::Message>; diff --git a/src/util.rs b/src/util.rs index b0710b1..aeb5539 100644 --- a/src/util.rs +++ b/src/util.rs @@ -3,6 +3,8 @@ // Licensed under the MIT License, . // This file may not be copied, modified, or distributed except according to those terms. +use futures::{Future, Poll}; +use futures::stream::Stream; use std::fmt; use std::error::Error; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -14,13 +16,43 @@ pub struct Never(!); impl Error for Never { fn description(&self) -> &str { - unreachable!() + match self.0 { + // TODO(tikue): remove when https://github.com/rust-lang/rust/issues/12609 lands + _ => unreachable!(), + } } } impl fmt::Display for Never { fn fmt(&self, _: &mut fmt::Formatter) -> fmt::Result { - unreachable!() + match self.0 { + // TODO(tikue): remove when https://github.com/rust-lang/rust/issues/12609 lands + _ => unreachable!(), + } + } +} + +impl Future for Never { + type Item = Never; + type Error = Never; + + fn poll(&mut self) -> Poll { + match self.0 { + // TODO(tikue): remove when https://github.com/rust-lang/rust/issues/12609 lands + _ => unreachable!(), + } + } +} + +impl Stream for Never { + type Item = Never; + type Error = Never; + + fn poll(&mut self) -> Poll, Self::Error> { + match self.0 { + // TODO(tikue): remove when https://github.com/rust-lang/rust/issues/12609 lands + _ => unreachable!(), + } } } @@ -28,7 +60,10 @@ impl Serialize for Never { fn serialize(&self, _: &mut S) -> Result<(), S::Error> where S: Serializer { - unreachable!() + match self.0 { + // TODO(tikue): remove when https://github.com/rust-lang/rust/issues/12609 lands + _ => unreachable!(), + } } }