Add an async client

This commit is contained in:
Tim Kuehn
2016-01-27 01:09:01 -08:00
parent 489ab555c3
commit 6109d825f6
3 changed files with 152 additions and 9 deletions

View File

@@ -53,7 +53,7 @@
//! ```
#![deny(missing_docs)]
#![feature(custom_derive, plugin)]
#![feature(custom_derive, plugin, test)]
#![plugin(serde_macros)]
extern crate serde;
@@ -61,6 +61,7 @@ extern crate bincode;
#[macro_use]
extern crate log;
extern crate crossbeam;
extern crate test;
/// Provides the tarpc client and server, which implements the tarpc protocol.
/// The protocol is defined by the implementation.

View File

@@ -106,6 +106,50 @@ macro_rules! client_methods {
)*);
}
// Required because if-let can't be used with irrefutable patterns, so it needs
// to be special cased.
#[doc(hidden)]
#[macro_export]
macro_rules! async_client_methods {
(
{ $(#[$attr:meta])* }
$fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty
) => (
$(#[$attr])*
pub fn $fn_name(&self, $($arg: $in_),*) -> Future<$out> {
fn mapper(reply: __Reply) -> $out {
let __Reply::$fn_name(reply) = reply;
reply
}
let reply = (self.0).rpc_async(&request_variant!($fn_name $($arg),*));
Future {
future: reply,
mapper: mapper,
}
}
);
($(
{ $(#[$attr:meta])* }
$fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty
)*) => ( $(
$(#[$attr])*
pub fn $fn_name(&self, $($arg: $in_),*) -> Future<$out> {
fn mapper(reply: __Reply) -> $out {
if let __Reply::$fn_name(reply) = reply {
reply
} else {
panic!("Incorrect reply variant returned from protocol::Clientrpc; expected `{}`, but got {:?}", stringify!($fn_name), reply);
}
}
let reply = (self.0).rpc_async(&request_variant!($fn_name $($arg),*));
Future {
future: reply,
mapper: mapper,
}
}
)*);
}
// Required because enum variants with no fields can't be suffixed by parens
#[doc(hidden)]
#[macro_export]
@@ -220,6 +264,19 @@ macro_rules! rpc {
)*
}
/// An asynchronous RPC call
pub struct Future<T> {
future: $crate::protocol::Future<__Reply>,
mapper: fn(__Reply) -> T,
}
impl<T> Future<T> {
/// Block until the result of the RPC call is available
pub fn get(self) -> $crate::Result<T> {
self.future.get().map(self.mapper)
}
}
#[doc="The client stub that makes RPC calls to the server."]
pub struct Client($crate::protocol::Client<__Request, __Reply>);
@@ -241,6 +298,27 @@ macro_rules! rpc {
);
}
#[doc="The client stub that makes asynchronous RPC calls to the server."]
pub struct AsyncClient($crate::protocol::Client<__Request, __Reply>);
impl AsyncClient {
#[doc="Create a new asynchronous client that connects to the given address."]
pub fn new<A>(addr: A, timeout: ::std::option::Option<::std::time::Duration>)
-> $crate::Result<Self>
where A: ::std::net::ToSocketAddrs,
{
let inner = try!($crate::protocol::Client::new(addr, timeout));
Ok(AsyncClient(inner))
}
async_client_methods!(
$(
{ $(#[$attr])* }
$fn_name($($arg: $in_),*) -> $out
)*
);
}
struct __Server<S: 'static + Service>(S);
impl<S> $crate::protocol::Serve for __Server<S>
@@ -277,7 +355,9 @@ macro_rules! rpc {
#[cfg(test)]
#[allow(dead_code)]
mod test {
extern crate env_logger;
use std::time::Duration;
use test::Bencher;
fn test_timeout() -> Option<Duration> {
Some(Duration::from_secs(5))
@@ -368,7 +448,6 @@ mod test {
}
service {
#[doc="Hello bob"]
#[inline(always)]
rpc baz(s: String) -> HashMap<String, String>;
}
@@ -400,4 +479,37 @@ mod test {
fn debug() {
println!("{:?}", baz::Debuggable);
}
rpc! {
mod hello {
service {
rpc hello(s: String) -> String;
}
}
}
struct HelloServer;
impl hello::Service for HelloServer {
fn hello(&self, s: String) -> String {
format!("Hello, {}!", s)
}
}
#[bench]
fn hello(bencher: &mut Bencher) {
let _ = env_logger::init();
let handle = hello::serve("localhost:0", HelloServer, None).unwrap();
let client = hello::AsyncClient::new(handle.local_addr(), None).unwrap();
let mut rpcs = Vec::with_capacity(100);
bencher.iter(|| {
for _ in 0..1000 {
rpcs.push(client.hello("Bob".into()));
}
for _ in 0..1000 {
rpcs.pop().unwrap().get().unwrap();
}
});
drop(client);
handle.shutdown();
}
}

View File

@@ -17,7 +17,7 @@ use std::marker::PhantomData;
use std::mem;
use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
use std::sync::{Arc, Condvar, Mutex};
use std::sync::mpsc::{Sender, TryRecvError, channel};
use std::sync::mpsc::{Receiver, Sender, TryRecvError, channel};
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use std::thread::{self, JoinHandle};
@@ -59,6 +59,22 @@ impl convert::From<io::Error> for Error {
/// Return type of rpc calls: either the successful return value, or a client error.
pub type Result<T> = ::std::result::Result<T, Error>;
/// An asynchronous RPC call
pub struct Future<T> {
rx: Result<Receiver<T>>,
requests: Arc<Mutex<RpcFutures<T>>>
}
impl<T> Future<T> {
/// Block until the result of the RPC call is available
pub fn get(self) -> Result<T> {
let requests = self.requests;
try!(self.rx)
.recv()
.map_err(|_| requests.lock().unwrap().get_error())
}
}
struct InflightRpcs {
count: Mutex<u64>,
cvar: Condvar,
@@ -437,8 +453,7 @@ impl<Request, Reply> Client<Request, Reply>
})
}
/// Run the specified rpc method on the server this client is connected to
pub fn rpc(&self, request: &Request) -> Result<Reply>
fn rpc_internal(&self, request: &Request) -> Result<Receiver<Reply>>
where Request: serde::ser::Serialize + fmt::Debug + Send + 'static
{
let (tx, rx) = channel();
@@ -460,10 +475,25 @@ impl<Request, Reply> Client<Request, Reply>
err);
try!(self.requests.lock().unwrap().remove_tx(id));
}
drop(state);
match rx.recv() {
Ok(msg) => Ok(msg),
Err(_) => Err(self.requests.lock().unwrap().get_error()),
Ok(rx)
}
/// Run the specified rpc method on the server this client is connected to
pub fn rpc(&self, request: &Request) -> Result<Reply>
where Request: serde::ser::Serialize + fmt::Debug + Send + 'static
{
try!(self.rpc_internal(request))
.recv()
.map_err(|_| self.requests.lock().unwrap().get_error())
}
/// Asynchronously run the specified rpc method on the server this client is connected to
pub fn rpc_async(&self, request: &Request) -> Future<Reply>
where Request: serde::ser::Serialize + fmt::Debug + Send + 'static
{
Future {
rx: self.rpc_internal(request),
requests: self.requests.clone(),
}
}
}