diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index d042383..2f6755b 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -53,7 +53,7 @@ //! ``` #![deny(missing_docs)] -#![feature(custom_derive, plugin)] +#![feature(custom_derive, plugin, test)] #![plugin(serde_macros)] extern crate serde; @@ -61,6 +61,7 @@ extern crate bincode; #[macro_use] extern crate log; extern crate crossbeam; +extern crate test; /// Provides the tarpc client and server, which implements the tarpc protocol. /// The protocol is defined by the implementation. diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs index 26c02a6..b5195f1 100644 --- a/tarpc/src/macros.rs +++ b/tarpc/src/macros.rs @@ -106,6 +106,50 @@ macro_rules! client_methods { )*); } +// Required because if-let can't be used with irrefutable patterns, so it needs +// to be special cased. +#[doc(hidden)] +#[macro_export] +macro_rules! async_client_methods { + ( + { $(#[$attr:meta])* } + $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty + ) => ( + $(#[$attr])* + pub fn $fn_name(&self, $($arg: $in_),*) -> Future<$out> { + fn mapper(reply: __Reply) -> $out { + let __Reply::$fn_name(reply) = reply; + reply + } + let reply = (self.0).rpc_async(&request_variant!($fn_name $($arg),*)); + Future { + future: reply, + mapper: mapper, + } + } + ); + ($( + { $(#[$attr:meta])* } + $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty + )*) => ( $( + $(#[$attr])* + pub fn $fn_name(&self, $($arg: $in_),*) -> Future<$out> { + fn mapper(reply: __Reply) -> $out { + if let __Reply::$fn_name(reply) = reply { + reply + } else { + panic!("Incorrect reply variant returned from protocol::Clientrpc; expected `{}`, but got {:?}", stringify!($fn_name), reply); + } + } + let reply = (self.0).rpc_async(&request_variant!($fn_name $($arg),*)); + Future { + future: reply, + mapper: mapper, + } + } + )*); +} + // Required because enum variants with no fields can't be suffixed by parens #[doc(hidden)] #[macro_export] @@ -220,6 +264,19 @@ macro_rules! rpc { )* } + /// An asynchronous RPC call + pub struct Future { + future: $crate::protocol::Future<__Reply>, + mapper: fn(__Reply) -> T, + } + + impl Future { + /// Block until the result of the RPC call is available + pub fn get(self) -> $crate::Result { + self.future.get().map(self.mapper) + } + } + #[doc="The client stub that makes RPC calls to the server."] pub struct Client($crate::protocol::Client<__Request, __Reply>); @@ -241,6 +298,27 @@ macro_rules! rpc { ); } + #[doc="The client stub that makes asynchronous RPC calls to the server."] + pub struct AsyncClient($crate::protocol::Client<__Request, __Reply>); + + impl AsyncClient { + #[doc="Create a new asynchronous client that connects to the given address."] + pub fn new(addr: A, timeout: ::std::option::Option<::std::time::Duration>) + -> $crate::Result + where A: ::std::net::ToSocketAddrs, + { + let inner = try!($crate::protocol::Client::new(addr, timeout)); + Ok(AsyncClient(inner)) + } + + async_client_methods!( + $( + { $(#[$attr])* } + $fn_name($($arg: $in_),*) -> $out + )* + ); + } + struct __Server(S); impl $crate::protocol::Serve for __Server @@ -277,7 +355,9 @@ macro_rules! rpc { #[cfg(test)] #[allow(dead_code)] mod test { + extern crate env_logger; use std::time::Duration; + use test::Bencher; fn test_timeout() -> Option { Some(Duration::from_secs(5)) @@ -322,17 +402,27 @@ mod test { } #[test] - fn simple_test() { - println!("Starting"); - let addr = "127.0.0.1:9000"; - let shutdown = my_server::serve(addr, Server, test_timeout()).unwrap(); - let client = Client::new(addr, None).unwrap(); + fn simple() { + let handle = my_server::serve( "localhost:0", Server, test_timeout()).unwrap(); + let client = Client::new(handle.local_addr(), None).unwrap(); assert_eq!(3, client.add(1, 2).unwrap()); let foo = Foo { message: "Adam".into() }; let want = Foo { message: format!("Hello, {}", &foo.message) }; assert_eq!(want, client.hello(Foo { message: "Adam".into() }).unwrap()); drop(client); - shutdown.shutdown(); + handle.shutdown(); + } + + #[test] + fn simple_async() { + let handle = my_server::serve("localhost:0", Server, test_timeout()).unwrap(); + let client = AsyncClient::new(handle.local_addr(), None).unwrap(); + assert_eq!(3, client.add(1, 2).get().unwrap()); + let foo = Foo { message: "Adam".into() }; + let want = Foo { message: format!("Hello, {}", &foo.message) }; + assert_eq!(want, client.hello(Foo { message: "Adam".into() }).get().unwrap()); + drop(client); + handle.shutdown(); } // Tests a service definition with a fn that takes no args @@ -368,7 +458,6 @@ mod test { } service { - #[doc="Hello bob"] #[inline(always)] rpc baz(s: String) -> HashMap; } @@ -400,4 +489,38 @@ mod test { fn debug() { println!("{:?}", baz::Debuggable); } + + rpc! { + mod hello { + service { + rpc hello(s: String) -> String; + } + } + } + + struct HelloServer; + impl hello::Service for HelloServer { + fn hello(&self, s: String) -> String { + format!("Hello, {}!", s) + } + } + + #[bench] + fn hello(bencher: &mut Bencher) { + let _ = env_logger::init(); + let handle = hello::serve("localhost:0", HelloServer, None).unwrap(); + let client = hello::AsyncClient::new(handle.local_addr(), None).unwrap(); + let concurrency = 100; + let mut rpcs = Vec::with_capacity(concurrency); + bencher.iter(|| { + for _ in 0..concurrency { + rpcs.push(client.hello("Bob".into())); + } + for _ in 0..concurrency { + rpcs.pop().unwrap().get().unwrap(); + } + }); + drop(client); + handle.shutdown(); + } } diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index b6ebfef..bd8b67d 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -17,7 +17,7 @@ use std::marker::PhantomData; use std::mem; use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs}; use std::sync::{Arc, Condvar, Mutex}; -use std::sync::mpsc::{Sender, TryRecvError, channel}; +use std::sync::mpsc::{Receiver, Sender, TryRecvError, channel}; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use std::thread::{self, JoinHandle}; @@ -59,6 +59,22 @@ impl convert::From for Error { /// Return type of rpc calls: either the successful return value, or a client error. pub type Result = ::std::result::Result; +/// An asynchronous RPC call +pub struct Future { + rx: Result>, + requests: Arc>> +} + +impl Future { + /// Block until the result of the RPC call is available + pub fn get(self) -> Result { + let requests = self.requests; + try!(self.rx) + .recv() + .map_err(|_| requests.lock().unwrap().get_error()) + } +} + struct InflightRpcs { count: Mutex, cvar: Condvar, @@ -214,9 +230,9 @@ impl ServeHandle { /// gracefully close open connections. pub fn shutdown(self) { info!("ServeHandle: attempting to shut down the server."); - self.tx.send(()).expect(&line!().to_string()); + self.tx.send(()).unwrap(); if let Ok(_) = TcpStream::connect(self.addr) { - self.join_handle.join().expect(&line!().to_string()); + self.join_handle.join().unwrap(); } else { warn!("ServeHandle: best effort shutdown of serve thread failed"); } @@ -343,7 +359,9 @@ impl RpcFutures { fn complete_reply(&mut self, id: u64, reply: Reply) { if let Some(tx) = self.0.as_mut().unwrap().remove(&id) { - tx.send(reply).unwrap(); + if let Err(e) = tx.send(reply) { + info!("Reader: could not complete reply: {:?}", e); + } } else { warn!("RpcFutures: expected sender for id {} but got None!", id); } @@ -435,8 +453,7 @@ impl Client }) } - /// Run the specified rpc method on the server this client is connected to - pub fn rpc(&self, request: &Request) -> Result + fn rpc_internal(&self, request: &Request) -> Result> where Request: serde::ser::Serialize + fmt::Debug + Send + 'static { let (tx, rx) = channel(); @@ -458,14 +475,25 @@ impl Client err); try!(self.requests.lock().unwrap().remove_tx(id)); } - debug!("Client: finishing rpc({:?})", request); - drop(state); - match rx.recv() { - Ok(msg) => Ok(msg), - Err(_) => { - debug!("locking requests map"); - Err(self.requests.lock().unwrap().get_error()) - } + Ok(rx) + } + + /// Run the specified rpc method on the server this client is connected to + pub fn rpc(&self, request: &Request) -> Result + where Request: serde::ser::Serialize + fmt::Debug + Send + 'static + { + try!(self.rpc_internal(request)) + .recv() + .map_err(|_| self.requests.lock().unwrap().get_error()) + } + + /// Asynchronously run the specified rpc method on the server this client is connected to + pub fn rpc_async(&self, request: &Request) -> Future + where Request: serde::ser::Serialize + fmt::Debug + Send + 'static + { + Future { + rx: self.rpc_internal(request), + requests: self.requests.clone(), } } } @@ -539,9 +567,8 @@ mod test { fn handle() { let _ = env_logger::init(); let server = Arc::new(Server::new()); - let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); - let client: Client = Client::new(serve_handle.local_addr().clone(), None) - .expect(&line!().to_string()); + let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap(); + let client: Client = Client::new(serve_handle.local_addr(), None).unwrap(); drop(client); serve_handle.shutdown(); } @@ -550,7 +577,7 @@ mod test { fn simple() { let _ = env_logger::init(); let server = Arc::new(Server::new()); - let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); + let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); let client = Client::new(addr, None).unwrap(); assert_eq!(Reply::Increment(0), @@ -594,12 +621,11 @@ mod test { fn force_shutdown() { let _ = env_logger::init(); let server = Arc::new(Server::new()); - let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap(); + let serve_handle = serve_async("localhost:0", server, Some(Duration::new(0, 10))).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); + let client: Client = Client::new(addr, None).unwrap(); let thread = thread::spawn(move || serve_handle.shutdown()); - info!("force_shutdown:: rpc1: {:?}", - client.rpc(&Request::Increment)); + info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment)); thread.join().unwrap(); } @@ -607,7 +633,7 @@ mod test { fn client_failed_rpc() { let _ = env_logger::init(); let server = Arc::new(Server::new()); - let serve_handle = serve_async("0.0.0.0:0", server, test_timeout()).unwrap(); + let serve_handle = serve_async("localhost:0", server, test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); serve_handle.shutdown(); @@ -621,19 +647,20 @@ mod test { #[test] fn concurrent() { let _ = env_logger::init(); - let server = Arc::new(BarrierServer::new(10)); - let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); + let concurrency = 10; + let server = Arc::new(BarrierServer::new(concurrency)); + let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); let mut join_handles = vec![]; - for _ in 0..10 { + for _ in 0..concurrency { let my_client = client.clone(); join_handles.push(thread::spawn(move || my_client.rpc(&Request::Increment).unwrap())); } for handle in join_handles.into_iter() { handle.join().unwrap(); } - assert_eq!(10, server.count()); + assert_eq!(concurrency as u64, server.count()); let client = match Arc::try_unwrap(client) { Err(_) => panic!("couldn't unwrap arc"), Ok(c) => c, @@ -641,4 +668,22 @@ mod test { drop(client); serve_handle.shutdown(); } + + #[test] + fn async() { + let _ = env_logger::init(); + let server = Arc::new(Server::new()); + let serve_handle = serve_async("localhost:0", server.clone(), None).unwrap(); + let addr = serve_handle.local_addr().clone(); + let client: Client = Client::new(addr, None).unwrap(); + + // Drop future immediately; does the reader channel panic when sending? + client.rpc_async(&Request::Increment); + // If the reader panicked, this won't succeed + client.rpc_async(&Request::Increment); + + drop(client); + serve_handle.shutdown(); + assert_eq!(server.count(), 2); + } }