diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index a3475a5..3f47601 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -14,8 +14,7 @@ description = "An RPC framework for Rust with a focus on ease of use." bincode = "^0.4.0" log = "^0.3.5" scoped-pool = "^0.1.4" -serde = "^0.6.11" -serde_macros = "^0.6.11" +serde = "^0.6.13" [dev-dependencies] lazy_static = "^0.1.15" diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 55e9ee1..dd6055a 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -8,10 +8,7 @@ //! Example usage: //! //! ``` -//! # #![feature(custom_derive, plugin)] -//! # #![plugin(serde_macros)] //! # #[macro_use] extern crate tarpc; -//! # extern crate serde; //! mod my_server { //! service! { //! rpc hello(name: String) -> String; @@ -48,17 +45,18 @@ //! ``` #![deny(missing_docs)] -#![feature(custom_derive, plugin, test, type_ascription)] -#![plugin(serde_macros)] +#![cfg_attr(test, feature(test))] extern crate serde; extern crate bincode; -#[cfg(test)] -#[macro_use] -extern crate lazy_static; #[macro_use] extern crate log; extern crate scoped_pool; + +#[cfg(test)] +#[macro_use] +extern crate lazy_static; +#[cfg(test)] extern crate test; macro_rules! pos { diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs index 6304183..4baa737 100644 --- a/tarpc/src/macros.rs +++ b/tarpc/src/macros.rs @@ -3,9 +3,14 @@ // Licensed under the MIT License, . // This file may not be copied, modified, or distributed except according to those terms. -#[doc(hidden)] -#[macro_export] -macro_rules! as_item { ($i:item) => {$i} } +/// Serde re-exports required by macros. Not for general use. +pub mod serde { + pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; + /// Deserialization re-exports required by macros. Not for general use. + pub mod de { + pub use serde::de::{EnumVisitor, Error, Visitor, VariantVisitor}; + } +} // Required because if-let can't be used with irrefutable patterns, so it needs // to be special cased. @@ -14,24 +19,24 @@ macro_rules! as_item { ($i:item) => {$i} } macro_rules! client_methods { ( { $(#[$attr:meta])* } - $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty + $fn_name:ident( ($($arg:ident,)*) : ($($in_:ty,)*) ) -> $out:ty ) => ( $(#[$attr])* pub fn $fn_name(&self, $($arg: $in_),*) -> $crate::Result<$out> { - let reply = try!((self.0).rpc(request_variant!($fn_name $($arg),*))); + let reply = try!((self.0).rpc(__Request::$fn_name(($($arg,)*)))); let __Reply::$fn_name(reply) = reply; - Ok(reply) + ::std::result::Result::Ok(reply) } ); ($( { $(#[$attr:meta])* } - $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty + $fn_name:ident( ($( $arg:ident,)*) : ($($in_:ty, )*) ) -> $out:ty )*) => ( $( $(#[$attr])* pub fn $fn_name(&self, $($arg: $in_),*) -> $crate::Result<$out> { - let reply = try!((self.0).rpc(request_variant!($fn_name $($arg),*))); + let reply = try!((self.0).rpc(__Request::$fn_name(($($arg,)*)))); if let __Reply::$fn_name(reply) = reply { - Ok(reply) + ::std::result::Result::Ok(reply) } else { panic!("Incorrect reply variant returned from protocol::Clientrpc; expected `{}`, but got {:?}", stringify!($fn_name), reply); } @@ -46,7 +51,7 @@ macro_rules! client_methods { macro_rules! async_client_methods { ( { $(#[$attr:meta])* } - $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty + $fn_name:ident( ($( $arg:ident, )*) : ($( $in_:ty, )*) ) -> $out:ty ) => ( $(#[$attr])* pub fn $fn_name(&self, $($arg: $in_),*) -> Future<$out> { @@ -54,7 +59,7 @@ macro_rules! async_client_methods { let __Reply::$fn_name(reply) = reply; reply } - let reply = (self.0).rpc_async(request_variant!($fn_name $($arg),*)); + let reply = (self.0).rpc_async(__Request::$fn_name(($($arg,)*))); Future { future: reply, mapper: mapper, @@ -63,7 +68,7 @@ macro_rules! async_client_methods { ); ($( { $(#[$attr:meta])* } - $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty + $fn_name:ident( ($( $arg:ident, )*) : ($( $in_:ty, )*) ) -> $out:ty )*) => ( $( $(#[$attr])* pub fn $fn_name(&self, $($arg: $in_),*) -> Future<$out> { @@ -74,7 +79,7 @@ macro_rules! async_client_methods { panic!("Incorrect reply variant returned from protocol::Clientrpc; expected `{}`, but got {:?}", stringify!($fn_name), reply); } } - let reply = (self.0).rpc_async(request_variant!($fn_name $($arg),*)); + let reply = (self.0).rpc_async(__Request::$fn_name(($($arg,)*))); Future { future: reply, mapper: mapper, @@ -83,28 +88,113 @@ macro_rules! async_client_methods { )*); } -// Required because enum variants with no fields can't be suffixed by parens #[doc(hidden)] #[macro_export] -macro_rules! define_request { - ($(@($($finished:tt)*))* --) => (as_item!( - #[allow(non_camel_case_types)] - #[derive(Debug, Serialize, Deserialize)] - enum __Request { $($($finished)*),* } - );); - ($(@$finished:tt)* -- $name:ident() $($req:tt)*) => - (define_request!($(@$finished)* @($name) -- $($req)*);); - ($(@$finished:tt)* -- $name:ident $args: tt $($req:tt)*) => - (define_request!($(@$finished)* @($name $args) -- $($req)*);); - ($($started:tt)*) => (define_request!(-- $($started)*);); +macro_rules! impl_serialize { + ($impler:ident, $(@($name:ident $n:expr))* -- #($_n:expr) ) => ( + impl $crate::macros::serde::Serialize for $impler { + #[inline] + fn serialize(&self, serializer: &mut S) -> ::std::result::Result<(), S::Error> + where S: $crate::macros::serde::Serializer + { + match *self { + $( + $impler::$name(ref field) => + $crate::macros::serde::Serializer::visit_newtype_variant( + serializer, + stringify!($impler), + $n, + stringify!($name), + field, + ) + ),* + } + } + } + ); + // All args are wrapped in a tuple so we can use the newtype variant for each one. + ($impler:ident, $(@$finished:tt)* -- #($n:expr) $name:ident($field:ty) $($req:tt)*) => ( + impl_serialize!($impler, $(@$finished)* @($name $n) -- #($n + 1) $($req)*); + ); + // Entry + ($impler:ident, $($started:tt)*) => (impl_serialize!($impler, -- #(0) $($started)*);); } -// Required because enum variants with no fields can't be suffixed by parens #[doc(hidden)] #[macro_export] -macro_rules! request_variant { - ($x:ident) => (__Request::$x); - ($x:ident $($y:ident),+) => (__Request::$x($($y),+)); +macro_rules! impl_deserialize { + ($impler:ident, $(@($name:ident $n:expr))* -- #($_n:expr) ) => ( + impl $crate::macros::serde::Deserialize for $impler { + #[inline] + fn deserialize(deserializer: &mut D) + -> ::std::result::Result<$impler, D::Error> + where D: $crate::macros::serde::Deserializer + { + #[allow(non_camel_case_types)] + enum __Field { + $($name),* + } + impl $crate::macros::serde::Deserialize for __Field { + #[inline] + fn deserialize(deserializer: &mut D) + -> ::std::result::Result<__Field, D::Error> + where D: $crate::macros::serde::Deserializer + { + struct __FieldVisitor; + impl $crate::macros::serde::de::Visitor for __FieldVisitor { + type Value = __Field; + + fn visit_usize(&mut self, value: usize) + -> ::std::result::Result<__Field, E> + where E: $crate::macros::serde::de::Error, + { + $( + if value == $n { + return ::std::result::Result::Ok(__Field::$name); + } + )* + return ::std::result::Result::Err( + $crate::macros::serde::de::Error::syntax("expected a field") + ); + } + } + deserializer.visit_struct_field(__FieldVisitor) + } + } + + struct __Visitor; + impl $crate::macros::serde::de::EnumVisitor for __Visitor { + type Value = $impler; + + fn visit<__V>(&mut self, mut visitor: __V) + -> ::std::result::Result<$impler, __V::Error> + where __V: $crate::macros::serde::de::VariantVisitor + { + match try!(visitor.visit_variant()) { + $( + __Field::$name => { + let val = try!(visitor.visit_newtype()); + Ok($impler::$name(val)) + } + ),* + } + } + } + const VARIANTS: &'static [&'static str] = &[ + $( + stringify!($name) + ),* + ]; + deserializer.visit_enum(stringify!($impler), VARIANTS, __Visitor) + } + } + ); + // All args are wrapped in a tuple so we can use the newtype variant for each one. + ($impler:ident, $(@$finished:tt)* -- #($n:expr) $name:ident($field:ty) $($req:tt)*) => ( + impl_deserialize!($impler, $(@$finished)* @($name $n) -- #($n + 1) $($req)*); + ); + // Entry + ($impler:ident, $($started:tt)*) => (impl_deserialize!($impler, -- #(0) $($started)*);); } /// The main macro that creates RPC services. @@ -112,10 +202,7 @@ macro_rules! request_variant { /// Rpc methods are specified, mirroring trait syntax: /// /// ``` -/// # #![feature(custom_derive, plugin)] -/// # #![plugin(serde_macros)] /// # #[macro_use] extern crate tarpc; -/// # extern crate serde; /// # fn main() {} /// # service! { /// #[doc="Say hello"] @@ -200,7 +287,7 @@ macro_rules! service_inner { { } // none left to expand $( $(#[$attr:meta])* - rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty; + rpc $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty; )* ) => { #[doc="Defines the RPC service"] @@ -223,16 +310,28 @@ macro_rules! service_inner { )* } - define_request!($($fn_name($($in_),*))*); + #[allow(non_camel_case_types)] + #[derive(Debug)] + enum __Request { + $( + $fn_name(( $($in_,)* )) + ),* + } + + impl_serialize!(__Request, $($fn_name(($($in_),*)))*); + impl_deserialize!(__Request, $($fn_name(($($in_),*)))*); #[allow(non_camel_case_types)] - #[derive(Debug, Serialize, Deserialize)] + #[derive(Debug)] enum __Reply { $( $fn_name($out), )* } + impl_serialize!(__Reply, $($fn_name($out))*); + impl_deserialize!(__Reply, $($fn_name($out))*); + /// An asynchronous RPC call pub struct Future { future: $crate::protocol::Future<__Reply>, @@ -256,13 +355,13 @@ macro_rules! service_inner { where A: ::std::net::ToSocketAddrs, { let inner = try!($crate::protocol::Client::new(addr, timeout)); - Ok(Client(inner)) + ::std::result::Result::Ok(Client(inner)) } client_methods!( $( { $(#[$attr])* } - $fn_name($($arg: $in_),*) -> $out + $fn_name(($($arg,)*) : ($($in_,)*)) -> $out )* ); @@ -283,13 +382,13 @@ macro_rules! service_inner { where A: ::std::net::ToSocketAddrs, { let inner = try!($crate::protocol::Client::new(addr, timeout)); - Ok(AsyncClient(inner)) + ::std::result::Result::Ok(AsyncClient(inner)) } async_client_methods!( $( { $(#[$attr])* } - $fn_name($($arg: $in_),*) -> $out + $fn_name(($($arg,)*): ($($in_,)*)) -> $out )* ); @@ -310,9 +409,9 @@ macro_rules! service_inner { fn serve(&self, request: __Request) -> __Reply { match request { $( - request_variant!($fn_name $($arg),*) => + __Request::$fn_name(( $($arg,)* )) => __Reply::$fn_name((self.0).$fn_name($($arg),*)), - )* + )* } } } @@ -326,13 +425,13 @@ macro_rules! service_inner { S: 'static + Service { let server = ::std::sync::Arc::new(__Server(service)); - Ok(try!($crate::protocol::serve_async(addr, server, read_timeout))) + ::std::result::Result::Ok(try!($crate::protocol::serve_async(addr, server, read_timeout))) } } } +#[allow(dead_code)] // because we're just testing that the macro expansion compiles #[cfg(test)] -#[allow(dead_code)] // because we're testing that the macro expansion compiles mod syntax_test { // Tests a service definition with a fn that takes no args mod qux { @@ -340,7 +439,6 @@ mod syntax_test { rpc hello() -> String; } } - // Tests a service definition with an attribute. mod bar { service! { @@ -355,6 +453,7 @@ mod syntax_test { rpc ack(); rpc apply(foo: String) -> i32; rpc bi_consume(bar: String, baz: u64); + rpc bi_fn(bar: String, baz: u64) -> String; } } } @@ -382,6 +481,7 @@ mod functional_test { #[test] fn simple() { + let _ = env_logger::init(); let handle = serve( "localhost:0", Server, test_timeout()).unwrap(); let client = Client::new(handle.local_addr(), None).unwrap(); assert_eq!(3, client.add(1, 2).unwrap()); @@ -391,6 +491,7 @@ mod functional_test { #[test] fn simple_async() { + let _ = env_logger::init(); let handle = serve("localhost:0", Server, test_timeout()).unwrap(); let client = AsyncClient::new(handle.local_addr(), None).unwrap(); assert_eq!(3, client.add(1, 2).get().unwrap()); @@ -421,6 +522,21 @@ mod functional_test { fn serve_arc_server() { let _ = serve("localhost:0", ::std::sync::Arc::new(Server), None); } + + #[test] + fn serde() { + let _ = env_logger::init(); + use bincode; + + let request = __Request::add((1, 2)); + let ser = bincode::serde::serialize(&request, bincode::SizeLimit::Infinite).unwrap(); + let de = bincode::serde::deserialize(&ser).unwrap(); + if let __Request::add((1, 2)) = de { + // success + } else { + panic!("Expected __Request::add, got {:?}", de); + } + } } #[cfg(test)] diff --git a/tarpc/src/protocol/client.rs b/tarpc/src/protocol/client.rs index db04ea9..f1250b0 100644 --- a/tarpc/src/protocol/client.rs +++ b/tarpc/src/protocol/client.rs @@ -205,7 +205,7 @@ fn write(outbound: Receiver<(Request, Sender>)>, rpc_id: id, message: request, }; - debug!("Writer: calling rpc({:?})", id); + debug!("Writer: writing rpc, id={:?}", id); if let Err(e) = stream.serialize(&packet) { report_error(&tx, e.into()); // Typically we'd want to notify the client of any Err returned by remove_tx, but in @@ -235,7 +235,7 @@ fn read(requests: Arc>>, stream: TcpStream) { let mut stream = BufReader::new(stream); loop { - match stream.deserialize() : Result> { + match stream.deserialize::>() { Ok(packet) => { debug!("Client: received message, id={}", packet.rpc_id); requests.lock().expect(pos!()).complete_reply(packet); diff --git a/tarpc/src/protocol/mod.rs b/tarpc/src/protocol/mod.rs index 2ce4369..45f322f 100644 --- a/tarpc/src/protocol/mod.rs +++ b/tarpc/src/protocol/mod.rs @@ -12,7 +12,9 @@ use std::sync::Arc; mod client; mod server; +mod packet; +pub use self::packet::Packet; pub use self::client::{Client, Future}; pub use self::server::{Serve, ServeHandle, serve_async}; @@ -55,12 +57,6 @@ impl convert::From for Error { /// Return type of rpc calls: either the successful return value, or a client error. pub type Result = ::std::result::Result; -#[derive(Debug, Clone, Serialize, Deserialize)] -struct Packet { - rpc_id: u64, - message: T, -} - trait Deserialize: Read + Sized { fn deserialize(&mut self) -> Result { deserialize_from(self, SizeLimit::Infinite) @@ -93,27 +89,17 @@ mod test { Some(Duration::from_secs(1)) } - #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] - enum Request { - Increment, - } - - #[derive(Debug, PartialEq, Serialize, Deserialize)] - enum Reply { - Increment(u64), - } - struct Server { counter: Mutex, } impl Serve for Server { - type Request = Request; - type Reply = Reply; + type Request = (); + type Reply = u64; - fn serve(&self, _: Request) -> Reply { + fn serve(&self, _: ()) -> u64 { let mut counter = self.counter.lock().unwrap(); - let reply = Reply::Increment(*counter); + let reply = *counter; *counter += 1; reply } @@ -134,7 +120,7 @@ mod test { let _ = env_logger::init(); let server = Arc::new(Server::new()); let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap(); - let client: Client = Client::new(serve_handle.local_addr(), None).unwrap(); + let client: Client<(), u64> = Client::new(serve_handle.local_addr(), None).unwrap(); drop(client); serve_handle.shutdown(); } @@ -145,12 +131,11 @@ mod test { let server = Arc::new(Server::new()); let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); - let client = Client::new(addr, None).unwrap(); - assert_eq!(Reply::Increment(0), - client.rpc(Request::Increment).unwrap()); + // The explicit type is required so that it doesn't deserialize a u32 instead of u64 + let client: Client<(), u64> = Client::new(addr, None).unwrap(); + assert_eq!(0, client.rpc(()).unwrap()); assert_eq!(1, server.count()); - assert_eq!(Reply::Increment(1), - client.rpc(Request::Increment).unwrap()); + assert_eq!(1, client.rpc(()).unwrap()); assert_eq!(2, server.count()); drop(client); serve_handle.shutdown(); @@ -162,9 +147,9 @@ mod test { } impl Serve for BarrierServer { - type Request = Request; - type Reply = Reply; - fn serve(&self, request: Request) -> Reply { + type Request = (); + type Reply = u64; + fn serve(&self, request: ()) -> u64 { self.barrier.wait(); self.inner.serve(request) } @@ -189,9 +174,9 @@ mod test { let server = Arc::new(Server::new()); let serve_handle = serve_async("localhost:0", server, Some(Duration::new(0, 10))).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Client = Client::new(addr, None).unwrap(); + let client: Client<(), u64> = Client::new(addr, None).unwrap(); let thread = thread::spawn(move || serve_handle.shutdown()); - info!("force_shutdown:: rpc1: {:?}", client.rpc(Request::Increment)); + info!("force_shutdown:: rpc1: {:?}", client.rpc(())); thread.join().unwrap(); } @@ -201,14 +186,14 @@ mod test { let server = Arc::new(Server::new()); let serve_handle = serve_async("localhost:0", server, test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); - client.rpc(Request::Increment).unwrap(); + let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); + client.rpc(()).unwrap(); serve_handle.shutdown(); - match client.rpc(Request::Increment) { + match client.rpc(()) { Err(super::Error::ConnectionBroken) => {} // success otherwise => panic!("Expected Err(ConnectionBroken), got {:?}", otherwise), } - let _ = client.rpc(Request::Increment); // Test whether second failure hangs + let _ = client.rpc(()); // Test whether second failure hangs } #[test] @@ -219,11 +204,11 @@ mod test { let server = Arc::new(BarrierServer::new(concurrency)); let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Client = Client::new(addr, None).unwrap(); + let client: Client<(), u64> = Client::new(addr, None).unwrap(); pool.scoped(|scope| { for _ in 0..concurrency { let client = client.try_clone().unwrap(); - scope.execute(move || { client.rpc(Request::Increment).unwrap(); }); + scope.execute(move || { client.rpc(()).unwrap(); }); } }); assert_eq!(concurrency as u64, server.count()); @@ -237,12 +222,12 @@ mod test { let server = Arc::new(Server::new()); let serve_handle = serve_async("localhost:0", server.clone(), None).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Client = Client::new(addr, None).unwrap(); + let client: Client<(), u64> = Client::new(addr, None).unwrap(); // Drop future immediately; does the reader channel panic when sending? - client.rpc_async(Request::Increment); + client.rpc_async(()); // If the reader panicked, this won't succeed - client.rpc_async(Request::Increment); + client.rpc_async(()); drop(client); serve_handle.shutdown(); diff --git a/tarpc/src/protocol/packet.rs b/tarpc/src/protocol/packet.rs new file mode 100644 index 0000000..ae59472 --- /dev/null +++ b/tarpc/src/protocol/packet.rs @@ -0,0 +1,98 @@ +use serde::{Deserialize, Deserializer, Serialize, Serializer, de, ser}; +use std::marker::PhantomData; + +/// Packet shared between client and server. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Packet { + /// Packet id to map response to request. + pub rpc_id: u64, + /// Packet payload. + pub message: T, +} + +const PACKET: &'static str = "Packet"; +const RPC_ID: &'static str = "rpc_id"; +const MESSAGE: &'static str = "message"; + +impl Serialize for Packet { + #[inline] + fn serialize(&self, serializer: &mut S) -> Result<(), S::Error> + where S: Serializer + { + serializer.visit_struct(PACKET, MapVisitor { + value: self, + state: 0, + }) + } +} + +struct MapVisitor<'a, T: 'a> { + value: &'a Packet, + state: u8, +} + +impl <'a, T: Serialize> ser::MapVisitor for MapVisitor<'a, T> { + fn visit(&mut self, serializer: &mut S) -> Result, S::Error> + where S: Serializer + { + match self.state { + 0 => { + self.state += 1; + Ok(Some(try!(serializer.visit_struct_elt(RPC_ID, &self.value.rpc_id)))) + } + 1 => { + self.state += 1; + Ok(Some(try!(serializer.visit_struct_elt(MESSAGE, &self.value.message)))) + } + _ => { + Ok(None) + } + } + } +} + +impl Deserialize for Packet { + fn deserialize(deserializer: &mut D) -> Result + where D: Deserializer + { + const FIELDS: &'static [&'static str] = &[RPC_ID, MESSAGE]; + deserializer.visit_struct(PACKET, FIELDS, Visitor(PhantomData)) + } +} + +struct Visitor(PhantomData); + +impl de::Visitor for Visitor { + type Value = Packet; + + fn visit_seq(&mut self, mut visitor: V) -> Result, V::Error> + where V: de::SeqVisitor + { + let packet = Packet { + rpc_id: match try!(visitor.visit()) { + Some(rpc_id) => rpc_id, + None => return Err(de::Error::end_of_stream()), + }, + message: match try!(visitor.visit()) { + Some(message) => message, + None => return Err(de::Error::end_of_stream()), + }, + }; + try!(visitor.end()); + Ok(packet) + } +} + +#[cfg(test)] +extern crate env_logger; + +#[test] +fn serde() { + let _ = env_logger::init(); + use bincode; + + let packet = Packet { rpc_id: 1, message: () }; + let ser = bincode::serde::serialize(&packet, bincode::SizeLimit::Infinite).unwrap(); + let de = bincode::serde::deserialize(&ser); + assert_eq!(packet, de.unwrap()); +}