diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs index 7013882..46e04ff 100644 --- a/tarpc/src/macros.rs +++ b/tarpc/src/macros.rs @@ -21,7 +21,7 @@ macro_rules! client_methods { ) => ( $(#[$attr])* pub fn $fn_name(&self, $($arg: $in_),*) -> $crate::Result<$out> { - let reply = try!((self.0).rpc(&request_variant!($fn_name $($arg),*))); + let reply = try!((self.0).rpc(request_variant!($fn_name $($arg),*))); let __Reply::$fn_name(reply) = reply; Ok(reply) } @@ -32,7 +32,7 @@ macro_rules! client_methods { )*) => ( $( $(#[$attr])* pub fn $fn_name(&self, $($arg: $in_),*) -> $crate::Result<$out> { - let reply = try!((self.0).rpc(&request_variant!($fn_name $($arg),*))); + let reply = try!((self.0).rpc(request_variant!($fn_name $($arg),*))); if let __Reply::$fn_name(reply) = reply { Ok(reply) } else { @@ -57,7 +57,7 @@ macro_rules! async_client_methods { let __Reply::$fn_name(reply) = reply; reply } - let reply = (self.0).rpc_async(&request_variant!($fn_name $($arg),*)); + let reply = (self.0).rpc_async(request_variant!($fn_name $($arg),*)); Future { future: reply, mapper: mapper, @@ -77,7 +77,7 @@ macro_rules! async_client_methods { panic!("Incorrect reply variant returned from protocol::Clientrpc; expected `{}`, but got {:?}", stringify!($fn_name), reply); } } - let reply = (self.0).rpc_async(&request_variant!($fn_name $($arg),*)); + let reply = (self.0).rpc_async(request_variant!($fn_name $($arg),*)); Future { future: reply, mapper: mapper, diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index 253351c..4a553f0 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -13,7 +13,6 @@ use std::fmt; use std::io::{self, BufReader, BufWriter, Read, Write}; use std::convert; use std::collections::HashMap; -use std::marker::PhantomData; use std::mem; use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs}; use std::sync::{Arc, Condvar, Mutex}; @@ -61,7 +60,7 @@ pub type Result = ::std::result::Result; /// An asynchronous RPC call pub struct Future { - rx: Result>, + rx: Receiver>, requests: Arc>> } @@ -69,9 +68,9 @@ impl Future { /// Block until the result of the RPC call is available pub fn get(self) -> Result { let requests = self.requests; - try!(self.rx) - .recv() + self.rx.recv() .map_err(|_| requests.lock().unwrap().get_error()) + .and_then(|reply| reply) } } @@ -332,14 +331,14 @@ struct Packet { message: T, } -struct RpcFutures(Result>>); +struct RpcFutures(Result>>>); impl RpcFutures { fn new() -> RpcFutures { RpcFutures(Ok(HashMap::new())) } - fn insert_tx(&mut self, id: u64, tx: Sender) -> Result<()> { + fn insert_tx(&mut self, id: u64, tx: Sender>) -> Result<()> { match self.0 { Ok(ref mut requests) => { requests.insert(id, tx); @@ -361,7 +360,7 @@ impl RpcFutures { fn complete_reply(&mut self, id: u64, reply: Reply) { if let Some(tx) = self.0.as_mut().unwrap().remove(&id) { - if let Err(e) = tx.send(reply) { + if let Err(e) = tx.send(Ok(reply)) { info!("Reader: could not complete reply: {:?}", e); } } else { @@ -378,62 +377,104 @@ impl RpcFutures { } } -struct Reader { - requests: Arc>>, -} +fn write(outbound: Receiver<(Request, Sender>)>, + requests: Arc>>, + stream: TcpStream) + where Request: serde::Serialize, + Reply: serde::Deserialize, +{ + let mut next_id = 0; + let mut stream = BufWriter::new(stream); + loop { + let (request, tx) = match outbound.recv() { + Err(e) => { + debug!("Writer: all senders have exited ({:?}). Returning.", e); + return; + } + Ok(request) => request, + }; + if let Err(e) = requests.lock().unwrap().insert_tx(next_id, tx.clone()) { + report_error(&tx, e); + // Once insert_tx returns Err, it will continue to do so. However, continue here so + // that any other clients who sent requests will also recv the Err. + continue; + } + let id = next_id; + next_id += 1; + let packet = Packet { + rpc_id: id, + message: request, + }; + debug!("Writer: calling rpc({:?})", id); + if let Err(e) = bincode::serde::serialize_into(&mut stream, + &packet, + bincode::SizeLimit::Infinite) { + report_error(&tx, e.into()); + // Typically we'd want to notify the client of any Err returned by remove_tx, but in + // this case the client already hit an Err, and doesn't need to know about this one, as + // well. + let _ = requests.lock().unwrap().remove_tx(id); + continue; + } + if let Err(e) = stream.flush() { + report_error(&tx, e.into()); + } + } -impl Reader { - fn read(self, stream: TcpStream) + fn report_error(tx: &Sender>, e: Error) where Reply: serde::Deserialize { - let mut stream = BufReader::new(stream); - loop { - let packet: bincode::serde::DeserializeResult> = - bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); - match packet { - Ok(Packet { - rpc_id: id, - message: reply - }) => { - debug!("Client: received message, id={}", id); - self.requests.lock().unwrap().complete_reply(id, reply); - } - Err(err) => { - warn!("Client: reader thread encountered an unexpected error while parsing; \ - returning now. Error: {:?}", - err); - self.requests.lock().unwrap().set_error(err); - break; - } + // Clone the err so we can log it if sending fails + if let Err(e2) = tx.send(Err(e.clone())) { + debug!("Error encountered while trying to send an error. \ + Initial error: {:?}; Send error: {:?}", + e, + e2); + } + } + +} + +fn read(requests: Arc>>, stream: TcpStream) + where Reply: serde::Deserialize +{ + let mut stream = BufReader::new(stream); + loop { + let packet: bincode::serde::DeserializeResult> = + bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); + match packet { + Ok(Packet { + rpc_id: id, + message: reply + }) => { + debug!("Client: received message, id={}", id); + requests.lock().unwrap().complete_reply(id, reply); + } + Err(err) => { + warn!("Client: reader thread encountered an unexpected error while parsing; \ + returning now. Error: {:?}", + err); + requests.lock().unwrap().set_error(err); + break; } } } } -fn increment(cur_id: &mut u64) -> u64 { - let id = *cur_id; - *cur_id += 1; - id -} - -struct SyncedClientState { - next_id: u64, - stream: BufWriter, -} - /// A client stub that connects to a server to run rpcs. pub struct Client where Request: serde::ser::Serialize { - synced_state: Mutex, + // The guard is in an option so it can be joined in the drop fn + reader_guard: Arc>>, + outbound: Sender<(Request, Sender>)>, requests: Arc>>, - reader_guard: Option>, - _request: PhantomData, + shutdown: TcpStream, } impl Client - where Reply: serde::de::Deserialize + Send + 'static, - Request: serde::ser::Serialize + where Request: serde::ser::Serialize + Send + 'static, + Reply: serde::de::Deserialize + Send + 'static { /// Create a new client that connects to `addr`. The client uses the given timeout /// for both reads and writes. @@ -441,60 +482,52 @@ impl Client let stream = try!(TcpStream::connect(addr)); try!(stream.set_read_timeout(timeout)); try!(stream.set_write_timeout(timeout)); - let requests = Arc::new(Mutex::new(RpcFutures::new())); let reader_stream = try!(stream.try_clone()); - let reader = Reader { requests: requests.clone() }; - let reader_guard = thread::spawn(move || reader.read(reader_stream)); + let writer_stream = try!(stream.try_clone()); + let requests = Arc::new(Mutex::new(RpcFutures::new())); + let reader_requests = requests.clone(); + let writer_requests = requests.clone(); + let (tx, rx) = channel(); + let reader_guard = thread::spawn(move || read(reader_requests, reader_stream)); + thread::spawn(move || write(rx, writer_requests, writer_stream)); Ok(Client { - synced_state: Mutex::new(SyncedClientState { - next_id: 0, - stream: BufWriter::new(stream), - }), + reader_guard: Arc::new(Some(reader_guard)), + outbound: tx, requests: requests, - reader_guard: Some(reader_guard), - _request: PhantomData, + shutdown: stream, }) } - fn rpc_internal(&self, request: &Request) -> Result> + /// Clones the Client so that it can be shared across threads. + pub fn try_clone(&self) -> io::Result> { + Ok(Client { + reader_guard: self.reader_guard.clone(), + outbound: self.outbound.clone(), + requests: self.requests.clone(), + shutdown: try!(self.shutdown.try_clone()), + }) + } + + fn rpc_internal(&self, request: Request) -> Receiver> where Request: serde::ser::Serialize + fmt::Debug + Send + 'static { let (tx, rx) = channel(); - let mut state = self.synced_state.lock().unwrap(); - let id = increment(&mut state.next_id); - try!(self.requests.lock().unwrap().insert_tx(id, tx)); - let packet = Packet { - rpc_id: id, - message: request, - }; - debug!("Client: calling rpc({:?})", request); - if let Err(err) = bincode::serde::serialize_into(&mut state.stream, - &packet, - bincode::SizeLimit::Infinite) { - warn!("Client: failed to write packet.\nPacket: {:?}\nError: {:?}", - packet, - err); - try!(self.requests.lock().unwrap().remove_tx(id)); - } - if let Err(err) = state.stream.flush() { - warn!("Client: failed to flush packet.\nPacket: {:?}\nError: {:?}", - packet, - err); - } - Ok(rx) + self.outbound.send((request, tx)).unwrap(); + rx } /// Run the specified rpc method on the server this client is connected to - pub fn rpc(&self, request: &Request) -> Result + pub fn rpc(&self, request: Request) -> Result where Request: serde::ser::Serialize + fmt::Debug + Send + 'static { - try!(self.rpc_internal(request)) + self.rpc_internal(request) .recv() .map_err(|_| self.requests.lock().unwrap().get_error()) + .and_then(|reply| reply) } /// Asynchronously run the specified rpc method on the server this client is connected to - pub fn rpc_async(&self, request: &Request) -> Future + pub fn rpc_async(&self, request: Request) -> Future where Request: serde::ser::Serialize + fmt::Debug + Send + 'static { Future { @@ -508,15 +541,18 @@ impl Drop for Client where Request: serde::ser::Serialize { fn drop(&mut self) { - if let Err(e) = self.synced_state - .lock() - .unwrap() - .stream - .get_mut() - .shutdown(::std::net::Shutdown::Both) { - warn!("Client: couldn't shutdown reader thread: {:?}", e); - } else { - self.reader_guard.take().unwrap().join().unwrap(); + debug!("Dropping Client."); + if let Some(reader_guard) = Arc::get_mut(&mut self.reader_guard) { + debug!("Attempting to shut down writer and reader threads."); + if let Err(e) = self.shutdown.shutdown(::std::net::Shutdown::Both) { + warn!("Client: couldn't shutdown writer and reader threads: {:?}", e); + } else { + // We only join if we know the TcpStream was shut down. Otherwise we might never + // finish. + debug!("Joining writer and reader."); + reader_guard.take().unwrap().join().unwrap(); + debug!("Successfully joined writer and reader."); + } } } } @@ -524,8 +560,8 @@ impl Drop for Client #[cfg(test)] mod test { extern crate env_logger; - use super::{Client, Serve, serve_async}; + use scoped_pool::Pool; use std::sync::{Arc, Barrier, Mutex}; use std::thread; use std::time::Duration; @@ -588,10 +624,10 @@ mod test { let addr = serve_handle.local_addr().clone(); let client = Client::new(addr, None).unwrap(); assert_eq!(Reply::Increment(0), - client.rpc(&Request::Increment).unwrap()); + client.rpc(Request::Increment).unwrap()); assert_eq!(1, server.count()); assert_eq!(Reply::Increment(1), - client.rpc(&Request::Increment).unwrap()); + client.rpc(Request::Increment).unwrap()); assert_eq!(2, server.count()); drop(client); serve_handle.shutdown(); @@ -632,7 +668,7 @@ mod test { let addr = serve_handle.local_addr().clone(); let client: Client = Client::new(addr, None).unwrap(); let thread = thread::spawn(move || serve_handle.shutdown()); - info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment)); + info!("force_shutdown:: rpc1: {:?}", client.rpc(Request::Increment)); thread.join().unwrap(); } @@ -644,34 +680,29 @@ mod test { let addr = serve_handle.local_addr().clone(); let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); serve_handle.shutdown(); - match client.rpc(&Request::Increment) { + match client.rpc(Request::Increment) { Err(super::Error::ConnectionBroken) => {} // success otherwise => panic!("Expected Err(ConnectionBroken), got {:?}", otherwise), } - let _ = client.rpc(&Request::Increment); // Test whether second failure hangs + let _ = client.rpc(Request::Increment); // Test whether second failure hangs } #[test] fn concurrent() { let _ = env_logger::init(); let concurrency = 10; + let pool = Pool::new(concurrency); let server = Arc::new(BarrierServer::new(concurrency)); let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); - let mut join_handles = vec![]; - for _ in 0..concurrency { - let my_client = client.clone(); - join_handles.push(thread::spawn(move || my_client.rpc(&Request::Increment).unwrap())); - } - for handle in join_handles.into_iter() { - handle.join().unwrap(); - } + let client: Client = Client::new(addr, None).unwrap(); + pool.scoped(|scope| { + for _ in 0..concurrency { + let client = client.try_clone().unwrap(); + scope.execute(move || { client.rpc(Request::Increment).unwrap(); }); + } + }); assert_eq!(concurrency as u64, server.count()); - let client = match Arc::try_unwrap(client) { - Err(_) => panic!("couldn't unwrap arc"), - Ok(c) => c, - }; drop(client); serve_handle.shutdown(); } @@ -685,9 +716,9 @@ mod test { let client: Client = Client::new(addr, None).unwrap(); // Drop future immediately; does the reader channel panic when sending? - client.rpc_async(&Request::Increment); + client.rpc_async(Request::Increment); // If the reader panicked, this won't succeed - client.rpc_async(&Request::Increment); + client.rpc_async(Request::Increment); drop(client); serve_handle.shutdown();