diff --git a/README.md b/README.md index 326b11a..01c430f 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,9 @@ tarpc is an RPC framework for rust with a focus on ease of use. Defining and imp extern crate tarpc; extern crate serde; -rpc! { - mod hello_service { - service { - rpc hello(name: String) -> String; - } +mod hello_service { + service! { + rpc hello(name: String) -> String; } } @@ -34,11 +32,18 @@ fn main() { } ``` -The `rpc!` macro generates a module in the current module. In the above example, the module is named `hello_service`. This module will contain a `Client` type, a `Service` trait, and a `serve` function. `serve` can be used to start a server listening on a tcp port. A `Client` can connect to such a service. Any type implementing the `Service` trait can be passed to `serve`. These generated types are specific to the echo service, and make it easy and ergonomic to write servers without dealing with sockets or serialization directly. See the tarpc_examples package for more sophisticated examples. +The `service!` macro expands to a collection of items that collectively form an rpc service. In the +above example, the macro is called within the `hello_service` module. This module will contain a +`Client` type, a `Service` trait, and a `serve` function. `serve` can be used to start a server +listening on a tcp port. A `Client` can connect to such a service. Any type implementing the +`Service` trait can be passed to `serve`. These generated types are specific to the echo service, +and make it easy and ergonomic to write servers without dealing with sockets or serialization +directly. See the tarpc_examples package for more sophisticated examples. ## Additional Features - Imports can be specified in an `item {}` block that appears above the `service {}` block. -- Attributes can be specified on rpc methods. These will be included on both the `Service` trait methods as well as on the `Client`'s stub methods. +- Attributes can be specified on rpc methods. These will be included on both the `Service` trait + methods as well as on the `Client`'s stub methods. ## Planned Improvements (actively being worked on) - Automatically reconnect on the client side when the connection cuts out. diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index c6b0d69..37ed583 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -11,4 +11,5 @@ bincode = "*" serde_macros = "*" log = "*" env_logger = "*" -crossbeam = "*" +scoped-pool = "*" +lazy_static = "*" diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 2f6755b..a8e6486 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -15,12 +15,10 @@ //! # #![plugin(serde_macros)] //! # #[macro_use] extern crate tarpc; //! # extern crate serde; -//! rpc! { -//! mod my_server { -//! service { -//! rpc hello(name: String) -> String; -//! rpc add(x: i32, y: i32) -> i32; -//! } +//! mod my_server { +//! service! { +//! rpc hello(name: String) -> String; +//! rpc add(x: i32, y: i32) -> i32; //! } //! } //! @@ -60,8 +58,11 @@ extern crate serde; extern crate bincode; #[macro_use] extern crate log; -extern crate crossbeam; +extern crate scoped_pool; extern crate test; +#[cfg(test)] +#[macro_use] +extern crate lazy_static; /// Provides the tarpc client and server, which implements the tarpc protocol. /// The protocol is defined by the implementation. diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs index b5195f1..656c63c 100644 --- a/tarpc/src/macros.rs +++ b/tarpc/src/macros.rs @@ -10,70 +10,6 @@ #[macro_export] macro_rules! as_item { ($i:item) => {$i} } -// Inserts a placeholder doc comment for the module if it's missing -#[doc(hidden)] -#[macro_export] -macro_rules! add_mod_doc { - // If nothing left, return - ( - @rec - { $(#[$done:meta])* } - { } - $i:item - ) => { - $(#[$done])* - #[doc="A module containing an rpc service and client stub."] - $i - }; - - // If we find a doc attribute, return - ( - @rec - { $(#[$done:meta])* } - { - #[doc=$doc:expr] - $(#[$rest:meta])* - } - $i:item - ) => { - $(#[$done])* - #[doc=$doc] - $(#[$rest])* - $i - }; - - // If we don't find a doc attribute, keep going - ( - @rec - { $(#[$($done:tt)*])* } - { - #[$($attr:tt)*] - $($rest:tt)* - } - $i:item - ) => { - add_mod_doc! { - @rec - { $(#[$($done)*])* #[$($attr)*] } - { $($rest)* } - $i - } - }; - - // Entry - ( - { $(#[$($attr:tt)*])* } - $i:item - ) => { - add_mod_doc! { - @rec - {} - { $(#[$($attr)*])* } - $i - } - }; -} - // Required because if-let can't be used with irrefutable patterns, so it needs // to be special cased. #[doc(hidden)] @@ -85,7 +21,7 @@ macro_rules! client_methods { ) => ( $(#[$attr])* pub fn $fn_name(&self, $($arg: $in_),*) -> $crate::Result<$out> { - let reply = try!((self.0).rpc(&request_variant!($fn_name $($arg),*))); + let reply = try!((self.0).rpc(request_variant!($fn_name $($arg),*))); let __Reply::$fn_name(reply) = reply; Ok(reply) } @@ -96,7 +32,7 @@ macro_rules! client_methods { )*) => ( $( $(#[$attr])* pub fn $fn_name(&self, $($arg: $in_),*) -> $crate::Result<$out> { - let reply = try!((self.0).rpc(&request_variant!($fn_name $($arg),*))); + let reply = try!((self.0).rpc(request_variant!($fn_name $($arg),*))); if let __Reply::$fn_name(reply) = reply { Ok(reply) } else { @@ -121,7 +57,7 @@ macro_rules! async_client_methods { let __Reply::$fn_name(reply) = reply; reply } - let reply = (self.0).rpc_async(&request_variant!($fn_name $($arg),*)); + let reply = (self.0).rpc_async(request_variant!($fn_name $($arg),*)); Future { future: reply, mapper: mapper, @@ -141,7 +77,7 @@ macro_rules! async_client_methods { panic!("Incorrect reply variant returned from protocol::Clientrpc; expected `{}`, but got {:?}", stringify!($fn_name), reply); } } - let reply = (self.0).rpc_async(&request_variant!($fn_name $($arg),*)); + let reply = (self.0).rpc_async(request_variant!($fn_name $($arg),*)); Future { future: reply, mapper: mapper, @@ -174,180 +110,162 @@ macro_rules! request_variant { ($x:ident $($y:ident),+) => (__Request::$x($($y),+)); } -// The main macro that creates RPC services. +/// The main macro that creates RPC services. +/// +/// Rpc methods are specified, mirroring trait syntax: +/// +/// ``` +/// # #![feature(custom_derive, plugin)] +/// # #![plugin(serde_macros)] +/// # #[macro_use] extern crate tarpc; +/// # extern crate serde; +/// # fn main() {} +/// # service! { +/// #[doc="Say hello"] +/// rpc hello(name: String) -> String; +/// # } +/// ``` +/// +/// Attributes can be attached to each rpc. These attributes +/// will then be attached to the generated `Service` trait's +/// corresponding method, as well as to the `Client` stub's rpcs methods. +/// +/// The following items are expanded in the enclosing module: +/// +/// * `Service` -- the trait defining the RPC service +/// * `Client` -- a client that makes synchronous requests to the RPC server +/// * `AsyncClient` -- a client that makes asynchronous requests to the RPC server +/// * `Future` -- a handle for asynchronously retrieving the result of an RPC +/// * `serve` -- the function that starts the RPC server +/// +/// **Warning**: In addition to the above items, there are a few expanded items that +/// are considered implementation details. As with the above items, shadowing +/// these item names in the enclosing module is likely to break things in confusing +/// ways: +/// +/// * `__Server` -- an implementation detail +/// * `__Request` -- an implementation detail +/// * `__Reply` -- an implementation detail #[macro_export] -macro_rules! rpc { +macro_rules! service { ( - $(#[$($service_attr:tt)*])* - mod $server:ident { - - service { - $( - $(#[$attr:meta])* - rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty; - )* - } - } + $( + $(#[$attr:meta])* + rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty; + )* ) => { - rpc! { - $(#[$($service_attr)*])* - mod $server { + #[doc="Defines the RPC service"] + pub trait Service: Send + Sync { + $( + $(#[$attr])* + fn $fn_name(&self, $($arg:$in_),*) -> $out; + )* + } - items { } - - service { - $( - $(#[$attr])* - rpc $fn_name($($arg: $in_),*) -> $out; - )* + impl Service for P + where P: Send + Sync + ::std::ops::Deref, + S: Service + { + $( + $(#[$attr])* + fn $fn_name(&self, $($arg:$in_),*) -> $out { + Service::$fn_name(&**self, $($arg),*) } + )* + } + + define_request!($($fn_name($($in_),*))*); + + #[allow(non_camel_case_types)] + #[derive(Debug, Serialize, Deserialize)] + enum __Reply { + $( + $fn_name($out), + )* + } + + /// An asynchronous RPC call + pub struct Future { + future: $crate::protocol::Future<__Reply>, + mapper: fn(__Reply) -> T, + } + + impl Future { + /// Block until the result of the RPC call is available + pub fn get(self) -> $crate::Result { + self.future.get().map(self.mapper) } } - }; - ( - // Names the service - $(#[$($service_attr:tt)*])* - mod $server:ident { + #[doc="The client stub that makes RPC calls to the server."] + pub struct Client($crate::protocol::Client<__Request, __Reply>); - // Include any desired or required items. Conflicts can arise with the following names: - // 1. Service - // 2. Client - // 3. serve - // 4. __Reply - // 5. __Request - items { $($i:item)* } + impl Client { + #[doc="Create a new client that connects to the given address."] + pub fn new(addr: A, timeout: ::std::option::Option<::std::time::Duration>) + -> $crate::Result + where A: ::std::net::ToSocketAddrs, + { + let inner = try!($crate::protocol::Client::new(addr, timeout)); + Ok(Client(inner)) + } - // List any rpc methods: rpc foo(arg1: Arg1, ..., argN: ArgN) -> Out - service { + client_methods!( $( - $(#[$attr:meta])* - rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty; + { $(#[$attr])* } + $fn_name($($arg: $in_),*) -> $out )* + ); + } + + #[doc="The client stub that makes asynchronous RPC calls to the server."] + pub struct AsyncClient($crate::protocol::Client<__Request, __Reply>); + + impl AsyncClient { + #[doc="Create a new asynchronous client that connects to the given address."] + pub fn new(addr: A, timeout: ::std::option::Option<::std::time::Duration>) + -> $crate::Result + where A: ::std::net::ToSocketAddrs, + { + let inner = try!($crate::protocol::Client::new(addr, timeout)); + Ok(AsyncClient(inner)) + } + + async_client_methods!( + $( + { $(#[$attr])* } + $fn_name($($arg: $in_),*) -> $out + )* + ); + } + + struct __Server(S); + + impl $crate::protocol::Serve for __Server + where S: 'static + Service + { + type Request = __Request; + type Reply = __Reply; + fn serve(&self, request: __Request) -> __Reply { + match request { + $( + request_variant!($fn_name $($arg),*) => + __Reply::$fn_name((self.0).$fn_name($($arg),*)), + )* + } } } - ) => { - add_mod_doc! { - { $(#[$($service_attr)*])* } - pub mod $server { - - $($i)* - - #[doc="The provided RPC service."] - pub trait Service: Send + Sync { - $( - $(#[$attr])* - fn $fn_name(&self, $($arg:$in_),*) -> $out; - )* - } - - impl Service for P - where P: Send + Sync + ::std::ops::Deref, - S: Service - { - $( - $(#[$attr])* - fn $fn_name(&self, $($arg:$in_),*) -> $out { - Service::$fn_name(&**self, $($arg),*) - } - )* - } - - define_request!($($fn_name($($in_),*))*); - - #[allow(non_camel_case_types)] - #[derive(Debug, Serialize, Deserialize)] - enum __Reply { - $( - $fn_name($out), - )* - } - - /// An asynchronous RPC call - pub struct Future { - future: $crate::protocol::Future<__Reply>, - mapper: fn(__Reply) -> T, - } - - impl Future { - /// Block until the result of the RPC call is available - pub fn get(self) -> $crate::Result { - self.future.get().map(self.mapper) - } - } - - #[doc="The client stub that makes RPC calls to the server."] - pub struct Client($crate::protocol::Client<__Request, __Reply>); - - impl Client { - #[doc="Create a new client that connects to the given address."] - pub fn new(addr: A, timeout: ::std::option::Option<::std::time::Duration>) - -> $crate::Result - where A: ::std::net::ToSocketAddrs, - { - let inner = try!($crate::protocol::Client::new(addr, timeout)); - Ok(Client(inner)) - } - - client_methods!( - $( - { $(#[$attr])* } - $fn_name($($arg: $in_),*) -> $out - )* - ); - } - - #[doc="The client stub that makes asynchronous RPC calls to the server."] - pub struct AsyncClient($crate::protocol::Client<__Request, __Reply>); - - impl AsyncClient { - #[doc="Create a new asynchronous client that connects to the given address."] - pub fn new(addr: A, timeout: ::std::option::Option<::std::time::Duration>) - -> $crate::Result - where A: ::std::net::ToSocketAddrs, - { - let inner = try!($crate::protocol::Client::new(addr, timeout)); - Ok(AsyncClient(inner)) - } - - async_client_methods!( - $( - { $(#[$attr])* } - $fn_name($($arg: $in_),*) -> $out - )* - ); - } - - struct __Server(S); - - impl $crate::protocol::Serve for __Server - where S: 'static + Service - { - type Request = __Request; - type Reply = __Reply; - fn serve(&self, request: __Request) -> __Reply { - match request { - $( - request_variant!($fn_name $($arg),*) => - __Reply::$fn_name((self.0).$fn_name($($arg),*)), - )* - } - } - } - - #[doc="Start a running service."] - pub fn serve(addr: A, - service: S, - read_timeout: ::std::option::Option<::std::time::Duration>) - -> $crate::Result<$crate::protocol::ServeHandle> - where A: ::std::net::ToSocketAddrs, - S: 'static + Service - { - let server = ::std::sync::Arc::new(__Server(service)); - Ok(try!($crate::protocol::serve_async(addr, server, read_timeout))) - } - } + #[doc="Start a running service."] + pub fn serve(addr: A, + service: S, + read_timeout: ::std::option::Option<::std::time::Duration>) + -> $crate::Result<$crate::protocol::ServeHandle> + where A: ::std::net::ToSocketAddrs, + S: 'static + Service + { + let server = ::std::sync::Arc::new(__Server(service)); + Ok(try!($crate::protocol::serve_async(addr, server, read_timeout))) } } } @@ -356,6 +274,8 @@ macro_rules! rpc { #[allow(dead_code)] mod test { extern crate env_logger; + use ServeHandle; + use std::sync::{Arc, Mutex}; use std::time::Duration; use test::Bencher; @@ -363,28 +283,23 @@ mod test { Some(Duration::from_secs(5)) } - rpc! { - #[deny(missing_docs)] - #[doc="Hello"] - mod my_server { - items { - #[derive(PartialEq, Debug, Serialize, Deserialize)] - pub struct Foo { - pub message: String - } - } + #[derive(PartialEq, Debug, Serialize, Deserialize)] + pub struct Foo { + pub message: String + } - service { - rpc hello(foo: Foo) -> Foo; - rpc add(x: i32, y: i32) -> i32; - } + mod my_server { + use super::Foo; + + service! { + rpc hello(foo: Foo) -> Foo; + rpc add(x: i32, y: i32) -> i32; } } - use self::my_server::*; - struct Server; - impl Service for Server { + + impl my_server::Service for Server { fn hello(&self, s: Foo) -> Foo { Foo { message: format!("Hello, {}", &s.message) } } @@ -396,7 +311,7 @@ mod test { #[test] fn serve_arc_server() { - serve("localhost:0", ::std::sync::Arc::new(Server), None) + my_server::serve("localhost:0", ::std::sync::Arc::new(Server), None) .unwrap() .shutdown(); } @@ -404,7 +319,7 @@ mod test { #[test] fn simple() { let handle = my_server::serve( "localhost:0", Server, test_timeout()).unwrap(); - let client = Client::new(handle.local_addr(), None).unwrap(); + let client = my_server::Client::new(handle.local_addr(), None).unwrap(); assert_eq!(3, client.add(1, 2).unwrap()); let foo = Foo { message: "Adam".into() }; let want = Foo { message: format!("Hello, {}", &foo.message) }; @@ -416,7 +331,7 @@ mod test { #[test] fn simple_async() { let handle = my_server::serve("localhost:0", Server, test_timeout()).unwrap(); - let client = AsyncClient::new(handle.local_addr(), None).unwrap(); + let client = my_server::AsyncClient::new(handle.local_addr(), None).unwrap(); assert_eq!(3, client.add(1, 2).get().unwrap()); let foo = Foo { message: "Adam".into() }; let want = Foo { message: format!("Hello, {}", &foo.message) }; @@ -425,63 +340,48 @@ mod test { handle.shutdown(); } - // Tests a service definition with a fn that takes no args - rpc! { - mod qux { - service { - rpc hello() -> String; - } + /// Tests a service definition with a fn that takes no args + mod qux { + service! { + rpc hello() -> String; } } - // Tests a service definition with an import - rpc! { - mod foo { - items { - use std::collections::HashMap; - } + /// Tests a service definition with an import + mod foo { + use std::collections::HashMap; - service { - #[doc="Hello bob"] - #[inline(always)] - rpc baz(s: String) -> HashMap; - } + service! { + #[doc="Hello bob"] + #[inline(always)] + rpc baz(s: String) -> HashMap; } } - // Tests a service definition with an attribute but no doc comment - rpc! { - #[deny(missing_docs)] - mod bar { - items { - use std::collections::HashMap; - } + /// Tests a service definition with an attribute but no doc comment + #[deny(missing_docs)] + mod bar { + use std::collections::HashMap; - service { - #[inline(always)] - rpc baz(s: String) -> HashMap; - } + service! { + #[inline(always)] + rpc baz(s: String) -> HashMap; } } - // Tests a service definition with an attribute and a doc comment - rpc! { - #[deny(missing_docs)] - #[doc="Hello bob"] - #[allow(unused)] - mod baz { - items { - use std::collections::HashMap; + /// Tests a service definition with an attribute and a doc comment + #[deny(missing_docs)] + #[allow(unused)] + mod baz { + use std::collections::HashMap; - #[derive(Debug)] - pub struct Debuggable; - } + #[derive(Debug)] + pub struct Debuggable; - service { - #[doc="Hello bob"] - #[inline(always)] - rpc baz(s: String) -> HashMap; - } + service! { + #[doc="Hello bob"] + #[inline(always)] + rpc baz(s: String) -> HashMap; } } @@ -490,37 +390,54 @@ mod test { println!("{:?}", baz::Debuggable); } - rpc! { - mod hello { - service { - rpc hello(s: String) -> String; - } + mod hi { + service! { + rpc hello(s: String) -> String; } } struct HelloServer; - impl hello::Service for HelloServer { + + impl hi::Service for HelloServer { fn hello(&self, s: String) -> String { format!("Hello, {}!", s) } } + // Prevents resource exhaustion when benching + lazy_static! { + static ref HANDLE: Arc> = { + let handle = hi::serve("localhost:0", HelloServer, None).unwrap(); + Arc::new(Mutex::new(handle)) + }; + static ref CLIENT: Arc> = { + let addr = HANDLE.lock().unwrap().local_addr().clone(); + let client = hi::AsyncClient::new(addr, None).unwrap(); + Arc::new(Mutex::new(client)) + }; + } + #[bench] fn hello(bencher: &mut Bencher) { let _ = env_logger::init(); - let handle = hello::serve("localhost:0", HelloServer, None).unwrap(); - let client = hello::AsyncClient::new(handle.local_addr(), None).unwrap(); + let client = CLIENT.lock().unwrap(); let concurrency = 100; - let mut rpcs = Vec::with_capacity(concurrency); + let mut futures = Vec::with_capacity(concurrency); + let mut count = 0; bencher.iter(|| { - for _ in 0..concurrency { - rpcs.push(client.hello("Bob".into())); - } - for _ in 0..concurrency { - rpcs.pop().unwrap().get().unwrap(); + futures.push(client.hello("Bob".into())); + count += 1; + if count % concurrency == 0 { + // We can't block on each rpc call, otherwise we'd be + // benchmarking latency instead of throughput. It's also + // not ideal to call more than one rpc per iteration, because + // it makes the output of the bencher harder to parse (you have + // to mentally divide the number by `concurrency` to get + // the ns / iter for one rpc + for f in futures.drain(..) { + f.get().unwrap(); + } } }); - drop(client); - handle.shutdown(); } } diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index bd8b67d..7cfdd00 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -8,12 +8,11 @@ use bincode; use serde; -use crossbeam; +use scoped_pool::Pool; use std::fmt; -use std::io::{self, Read}; +use std::io::{self, BufReader, BufWriter, Read, Write}; use std::convert; use std::collections::HashMap; -use std::marker::PhantomData; use std::mem; use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs}; use std::sync::{Arc, Condvar, Mutex}; @@ -61,7 +60,7 @@ pub type Result = ::std::result::Result; /// An asynchronous RPC call pub struct Future { - rx: Result>, + rx: Receiver>, requests: Arc>> } @@ -69,9 +68,9 @@ impl Future { /// Block until the result of the RPC call is available pub fn get(self) -> Result { let requests = self.requests; - try!(self.rx) - .recv() + self.rx.recv() .map_err(|_| requests.lock().unwrap().get_error()) + .and_then(|reply| reply) } } @@ -116,12 +115,12 @@ impl InflightRpcs { struct ConnectionHandler<'a, S> where S: Serve { - read_stream: TcpStream, - write_stream: Mutex, + read_stream: BufReader, + write_stream: Mutex>, shutdown: &'a AtomicBool, inflight_rpcs: &'a InflightRpcs, - timeout: Option, server: S, + pool: &'a Pool, } impl<'a, S> Drop for ConnectionHandler<'a, S> where S: Serve { @@ -132,32 +131,23 @@ impl<'a, S> Drop for ConnectionHandler<'a, S> where S: Serve { } impl<'a, S> ConnectionHandler<'a, S> where S: Serve { - fn read(read_stream: &mut TcpStream, - timeout: Option) - -> bincode::serde::DeserializeResult> - where Request: serde::de::Deserialize - { - try!(read_stream.set_read_timeout(timeout)); - bincode::serde::deserialize_from(read_stream, bincode::SizeLimit::Infinite) - } - fn handle_conn(&mut self) -> Result<()> { let ConnectionHandler { ref mut read_stream, ref write_stream, shutdown, inflight_rpcs, - timeout, ref server, + pool, } = *self; trace!("ConnectionHandler: serving client..."); - crossbeam::scope(|scope| { + pool.scoped(|scope| { loop { - match Self::read(read_stream, timeout) { + match bincode::serde::deserialize_from(read_stream, bincode::SizeLimit::Infinite) { Ok(Packet { rpc_id, message, }) => { debug!("ConnectionHandler: serving request, id: {}, message: {:?}", rpc_id, message); inflight_rpcs.increment(); - scope.spawn(move || { + scope.execute(move || { let reply = server.serve(message); let reply_packet = Packet { rpc_id: rpc_id, @@ -171,6 +161,10 @@ impl<'a, S> ConnectionHandler<'a, S> where S: Serve { warn!("ConnectionHandler: failed to write reply to Client: {:?}", e); } + if let Err(e) = write_stream.flush() { + warn!("ConnectionHandler: failed to flush reply to Client: {:?}", + e); + } inflight_rpcs.decrement(); }); if shutdown.load(Ordering::SeqCst) { @@ -252,9 +246,10 @@ pub fn serve_async(addr: A, info!("serve_async: spinning up server on {:?}", addr); let (die_tx, die_rx) = channel(); let join_handle = thread::spawn(move || { + let pool = Pool::new(100); // TODO(tjk): make this configurable, and expire idle threads let shutdown = AtomicBool::new(false); let inflight_rpcs = InflightRpcs::new(); - crossbeam::scope(|scope| { + pool.scoped(|scope| { for conn in listener.incoming() { match die_rx.try_recv() { Ok(_) => { @@ -277,15 +272,19 @@ pub fn serve_async(addr: A, } Ok(c) => c, }; + if let Err(err) = conn.set_read_timeout(read_timeout) { + info!("Server: could not set read timeout: {:?}", err); + return; + } inflight_rpcs.increment(); - scope.spawn(|| { + scope.execute(|| { let mut handler = ConnectionHandler { - read_stream: conn.try_clone().unwrap(), - write_stream: Mutex::new(conn), + read_stream: BufReader::new(conn.try_clone().unwrap()), + write_stream: Mutex::new(BufWriter::new(conn)), shutdown: &shutdown, inflight_rpcs: &inflight_rpcs, - timeout: read_timeout, server: &server, + pool: &pool, }; if let Err(err) = handler.handle_conn() { info!("ConnectionHandler: err in connection handling: {:?}", err); @@ -330,14 +329,14 @@ struct Packet { message: T, } -struct RpcFutures(Result>>); +struct RpcFutures(Result>>>); impl RpcFutures { fn new() -> RpcFutures { RpcFutures(Ok(HashMap::new())) } - fn insert_tx(&mut self, id: u64, tx: Sender) -> Result<()> { + fn insert_tx(&mut self, id: u64, tx: Sender>) -> Result<()> { match self.0 { Ok(ref mut requests) => { requests.insert(id, tx); @@ -359,7 +358,7 @@ impl RpcFutures { fn complete_reply(&mut self, id: u64, reply: Reply) { if let Some(tx) = self.0.as_mut().unwrap().remove(&id) { - if let Err(e) = tx.send(reply) { + if let Err(e) = tx.send(Ok(reply)) { info!("Reader: could not complete reply: {:?}", e); } } else { @@ -376,119 +375,157 @@ impl RpcFutures { } } -struct Reader { - requests: Arc>>, -} +fn write(outbound: Receiver<(Request, Sender>)>, + requests: Arc>>, + stream: TcpStream) + where Request: serde::Serialize, + Reply: serde::Deserialize, +{ + let mut next_id = 0; + let mut stream = BufWriter::new(stream); + loop { + let (request, tx) = match outbound.recv() { + Err(e) => { + debug!("Writer: all senders have exited ({:?}). Returning.", e); + return; + } + Ok(request) => request, + }; + if let Err(e) = requests.lock().unwrap().insert_tx(next_id, tx.clone()) { + report_error(&tx, e); + // Once insert_tx returns Err, it will continue to do so. However, continue here so + // that any other clients who sent requests will also recv the Err. + continue; + } + let id = next_id; + next_id += 1; + let packet = Packet { + rpc_id: id, + message: request, + }; + debug!("Writer: calling rpc({:?})", id); + if let Err(e) = bincode::serde::serialize_into(&mut stream, + &packet, + bincode::SizeLimit::Infinite) { + report_error(&tx, e.into()); + // Typically we'd want to notify the client of any Err returned by remove_tx, but in + // this case the client already hit an Err, and doesn't need to know about this one, as + // well. + let _ = requests.lock().unwrap().remove_tx(id); + continue; + } + if let Err(e) = stream.flush() { + report_error(&tx, e.into()); + } + } -impl Reader { - fn read(self, mut stream: TcpStream) + fn report_error(tx: &Sender>, e: Error) where Reply: serde::Deserialize { - loop { - let packet: bincode::serde::DeserializeResult> = - bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); - match packet { - Ok(Packet { - rpc_id: id, - message: reply - }) => { - debug!("Client: received message, id={}", id); - self.requests.lock().unwrap().complete_reply(id, reply); - } - Err(err) => { - warn!("Client: reader thread encountered an unexpected error while parsing; \ - returning now. Error: {:?}", - err); - self.requests.lock().unwrap().set_error(err); - break; - } + // Clone the err so we can log it if sending fails + if let Err(e2) = tx.send(Err(e.clone())) { + debug!("Error encountered while trying to send an error. \ + Initial error: {:?}; Send error: {:?}", + e, + e2); + } + } + +} + +fn read(requests: Arc>>, stream: TcpStream) + where Reply: serde::Deserialize +{ + let mut stream = BufReader::new(stream); + loop { + let packet: bincode::serde::DeserializeResult> = + bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); + match packet { + Ok(Packet { + rpc_id: id, + message: reply + }) => { + debug!("Client: received message, id={}", id); + requests.lock().unwrap().complete_reply(id, reply); + } + Err(err) => { + warn!("Client: reader thread encountered an unexpected error while parsing; \ + returning now. Error: {:?}", + err); + requests.lock().unwrap().set_error(err); + break; } } } } -fn increment(cur_id: &mut u64) -> u64 { - let id = *cur_id; - *cur_id += 1; - id -} - -struct SyncedClientState { - next_id: u64, - stream: TcpStream, -} - /// A client stub that connects to a server to run rpcs. pub struct Client where Request: serde::ser::Serialize { - synced_state: Mutex, + // The guard is in an option so it can be joined in the drop fn + reader_guard: Arc>>, + outbound: Sender<(Request, Sender>)>, requests: Arc>>, - reader_guard: Option>, - timeout: Option, - _request: PhantomData, + shutdown: TcpStream, } impl Client - where Reply: serde::de::Deserialize + Send + 'static, - Request: serde::ser::Serialize + where Request: serde::ser::Serialize + Send + 'static, + Reply: serde::de::Deserialize + Send + 'static { /// Create a new client that connects to `addr`. The client uses the given timeout /// for both reads and writes. pub fn new(addr: A, timeout: Option) -> io::Result { let stream = try!(TcpStream::connect(addr)); - let requests = Arc::new(Mutex::new(RpcFutures::new())); + try!(stream.set_read_timeout(timeout)); + try!(stream.set_write_timeout(timeout)); let reader_stream = try!(stream.try_clone()); - let reader = Reader { requests: requests.clone() }; - let reader_guard = thread::spawn(move || reader.read(reader_stream)); + let writer_stream = try!(stream.try_clone()); + let requests = Arc::new(Mutex::new(RpcFutures::new())); + let reader_requests = requests.clone(); + let writer_requests = requests.clone(); + let (tx, rx) = channel(); + let reader_guard = thread::spawn(move || read(reader_requests, reader_stream)); + thread::spawn(move || write(rx, writer_requests, writer_stream)); Ok(Client { - synced_state: Mutex::new(SyncedClientState { - next_id: 0, - stream: stream, - }), + reader_guard: Arc::new(Some(reader_guard)), + outbound: tx, requests: requests, - reader_guard: Some(reader_guard), - timeout: timeout, - _request: PhantomData, + shutdown: stream, }) } - fn rpc_internal(&self, request: &Request) -> Result> + /// Clones the Client so that it can be shared across threads. + pub fn try_clone(&self) -> io::Result> { + Ok(Client { + reader_guard: self.reader_guard.clone(), + outbound: self.outbound.clone(), + requests: self.requests.clone(), + shutdown: try!(self.shutdown.try_clone()), + }) + } + + fn rpc_internal(&self, request: Request) -> Receiver> where Request: serde::ser::Serialize + fmt::Debug + Send + 'static { let (tx, rx) = channel(); - let mut state = self.synced_state.lock().unwrap(); - let id = increment(&mut state.next_id); - try!(self.requests.lock().unwrap().insert_tx(id, tx)); - let packet = Packet { - rpc_id: id, - message: request, - }; - try!(state.stream.set_write_timeout(self.timeout)); - try!(state.stream.set_read_timeout(self.timeout)); - debug!("Client: calling rpc({:?})", request); - if let Err(err) = bincode::serde::serialize_into(&mut state.stream, - &packet, - bincode::SizeLimit::Infinite) { - warn!("Client: failed to write packet.\nPacket: {:?}\nError: {:?}", - packet, - err); - try!(self.requests.lock().unwrap().remove_tx(id)); - } - Ok(rx) + self.outbound.send((request, tx)).unwrap(); + rx } /// Run the specified rpc method on the server this client is connected to - pub fn rpc(&self, request: &Request) -> Result + pub fn rpc(&self, request: Request) -> Result where Request: serde::ser::Serialize + fmt::Debug + Send + 'static { - try!(self.rpc_internal(request)) + self.rpc_internal(request) .recv() .map_err(|_| self.requests.lock().unwrap().get_error()) + .and_then(|reply| reply) } /// Asynchronously run the specified rpc method on the server this client is connected to - pub fn rpc_async(&self, request: &Request) -> Future + pub fn rpc_async(&self, request: Request) -> Future where Request: serde::ser::Serialize + fmt::Debug + Send + 'static { Future { @@ -502,14 +539,18 @@ impl Drop for Client where Request: serde::ser::Serialize { fn drop(&mut self) { - if let Err(e) = self.synced_state - .lock() - .unwrap() - .stream - .shutdown(::std::net::Shutdown::Both) { - warn!("Client: couldn't shutdown reader thread: {:?}", e); - } else { - self.reader_guard.take().unwrap().join().unwrap(); + debug!("Dropping Client."); + if let Some(reader_guard) = Arc::get_mut(&mut self.reader_guard) { + debug!("Attempting to shut down writer and reader threads."); + if let Err(e) = self.shutdown.shutdown(::std::net::Shutdown::Both) { + warn!("Client: couldn't shutdown writer and reader threads: {:?}", e); + } else { + // We only join if we know the TcpStream was shut down. Otherwise we might never + // finish. + debug!("Joining writer and reader."); + reader_guard.take().unwrap().join().unwrap(); + debug!("Successfully joined writer and reader."); + } } } } @@ -517,8 +558,8 @@ impl Drop for Client #[cfg(test)] mod test { extern crate env_logger; - use super::{Client, Serve, serve_async}; + use scoped_pool::Pool; use std::sync::{Arc, Barrier, Mutex}; use std::thread; use std::time::Duration; @@ -581,10 +622,10 @@ mod test { let addr = serve_handle.local_addr().clone(); let client = Client::new(addr, None).unwrap(); assert_eq!(Reply::Increment(0), - client.rpc(&Request::Increment).unwrap()); + client.rpc(Request::Increment).unwrap()); assert_eq!(1, server.count()); assert_eq!(Reply::Increment(1), - client.rpc(&Request::Increment).unwrap()); + client.rpc(Request::Increment).unwrap()); assert_eq!(2, server.count()); drop(client); serve_handle.shutdown(); @@ -625,7 +666,7 @@ mod test { let addr = serve_handle.local_addr().clone(); let client: Client = Client::new(addr, None).unwrap(); let thread = thread::spawn(move || serve_handle.shutdown()); - info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment)); + info!("force_shutdown:: rpc1: {:?}", client.rpc(Request::Increment)); thread.join().unwrap(); } @@ -637,34 +678,29 @@ mod test { let addr = serve_handle.local_addr().clone(); let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); serve_handle.shutdown(); - match client.rpc(&Request::Increment) { + match client.rpc(Request::Increment) { Err(super::Error::ConnectionBroken) => {} // success otherwise => panic!("Expected Err(ConnectionBroken), got {:?}", otherwise), } - let _ = client.rpc(&Request::Increment); // Test whether second failure hangs + let _ = client.rpc(Request::Increment); // Test whether second failure hangs } #[test] fn concurrent() { let _ = env_logger::init(); let concurrency = 10; + let pool = Pool::new(concurrency); let server = Arc::new(BarrierServer::new(concurrency)); let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); - let mut join_handles = vec![]; - for _ in 0..concurrency { - let my_client = client.clone(); - join_handles.push(thread::spawn(move || my_client.rpc(&Request::Increment).unwrap())); - } - for handle in join_handles.into_iter() { - handle.join().unwrap(); - } + let client: Client = Client::new(addr, None).unwrap(); + pool.scoped(|scope| { + for _ in 0..concurrency { + let client = client.try_clone().unwrap(); + scope.execute(move || { client.rpc(Request::Increment).unwrap(); }); + } + }); assert_eq!(concurrency as u64, server.count()); - let client = match Arc::try_unwrap(client) { - Err(_) => panic!("couldn't unwrap arc"), - Ok(c) => c, - }; drop(client); serve_handle.shutdown(); } @@ -678,12 +714,11 @@ mod test { let client: Client = Client::new(addr, None).unwrap(); // Drop future immediately; does the reader channel panic when sending? - client.rpc_async(&Request::Increment); + client.rpc_async(Request::Increment); // If the reader panicked, this won't succeed - client.rpc_async(&Request::Increment); + client.rpc_async(Request::Increment); drop(client); serve_handle.shutdown(); - assert_eq!(server.count(), 2); } }