mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-01-06 19:45:25 +01:00
Merge branch 'async' into 'master'
Add an AsyncClient generated by rpc! Returns `Future<T>` instead of `Result<T>`. `Future<T>` has one method, `get()`, which returns a `Result<T>`. See merge request !15
This commit is contained in:
@@ -53,7 +53,7 @@
|
||||
//! ```
|
||||
|
||||
#![deny(missing_docs)]
|
||||
#![feature(custom_derive, plugin)]
|
||||
#![feature(custom_derive, plugin, test)]
|
||||
#![plugin(serde_macros)]
|
||||
|
||||
extern crate serde;
|
||||
@@ -61,6 +61,7 @@ extern crate bincode;
|
||||
#[macro_use]
|
||||
extern crate log;
|
||||
extern crate crossbeam;
|
||||
extern crate test;
|
||||
|
||||
/// Provides the tarpc client and server, which implements the tarpc protocol.
|
||||
/// The protocol is defined by the implementation.
|
||||
|
||||
@@ -106,6 +106,50 @@ macro_rules! client_methods {
|
||||
)*);
|
||||
}
|
||||
|
||||
// Required because if-let can't be used with irrefutable patterns, so it needs
|
||||
// to be special cased.
|
||||
#[doc(hidden)]
|
||||
#[macro_export]
|
||||
macro_rules! async_client_methods {
|
||||
(
|
||||
{ $(#[$attr:meta])* }
|
||||
$fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty
|
||||
) => (
|
||||
$(#[$attr])*
|
||||
pub fn $fn_name(&self, $($arg: $in_),*) -> Future<$out> {
|
||||
fn mapper(reply: __Reply) -> $out {
|
||||
let __Reply::$fn_name(reply) = reply;
|
||||
reply
|
||||
}
|
||||
let reply = (self.0).rpc_async(&request_variant!($fn_name $($arg),*));
|
||||
Future {
|
||||
future: reply,
|
||||
mapper: mapper,
|
||||
}
|
||||
}
|
||||
);
|
||||
($(
|
||||
{ $(#[$attr:meta])* }
|
||||
$fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty
|
||||
)*) => ( $(
|
||||
$(#[$attr])*
|
||||
pub fn $fn_name(&self, $($arg: $in_),*) -> Future<$out> {
|
||||
fn mapper(reply: __Reply) -> $out {
|
||||
if let __Reply::$fn_name(reply) = reply {
|
||||
reply
|
||||
} else {
|
||||
panic!("Incorrect reply variant returned from protocol::Clientrpc; expected `{}`, but got {:?}", stringify!($fn_name), reply);
|
||||
}
|
||||
}
|
||||
let reply = (self.0).rpc_async(&request_variant!($fn_name $($arg),*));
|
||||
Future {
|
||||
future: reply,
|
||||
mapper: mapper,
|
||||
}
|
||||
}
|
||||
)*);
|
||||
}
|
||||
|
||||
// Required because enum variants with no fields can't be suffixed by parens
|
||||
#[doc(hidden)]
|
||||
#[macro_export]
|
||||
@@ -220,6 +264,19 @@ macro_rules! rpc {
|
||||
)*
|
||||
}
|
||||
|
||||
/// An asynchronous RPC call
|
||||
pub struct Future<T> {
|
||||
future: $crate::protocol::Future<__Reply>,
|
||||
mapper: fn(__Reply) -> T,
|
||||
}
|
||||
|
||||
impl<T> Future<T> {
|
||||
/// Block until the result of the RPC call is available
|
||||
pub fn get(self) -> $crate::Result<T> {
|
||||
self.future.get().map(self.mapper)
|
||||
}
|
||||
}
|
||||
|
||||
#[doc="The client stub that makes RPC calls to the server."]
|
||||
pub struct Client($crate::protocol::Client<__Request, __Reply>);
|
||||
|
||||
@@ -241,6 +298,27 @@ macro_rules! rpc {
|
||||
);
|
||||
}
|
||||
|
||||
#[doc="The client stub that makes asynchronous RPC calls to the server."]
|
||||
pub struct AsyncClient($crate::protocol::Client<__Request, __Reply>);
|
||||
|
||||
impl AsyncClient {
|
||||
#[doc="Create a new asynchronous client that connects to the given address."]
|
||||
pub fn new<A>(addr: A, timeout: ::std::option::Option<::std::time::Duration>)
|
||||
-> $crate::Result<Self>
|
||||
where A: ::std::net::ToSocketAddrs,
|
||||
{
|
||||
let inner = try!($crate::protocol::Client::new(addr, timeout));
|
||||
Ok(AsyncClient(inner))
|
||||
}
|
||||
|
||||
async_client_methods!(
|
||||
$(
|
||||
{ $(#[$attr])* }
|
||||
$fn_name($($arg: $in_),*) -> $out
|
||||
)*
|
||||
);
|
||||
}
|
||||
|
||||
struct __Server<S: 'static + Service>(S);
|
||||
|
||||
impl<S> $crate::protocol::Serve for __Server<S>
|
||||
@@ -277,7 +355,9 @@ macro_rules! rpc {
|
||||
#[cfg(test)]
|
||||
#[allow(dead_code)]
|
||||
mod test {
|
||||
extern crate env_logger;
|
||||
use std::time::Duration;
|
||||
use test::Bencher;
|
||||
|
||||
fn test_timeout() -> Option<Duration> {
|
||||
Some(Duration::from_secs(5))
|
||||
@@ -322,17 +402,27 @@ mod test {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn simple_test() {
|
||||
println!("Starting");
|
||||
let addr = "127.0.0.1:9000";
|
||||
let shutdown = my_server::serve(addr, Server, test_timeout()).unwrap();
|
||||
let client = Client::new(addr, None).unwrap();
|
||||
fn simple() {
|
||||
let handle = my_server::serve( "localhost:0", Server, test_timeout()).unwrap();
|
||||
let client = Client::new(handle.local_addr(), None).unwrap();
|
||||
assert_eq!(3, client.add(1, 2).unwrap());
|
||||
let foo = Foo { message: "Adam".into() };
|
||||
let want = Foo { message: format!("Hello, {}", &foo.message) };
|
||||
assert_eq!(want, client.hello(Foo { message: "Adam".into() }).unwrap());
|
||||
drop(client);
|
||||
shutdown.shutdown();
|
||||
handle.shutdown();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn simple_async() {
|
||||
let handle = my_server::serve("localhost:0", Server, test_timeout()).unwrap();
|
||||
let client = AsyncClient::new(handle.local_addr(), None).unwrap();
|
||||
assert_eq!(3, client.add(1, 2).get().unwrap());
|
||||
let foo = Foo { message: "Adam".into() };
|
||||
let want = Foo { message: format!("Hello, {}", &foo.message) };
|
||||
assert_eq!(want, client.hello(Foo { message: "Adam".into() }).get().unwrap());
|
||||
drop(client);
|
||||
handle.shutdown();
|
||||
}
|
||||
|
||||
// Tests a service definition with a fn that takes no args
|
||||
@@ -368,7 +458,6 @@ mod test {
|
||||
}
|
||||
|
||||
service {
|
||||
#[doc="Hello bob"]
|
||||
#[inline(always)]
|
||||
rpc baz(s: String) -> HashMap<String, String>;
|
||||
}
|
||||
@@ -400,4 +489,38 @@ mod test {
|
||||
fn debug() {
|
||||
println!("{:?}", baz::Debuggable);
|
||||
}
|
||||
|
||||
rpc! {
|
||||
mod hello {
|
||||
service {
|
||||
rpc hello(s: String) -> String;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct HelloServer;
|
||||
impl hello::Service for HelloServer {
|
||||
fn hello(&self, s: String) -> String {
|
||||
format!("Hello, {}!", s)
|
||||
}
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn hello(bencher: &mut Bencher) {
|
||||
let _ = env_logger::init();
|
||||
let handle = hello::serve("localhost:0", HelloServer, None).unwrap();
|
||||
let client = hello::AsyncClient::new(handle.local_addr(), None).unwrap();
|
||||
let concurrency = 100;
|
||||
let mut rpcs = Vec::with_capacity(concurrency);
|
||||
bencher.iter(|| {
|
||||
for _ in 0..concurrency {
|
||||
rpcs.push(client.hello("Bob".into()));
|
||||
}
|
||||
for _ in 0..concurrency {
|
||||
rpcs.pop().unwrap().get().unwrap();
|
||||
}
|
||||
});
|
||||
drop(client);
|
||||
handle.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ use std::marker::PhantomData;
|
||||
use std::mem;
|
||||
use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
|
||||
use std::sync::{Arc, Condvar, Mutex};
|
||||
use std::sync::mpsc::{Sender, TryRecvError, channel};
|
||||
use std::sync::mpsc::{Receiver, Sender, TryRecvError, channel};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::Duration;
|
||||
use std::thread::{self, JoinHandle};
|
||||
@@ -59,6 +59,22 @@ impl convert::From<io::Error> for Error {
|
||||
/// Return type of rpc calls: either the successful return value, or a client error.
|
||||
pub type Result<T> = ::std::result::Result<T, Error>;
|
||||
|
||||
/// An asynchronous RPC call
|
||||
pub struct Future<T> {
|
||||
rx: Result<Receiver<T>>,
|
||||
requests: Arc<Mutex<RpcFutures<T>>>
|
||||
}
|
||||
|
||||
impl<T> Future<T> {
|
||||
/// Block until the result of the RPC call is available
|
||||
pub fn get(self) -> Result<T> {
|
||||
let requests = self.requests;
|
||||
try!(self.rx)
|
||||
.recv()
|
||||
.map_err(|_| requests.lock().unwrap().get_error())
|
||||
}
|
||||
}
|
||||
|
||||
struct InflightRpcs {
|
||||
count: Mutex<u64>,
|
||||
cvar: Condvar,
|
||||
@@ -214,9 +230,9 @@ impl ServeHandle {
|
||||
/// gracefully close open connections.
|
||||
pub fn shutdown(self) {
|
||||
info!("ServeHandle: attempting to shut down the server.");
|
||||
self.tx.send(()).expect(&line!().to_string());
|
||||
self.tx.send(()).unwrap();
|
||||
if let Ok(_) = TcpStream::connect(self.addr) {
|
||||
self.join_handle.join().expect(&line!().to_string());
|
||||
self.join_handle.join().unwrap();
|
||||
} else {
|
||||
warn!("ServeHandle: best effort shutdown of serve thread failed");
|
||||
}
|
||||
@@ -343,7 +359,9 @@ impl<Reply> RpcFutures<Reply> {
|
||||
|
||||
fn complete_reply(&mut self, id: u64, reply: Reply) {
|
||||
if let Some(tx) = self.0.as_mut().unwrap().remove(&id) {
|
||||
tx.send(reply).unwrap();
|
||||
if let Err(e) = tx.send(reply) {
|
||||
info!("Reader: could not complete reply: {:?}", e);
|
||||
}
|
||||
} else {
|
||||
warn!("RpcFutures: expected sender for id {} but got None!", id);
|
||||
}
|
||||
@@ -435,8 +453,7 @@ impl<Request, Reply> Client<Request, Reply>
|
||||
})
|
||||
}
|
||||
|
||||
/// Run the specified rpc method on the server this client is connected to
|
||||
pub fn rpc(&self, request: &Request) -> Result<Reply>
|
||||
fn rpc_internal(&self, request: &Request) -> Result<Receiver<Reply>>
|
||||
where Request: serde::ser::Serialize + fmt::Debug + Send + 'static
|
||||
{
|
||||
let (tx, rx) = channel();
|
||||
@@ -458,14 +475,25 @@ impl<Request, Reply> Client<Request, Reply>
|
||||
err);
|
||||
try!(self.requests.lock().unwrap().remove_tx(id));
|
||||
}
|
||||
debug!("Client: finishing rpc({:?})", request);
|
||||
drop(state);
|
||||
match rx.recv() {
|
||||
Ok(msg) => Ok(msg),
|
||||
Err(_) => {
|
||||
debug!("locking requests map");
|
||||
Err(self.requests.lock().unwrap().get_error())
|
||||
}
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
/// Run the specified rpc method on the server this client is connected to
|
||||
pub fn rpc(&self, request: &Request) -> Result<Reply>
|
||||
where Request: serde::ser::Serialize + fmt::Debug + Send + 'static
|
||||
{
|
||||
try!(self.rpc_internal(request))
|
||||
.recv()
|
||||
.map_err(|_| self.requests.lock().unwrap().get_error())
|
||||
}
|
||||
|
||||
/// Asynchronously run the specified rpc method on the server this client is connected to
|
||||
pub fn rpc_async(&self, request: &Request) -> Future<Reply>
|
||||
where Request: serde::ser::Serialize + fmt::Debug + Send + 'static
|
||||
{
|
||||
Future {
|
||||
rx: self.rpc_internal(request),
|
||||
requests: self.requests.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -539,9 +567,8 @@ mod test {
|
||||
fn handle() {
|
||||
let _ = env_logger::init();
|
||||
let server = Arc::new(Server::new());
|
||||
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
|
||||
let client: Client<Request, Reply> = Client::new(serve_handle.local_addr().clone(), None)
|
||||
.expect(&line!().to_string());
|
||||
let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap();
|
||||
let client: Client<Request, Reply> = Client::new(serve_handle.local_addr(), None).unwrap();
|
||||
drop(client);
|
||||
serve_handle.shutdown();
|
||||
}
|
||||
@@ -550,7 +577,7 @@ mod test {
|
||||
fn simple() {
|
||||
let _ = env_logger::init();
|
||||
let server = Arc::new(Server::new());
|
||||
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
|
||||
let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap();
|
||||
let addr = serve_handle.local_addr().clone();
|
||||
let client = Client::new(addr, None).unwrap();
|
||||
assert_eq!(Reply::Increment(0),
|
||||
@@ -594,12 +621,11 @@ mod test {
|
||||
fn force_shutdown() {
|
||||
let _ = env_logger::init();
|
||||
let server = Arc::new(Server::new());
|
||||
let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap();
|
||||
let serve_handle = serve_async("localhost:0", server, Some(Duration::new(0, 10))).unwrap();
|
||||
let addr = serve_handle.local_addr().clone();
|
||||
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None).unwrap());
|
||||
let client: Client<Request, Reply> = Client::new(addr, None).unwrap();
|
||||
let thread = thread::spawn(move || serve_handle.shutdown());
|
||||
info!("force_shutdown:: rpc1: {:?}",
|
||||
client.rpc(&Request::Increment));
|
||||
info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment));
|
||||
thread.join().unwrap();
|
||||
}
|
||||
|
||||
@@ -607,7 +633,7 @@ mod test {
|
||||
fn client_failed_rpc() {
|
||||
let _ = env_logger::init();
|
||||
let server = Arc::new(Server::new());
|
||||
let serve_handle = serve_async("0.0.0.0:0", server, test_timeout()).unwrap();
|
||||
let serve_handle = serve_async("localhost:0", server, test_timeout()).unwrap();
|
||||
let addr = serve_handle.local_addr().clone();
|
||||
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None).unwrap());
|
||||
serve_handle.shutdown();
|
||||
@@ -621,19 +647,20 @@ mod test {
|
||||
#[test]
|
||||
fn concurrent() {
|
||||
let _ = env_logger::init();
|
||||
let server = Arc::new(BarrierServer::new(10));
|
||||
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
|
||||
let concurrency = 10;
|
||||
let server = Arc::new(BarrierServer::new(concurrency));
|
||||
let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap();
|
||||
let addr = serve_handle.local_addr().clone();
|
||||
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None).unwrap());
|
||||
let mut join_handles = vec![];
|
||||
for _ in 0..10 {
|
||||
for _ in 0..concurrency {
|
||||
let my_client = client.clone();
|
||||
join_handles.push(thread::spawn(move || my_client.rpc(&Request::Increment).unwrap()));
|
||||
}
|
||||
for handle in join_handles.into_iter() {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
assert_eq!(10, server.count());
|
||||
assert_eq!(concurrency as u64, server.count());
|
||||
let client = match Arc::try_unwrap(client) {
|
||||
Err(_) => panic!("couldn't unwrap arc"),
|
||||
Ok(c) => c,
|
||||
@@ -641,4 +668,22 @@ mod test {
|
||||
drop(client);
|
||||
serve_handle.shutdown();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn async() {
|
||||
let _ = env_logger::init();
|
||||
let server = Arc::new(Server::new());
|
||||
let serve_handle = serve_async("localhost:0", server.clone(), None).unwrap();
|
||||
let addr = serve_handle.local_addr().clone();
|
||||
let client: Client<Request, Reply> = Client::new(addr, None).unwrap();
|
||||
|
||||
// Drop future immediately; does the reader channel panic when sending?
|
||||
client.rpc_async(&Request::Increment);
|
||||
// If the reader panicked, this won't succeed
|
||||
client.rpc_async(&Request::Increment);
|
||||
|
||||
drop(client);
|
||||
serve_handle.shutdown();
|
||||
assert_eq!(server.count(), 2);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user