Merge pull request #8 from tikue/stabilize

Manually implement serialization code...
This commit is contained in:
shaladdle
2016-02-14 20:19:29 -08:00
6 changed files with 292 additions and 96 deletions

View File

@@ -14,8 +14,7 @@ description = "An RPC framework for Rust with a focus on ease of use."
bincode = "^0.4.0"
log = "^0.3.5"
scoped-pool = "^0.1.4"
serde = "^0.6.11"
serde_macros = "^0.6.11"
serde = "^0.6.13"
[dev-dependencies]
lazy_static = "^0.1.15"

View File

@@ -8,10 +8,7 @@
//! Example usage:
//!
//! ```
//! # #![feature(custom_derive, plugin)]
//! # #![plugin(serde_macros)]
//! # #[macro_use] extern crate tarpc;
//! # extern crate serde;
//! mod my_server {
//! service! {
//! rpc hello(name: String) -> String;
@@ -48,17 +45,18 @@
//! ```
#![deny(missing_docs)]
#![feature(custom_derive, plugin, test, type_ascription)]
#![plugin(serde_macros)]
#![cfg_attr(test, feature(test))]
extern crate serde;
extern crate bincode;
#[cfg(test)]
#[macro_use]
extern crate lazy_static;
#[macro_use]
extern crate log;
extern crate scoped_pool;
#[cfg(test)]
#[macro_use]
extern crate lazy_static;
#[cfg(test)]
extern crate test;
macro_rules! pos {

View File

@@ -3,9 +3,14 @@
// Licensed under the MIT License, <LICENSE or http://opensource.org/licenses/MIT>.
// This file may not be copied, modified, or distributed except according to those terms.
#[doc(hidden)]
#[macro_export]
macro_rules! as_item { ($i:item) => {$i} }
/// Serde re-exports required by macros. Not for general use.
pub mod serde {
pub use serde::{Deserialize, Deserializer, Serialize, Serializer};
/// Deserialization re-exports required by macros. Not for general use.
pub mod de {
pub use serde::de::{EnumVisitor, Error, Visitor, VariantVisitor};
}
}
// Required because if-let can't be used with irrefutable patterns, so it needs
// to be special cased.
@@ -14,24 +19,24 @@ macro_rules! as_item { ($i:item) => {$i} }
macro_rules! client_methods {
(
{ $(#[$attr:meta])* }
$fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty
$fn_name:ident( ($($arg:ident,)*) : ($($in_:ty,)*) ) -> $out:ty
) => (
$(#[$attr])*
pub fn $fn_name(&self, $($arg: $in_),*) -> $crate::Result<$out> {
let reply = try!((self.0).rpc(request_variant!($fn_name $($arg),*)));
let reply = try!((self.0).rpc(__Request::$fn_name(($($arg,)*))));
let __Reply::$fn_name(reply) = reply;
Ok(reply)
::std::result::Result::Ok(reply)
}
);
($(
{ $(#[$attr:meta])* }
$fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty
$fn_name:ident( ($( $arg:ident,)*) : ($($in_:ty, )*) ) -> $out:ty
)*) => ( $(
$(#[$attr])*
pub fn $fn_name(&self, $($arg: $in_),*) -> $crate::Result<$out> {
let reply = try!((self.0).rpc(request_variant!($fn_name $($arg),*)));
let reply = try!((self.0).rpc(__Request::$fn_name(($($arg,)*))));
if let __Reply::$fn_name(reply) = reply {
Ok(reply)
::std::result::Result::Ok(reply)
} else {
panic!("Incorrect reply variant returned from protocol::Clientrpc; expected `{}`, but got {:?}", stringify!($fn_name), reply);
}
@@ -46,7 +51,7 @@ macro_rules! client_methods {
macro_rules! async_client_methods {
(
{ $(#[$attr:meta])* }
$fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty
$fn_name:ident( ($( $arg:ident, )*) : ($( $in_:ty, )*) ) -> $out:ty
) => (
$(#[$attr])*
pub fn $fn_name(&self, $($arg: $in_),*) -> Future<$out> {
@@ -54,7 +59,7 @@ macro_rules! async_client_methods {
let __Reply::$fn_name(reply) = reply;
reply
}
let reply = (self.0).rpc_async(request_variant!($fn_name $($arg),*));
let reply = (self.0).rpc_async(__Request::$fn_name(($($arg,)*)));
Future {
future: reply,
mapper: mapper,
@@ -63,7 +68,7 @@ macro_rules! async_client_methods {
);
($(
{ $(#[$attr:meta])* }
$fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty
$fn_name:ident( ($( $arg:ident, )*) : ($( $in_:ty, )*) ) -> $out:ty
)*) => ( $(
$(#[$attr])*
pub fn $fn_name(&self, $($arg: $in_),*) -> Future<$out> {
@@ -74,7 +79,7 @@ macro_rules! async_client_methods {
panic!("Incorrect reply variant returned from protocol::Clientrpc; expected `{}`, but got {:?}", stringify!($fn_name), reply);
}
}
let reply = (self.0).rpc_async(request_variant!($fn_name $($arg),*));
let reply = (self.0).rpc_async(__Request::$fn_name(($($arg,)*)));
Future {
future: reply,
mapper: mapper,
@@ -83,28 +88,113 @@ macro_rules! async_client_methods {
)*);
}
// Required because enum variants with no fields can't be suffixed by parens
#[doc(hidden)]
#[macro_export]
macro_rules! define_request {
($(@($($finished:tt)*))* --) => (as_item!(
#[allow(non_camel_case_types)]
#[derive(Debug, Serialize, Deserialize)]
enum __Request { $($($finished)*),* }
););
($(@$finished:tt)* -- $name:ident() $($req:tt)*) =>
(define_request!($(@$finished)* @($name) -- $($req)*););
($(@$finished:tt)* -- $name:ident $args: tt $($req:tt)*) =>
(define_request!($(@$finished)* @($name $args) -- $($req)*););
($($started:tt)*) => (define_request!(-- $($started)*););
macro_rules! impl_serialize {
($impler:ident, $(@($name:ident $n:expr))* -- #($_n:expr) ) => (
impl $crate::macros::serde::Serialize for $impler {
#[inline]
fn serialize<S>(&self, serializer: &mut S) -> ::std::result::Result<(), S::Error>
where S: $crate::macros::serde::Serializer
{
match *self {
$(
$impler::$name(ref field) =>
$crate::macros::serde::Serializer::visit_newtype_variant(
serializer,
stringify!($impler),
$n,
stringify!($name),
field,
)
),*
}
}
}
);
// All args are wrapped in a tuple so we can use the newtype variant for each one.
($impler:ident, $(@$finished:tt)* -- #($n:expr) $name:ident($field:ty) $($req:tt)*) => (
impl_serialize!($impler, $(@$finished)* @($name $n) -- #($n + 1) $($req)*);
);
// Entry
($impler:ident, $($started:tt)*) => (impl_serialize!($impler, -- #(0) $($started)*););
}
// Required because enum variants with no fields can't be suffixed by parens
#[doc(hidden)]
#[macro_export]
macro_rules! request_variant {
($x:ident) => (__Request::$x);
($x:ident $($y:ident),+) => (__Request::$x($($y),+));
macro_rules! impl_deserialize {
($impler:ident, $(@($name:ident $n:expr))* -- #($_n:expr) ) => (
impl $crate::macros::serde::Deserialize for $impler {
#[inline]
fn deserialize<D>(deserializer: &mut D)
-> ::std::result::Result<$impler, D::Error>
where D: $crate::macros::serde::Deserializer
{
#[allow(non_camel_case_types)]
enum __Field {
$($name),*
}
impl $crate::macros::serde::Deserialize for __Field {
#[inline]
fn deserialize<D>(deserializer: &mut D)
-> ::std::result::Result<__Field, D::Error>
where D: $crate::macros::serde::Deserializer
{
struct __FieldVisitor;
impl $crate::macros::serde::de::Visitor for __FieldVisitor {
type Value = __Field;
fn visit_usize<E>(&mut self, value: usize)
-> ::std::result::Result<__Field, E>
where E: $crate::macros::serde::de::Error,
{
$(
if value == $n {
return ::std::result::Result::Ok(__Field::$name);
}
)*
return ::std::result::Result::Err(
$crate::macros::serde::de::Error::syntax("expected a field")
);
}
}
deserializer.visit_struct_field(__FieldVisitor)
}
}
struct __Visitor;
impl $crate::macros::serde::de::EnumVisitor for __Visitor {
type Value = $impler;
fn visit<__V>(&mut self, mut visitor: __V)
-> ::std::result::Result<$impler, __V::Error>
where __V: $crate::macros::serde::de::VariantVisitor
{
match try!(visitor.visit_variant()) {
$(
__Field::$name => {
let val = try!(visitor.visit_newtype());
Ok($impler::$name(val))
}
),*
}
}
}
const VARIANTS: &'static [&'static str] = &[
$(
stringify!($name)
),*
];
deserializer.visit_enum(stringify!($impler), VARIANTS, __Visitor)
}
}
);
// All args are wrapped in a tuple so we can use the newtype variant for each one.
($impler:ident, $(@$finished:tt)* -- #($n:expr) $name:ident($field:ty) $($req:tt)*) => (
impl_deserialize!($impler, $(@$finished)* @($name $n) -- #($n + 1) $($req)*);
);
// Entry
($impler:ident, $($started:tt)*) => (impl_deserialize!($impler, -- #(0) $($started)*););
}
/// The main macro that creates RPC services.
@@ -112,10 +202,7 @@ macro_rules! request_variant {
/// Rpc methods are specified, mirroring trait syntax:
///
/// ```
/// # #![feature(custom_derive, plugin)]
/// # #![plugin(serde_macros)]
/// # #[macro_use] extern crate tarpc;
/// # extern crate serde;
/// # fn main() {}
/// # service! {
/// #[doc="Say hello"]
@@ -200,7 +287,7 @@ macro_rules! service_inner {
{ } // none left to expand
$(
$(#[$attr:meta])*
rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty;
rpc $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty;
)*
) => {
#[doc="Defines the RPC service"]
@@ -223,16 +310,28 @@ macro_rules! service_inner {
)*
}
define_request!($($fn_name($($in_),*))*);
#[allow(non_camel_case_types)]
#[derive(Debug)]
enum __Request {
$(
$fn_name(( $($in_,)* ))
),*
}
impl_serialize!(__Request, $($fn_name(($($in_),*)))*);
impl_deserialize!(__Request, $($fn_name(($($in_),*)))*);
#[allow(non_camel_case_types)]
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug)]
enum __Reply {
$(
$fn_name($out),
)*
}
impl_serialize!(__Reply, $($fn_name($out))*);
impl_deserialize!(__Reply, $($fn_name($out))*);
/// An asynchronous RPC call
pub struct Future<T> {
future: $crate::protocol::Future<__Reply>,
@@ -256,13 +355,13 @@ macro_rules! service_inner {
where A: ::std::net::ToSocketAddrs,
{
let inner = try!($crate::protocol::Client::new(addr, timeout));
Ok(Client(inner))
::std::result::Result::Ok(Client(inner))
}
client_methods!(
$(
{ $(#[$attr])* }
$fn_name($($arg: $in_),*) -> $out
$fn_name(($($arg,)*) : ($($in_,)*)) -> $out
)*
);
@@ -283,13 +382,13 @@ macro_rules! service_inner {
where A: ::std::net::ToSocketAddrs,
{
let inner = try!($crate::protocol::Client::new(addr, timeout));
Ok(AsyncClient(inner))
::std::result::Result::Ok(AsyncClient(inner))
}
async_client_methods!(
$(
{ $(#[$attr])* }
$fn_name($($arg: $in_),*) -> $out
$fn_name(($($arg,)*): ($($in_,)*)) -> $out
)*
);
@@ -310,9 +409,9 @@ macro_rules! service_inner {
fn serve(&self, request: __Request) -> __Reply {
match request {
$(
request_variant!($fn_name $($arg),*) =>
__Request::$fn_name(( $($arg,)* )) =>
__Reply::$fn_name((self.0).$fn_name($($arg),*)),
)*
)*
}
}
}
@@ -326,13 +425,13 @@ macro_rules! service_inner {
S: 'static + Service
{
let server = ::std::sync::Arc::new(__Server(service));
Ok(try!($crate::protocol::serve_async(addr, server, read_timeout)))
::std::result::Result::Ok(try!($crate::protocol::serve_async(addr, server, read_timeout)))
}
}
}
#[allow(dead_code)] // because we're just testing that the macro expansion compiles
#[cfg(test)]
#[allow(dead_code)] // because we're testing that the macro expansion compiles
mod syntax_test {
// Tests a service definition with a fn that takes no args
mod qux {
@@ -340,7 +439,6 @@ mod syntax_test {
rpc hello() -> String;
}
}
// Tests a service definition with an attribute.
mod bar {
service! {
@@ -355,6 +453,7 @@ mod syntax_test {
rpc ack();
rpc apply(foo: String) -> i32;
rpc bi_consume(bar: String, baz: u64);
rpc bi_fn(bar: String, baz: u64) -> String;
}
}
}
@@ -382,6 +481,7 @@ mod functional_test {
#[test]
fn simple() {
let _ = env_logger::init();
let handle = serve( "localhost:0", Server, test_timeout()).unwrap();
let client = Client::new(handle.local_addr(), None).unwrap();
assert_eq!(3, client.add(1, 2).unwrap());
@@ -391,6 +491,7 @@ mod functional_test {
#[test]
fn simple_async() {
let _ = env_logger::init();
let handle = serve("localhost:0", Server, test_timeout()).unwrap();
let client = AsyncClient::new(handle.local_addr(), None).unwrap();
assert_eq!(3, client.add(1, 2).get().unwrap());
@@ -421,6 +522,21 @@ mod functional_test {
fn serve_arc_server() {
let _ = serve("localhost:0", ::std::sync::Arc::new(Server), None);
}
#[test]
fn serde() {
let _ = env_logger::init();
use bincode;
let request = __Request::add((1, 2));
let ser = bincode::serde::serialize(&request, bincode::SizeLimit::Infinite).unwrap();
let de = bincode::serde::deserialize(&ser).unwrap();
if let __Request::add((1, 2)) = de {
// success
} else {
panic!("Expected __Request::add, got {:?}", de);
}
}
}
#[cfg(test)]

View File

@@ -205,7 +205,7 @@ fn write<Request, Reply>(outbound: Receiver<(Request, Sender<Result<Reply>>)>,
rpc_id: id,
message: request,
};
debug!("Writer: calling rpc({:?})", id);
debug!("Writer: writing rpc, id={:?}", id);
if let Err(e) = stream.serialize(&packet) {
report_error(&tx, e.into());
// Typically we'd want to notify the client of any Err returned by remove_tx, but in
@@ -235,7 +235,7 @@ fn read<Reply>(requests: Arc<Mutex<RpcFutures<Reply>>>, stream: TcpStream)
{
let mut stream = BufReader::new(stream);
loop {
match stream.deserialize() : Result<Packet<Reply>> {
match stream.deserialize::<Packet<Reply>>() {
Ok(packet) => {
debug!("Client: received message, id={}", packet.rpc_id);
requests.lock().expect(pos!()).complete_reply(packet);

View File

@@ -12,7 +12,9 @@ use std::sync::Arc;
mod client;
mod server;
mod packet;
pub use self::packet::Packet;
pub use self::client::{Client, Future};
pub use self::server::{Serve, ServeHandle, serve_async};
@@ -55,12 +57,6 @@ impl convert::From<io::Error> for Error {
/// Return type of rpc calls: either the successful return value, or a client error.
pub type Result<T> = ::std::result::Result<T, Error>;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Packet<T> {
rpc_id: u64,
message: T,
}
trait Deserialize: Read + Sized {
fn deserialize<T: serde::Deserialize>(&mut self) -> Result<T> {
deserialize_from(self, SizeLimit::Infinite)
@@ -93,27 +89,17 @@ mod test {
Some(Duration::from_secs(1))
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
enum Request {
Increment,
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
enum Reply {
Increment(u64),
}
struct Server {
counter: Mutex<u64>,
}
impl Serve for Server {
type Request = Request;
type Reply = Reply;
type Request = ();
type Reply = u64;
fn serve(&self, _: Request) -> Reply {
fn serve(&self, _: ()) -> u64 {
let mut counter = self.counter.lock().unwrap();
let reply = Reply::Increment(*counter);
let reply = *counter;
*counter += 1;
reply
}
@@ -134,7 +120,7 @@ mod test {
let _ = env_logger::init();
let server = Arc::new(Server::new());
let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap();
let client: Client<Request, Reply> = Client::new(serve_handle.local_addr(), None).unwrap();
let client: Client<(), u64> = Client::new(serve_handle.local_addr(), None).unwrap();
drop(client);
serve_handle.shutdown();
}
@@ -145,12 +131,11 @@ mod test {
let server = Arc::new(Server::new());
let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap();
let addr = serve_handle.local_addr().clone();
let client = Client::new(addr, None).unwrap();
assert_eq!(Reply::Increment(0),
client.rpc(Request::Increment).unwrap());
// The explicit type is required so that it doesn't deserialize a u32 instead of u64
let client: Client<(), u64> = Client::new(addr, None).unwrap();
assert_eq!(0, client.rpc(()).unwrap());
assert_eq!(1, server.count());
assert_eq!(Reply::Increment(1),
client.rpc(Request::Increment).unwrap());
assert_eq!(1, client.rpc(()).unwrap());
assert_eq!(2, server.count());
drop(client);
serve_handle.shutdown();
@@ -162,9 +147,9 @@ mod test {
}
impl Serve for BarrierServer {
type Request = Request;
type Reply = Reply;
fn serve(&self, request: Request) -> Reply {
type Request = ();
type Reply = u64;
fn serve(&self, request: ()) -> u64 {
self.barrier.wait();
self.inner.serve(request)
}
@@ -189,9 +174,9 @@ mod test {
let server = Arc::new(Server::new());
let serve_handle = serve_async("localhost:0", server, Some(Duration::new(0, 10))).unwrap();
let addr = serve_handle.local_addr().clone();
let client: Client<Request, Reply> = Client::new(addr, None).unwrap();
let client: Client<(), u64> = Client::new(addr, None).unwrap();
let thread = thread::spawn(move || serve_handle.shutdown());
info!("force_shutdown:: rpc1: {:?}", client.rpc(Request::Increment));
info!("force_shutdown:: rpc1: {:?}", client.rpc(()));
thread.join().unwrap();
}
@@ -201,14 +186,14 @@ mod test {
let server = Arc::new(Server::new());
let serve_handle = serve_async("localhost:0", server, test_timeout()).unwrap();
let addr = serve_handle.local_addr().clone();
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None).unwrap());
client.rpc(Request::Increment).unwrap();
let client: Arc<Client<(), u64>> = Arc::new(Client::new(addr, None).unwrap());
client.rpc(()).unwrap();
serve_handle.shutdown();
match client.rpc(Request::Increment) {
match client.rpc(()) {
Err(super::Error::ConnectionBroken) => {} // success
otherwise => panic!("Expected Err(ConnectionBroken), got {:?}", otherwise),
}
let _ = client.rpc(Request::Increment); // Test whether second failure hangs
let _ = client.rpc(()); // Test whether second failure hangs
}
#[test]
@@ -219,11 +204,11 @@ mod test {
let server = Arc::new(BarrierServer::new(concurrency));
let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap();
let addr = serve_handle.local_addr().clone();
let client: Client<Request, Reply> = Client::new(addr, None).unwrap();
let client: Client<(), u64> = Client::new(addr, None).unwrap();
pool.scoped(|scope| {
for _ in 0..concurrency {
let client = client.try_clone().unwrap();
scope.execute(move || { client.rpc(Request::Increment).unwrap(); });
scope.execute(move || { client.rpc(()).unwrap(); });
}
});
assert_eq!(concurrency as u64, server.count());
@@ -237,12 +222,12 @@ mod test {
let server = Arc::new(Server::new());
let serve_handle = serve_async("localhost:0", server.clone(), None).unwrap();
let addr = serve_handle.local_addr().clone();
let client: Client<Request, Reply> = Client::new(addr, None).unwrap();
let client: Client<(), u64> = Client::new(addr, None).unwrap();
// Drop future immediately; does the reader channel panic when sending?
client.rpc_async(Request::Increment);
client.rpc_async(());
// If the reader panicked, this won't succeed
client.rpc_async(Request::Increment);
client.rpc_async(());
drop(client);
serve_handle.shutdown();

View File

@@ -0,0 +1,98 @@
use serde::{Deserialize, Deserializer, Serialize, Serializer, de, ser};
use std::marker::PhantomData;
/// Packet shared between client and server.
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Packet<T> {
/// Packet id to map response to request.
pub rpc_id: u64,
/// Packet payload.
pub message: T,
}
const PACKET: &'static str = "Packet";
const RPC_ID: &'static str = "rpc_id";
const MESSAGE: &'static str = "message";
impl<T: Serialize> Serialize for Packet<T> {
#[inline]
fn serialize<S>(&self, serializer: &mut S) -> Result<(), S::Error>
where S: Serializer
{
serializer.visit_struct(PACKET, MapVisitor {
value: self,
state: 0,
})
}
}
struct MapVisitor<'a, T: 'a> {
value: &'a Packet<T>,
state: u8,
}
impl <'a, T: Serialize> ser::MapVisitor for MapVisitor<'a, T> {
fn visit<S>(&mut self, serializer: &mut S) -> Result<Option<()>, S::Error>
where S: Serializer
{
match self.state {
0 => {
self.state += 1;
Ok(Some(try!(serializer.visit_struct_elt(RPC_ID, &self.value.rpc_id))))
}
1 => {
self.state += 1;
Ok(Some(try!(serializer.visit_struct_elt(MESSAGE, &self.value.message))))
}
_ => {
Ok(None)
}
}
}
}
impl<T: Deserialize> Deserialize for Packet<T> {
fn deserialize<D>(deserializer: &mut D) -> Result<Self, D::Error>
where D: Deserializer
{
const FIELDS: &'static [&'static str] = &[RPC_ID, MESSAGE];
deserializer.visit_struct(PACKET, FIELDS, Visitor(PhantomData))
}
}
struct Visitor<T>(PhantomData<T>);
impl<T: Deserialize> de::Visitor for Visitor<T> {
type Value = Packet<T>;
fn visit_seq<V>(&mut self, mut visitor: V) -> Result<Packet<T>, V::Error>
where V: de::SeqVisitor
{
let packet = Packet {
rpc_id: match try!(visitor.visit()) {
Some(rpc_id) => rpc_id,
None => return Err(de::Error::end_of_stream()),
},
message: match try!(visitor.visit()) {
Some(message) => message,
None => return Err(de::Error::end_of_stream()),
},
};
try!(visitor.end());
Ok(packet)
}
}
#[cfg(test)]
extern crate env_logger;
#[test]
fn serde() {
let _ = env_logger::init();
use bincode;
let packet = Packet { rpc_id: 1, message: () };
let ser = bincode::serde::serialize(&packet, bincode::SizeLimit::Infinite).unwrap();
let de = bincode::serde::deserialize(&ser);
assert_eq!(packet, de.unwrap());
}