diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 94d9b73..6a35757 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -8,3 +8,4 @@ serde = "*" bincode = "*" serde_macros = "*" log = "*" +env_logger = "*" diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 272783f..05823a0 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -18,6 +18,7 @@ //! } //! //! use self::my_server::*; +//! use std::time::Duration; //! //! impl my_server::Service for () { //! fn hello(&self, s: String) -> String { @@ -30,8 +31,9 @@ //! //! fn main() { //! let addr = "127.0.0.1:9000"; -//! let shutdown = my_server::serve(addr, ()).unwrap(); -//! let client = Client::new(addr).unwrap(); +//! let shutdown = my_server::serve(addr, (), +//! Some(Duration::from_secs(30))).unwrap(); +//! let client = Client::new(addr, None).unwrap(); //! assert_eq!(3, client.add(1, 2).unwrap()); //! assert_eq!("Hello, Mom!".to_string(), //! client.hello("Mom".to_string()).unwrap()); diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs index 1148d3f..a71d412 100644 --- a/tarpc/src/macros.rs +++ b/tarpc/src/macros.rs @@ -49,15 +49,15 @@ macro_rules! request_variant { // The main macro that creates RPC services. #[macro_export] -macro_rules! rpc { +macro_rules! rpc { ( mod $server:ident { service { - $( + $( $(#[$attr:meta])* rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty; - )* + )* } } ) => { @@ -66,7 +66,7 @@ macro_rules! rpc { items { } - service { + service { $( $(#[$attr])* rpc $fn_name($($arg: $in_),*) -> $out; @@ -125,10 +125,11 @@ macro_rules! rpc { impl Client { #[doc="Create a new client that connects to the given address."] - pub fn new(addr: A) -> $crate::Result + pub fn new(addr: A, timeout: ::std::option::Option<::std::time::Duration>) + -> $crate::Result where A: ::std::net::ToSocketAddrs, { - let inner = try!($crate::protocol::Client::new(addr)); + let inner = try!($crate::protocol::Client::new(addr, timeout)); Ok(Client(inner)) } @@ -137,9 +138,11 @@ macro_rules! rpc { struct __Server(S); - impl $crate::protocol::Serve<__Request, __Reply> for __Server + impl $crate::protocol::Serve for __Server where S: 'static + Service { + type Request = __Request; + type Reply = __Reply; fn serve(&self, request: __Request) -> __Reply { match request { $( @@ -151,12 +154,15 @@ macro_rules! rpc { } #[doc="Start a running service."] - pub fn serve(addr: A, service: S) -> $crate::Result<$crate::protocol::ServeHandle> + pub fn serve(addr: A, + service: S, + read_timeout: ::std::option::Option<::std::time::Duration>) + -> $crate::Result<$crate::protocol::ServeHandle> where A: ::std::net::ToSocketAddrs, S: 'static + Service { let server = ::std::sync::Arc::new(__Server(service)); - Ok(try!($crate::protocol::serve_async(addr, server))) + Ok(try!($crate::protocol::serve_async(addr, server, read_timeout))) } } } @@ -165,6 +171,12 @@ macro_rules! rpc { #[cfg(test)] #[allow(dead_code)] mod test { + use std::time::Duration; + + fn test_timeout() -> Option { + Some(Duration::from_secs(5)) + } + rpc! { mod my_server { items { @@ -197,8 +209,8 @@ mod test { fn simple_test() { println!("Starting"); let addr = "127.0.0.1:9000"; - let shutdown = my_server::serve(addr, ()).unwrap(); - let client = Client::new(addr).unwrap(); + let shutdown = my_server::serve(addr, (), test_timeout()).unwrap(); + let client = Client::new(addr, None).unwrap(); assert_eq!(3, client.add(1, 2).unwrap()); let foo = Foo { message: "Adam".into() }; let want = Foo { message: format!("Hello, {}", &foo.message) }; diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index f24f6bd..1414570 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -6,8 +6,10 @@ use std::convert; use std::collections::HashMap; use std::marker::PhantomData; use std::net::{TcpListener, TcpStream, SocketAddr, ToSocketAddrs}; -use std::sync::{self, Mutex, Arc}; +use std::sync::{self, Arc, Condvar, Mutex}; use std::sync::mpsc::{channel, Sender, TryRecvError}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; use std::thread::{self, JoinHandle}; /// Client errors that can occur during rpc calls @@ -23,6 +25,10 @@ pub enum Error { /// Channels are used for the client's inter-thread communication. This message is /// propagated if the receiver unexpectedly hangs up. Sender, + /// An internal message failed to be received. + /// Channels are used for the client's inter-thread communication. This message is + /// propagated if the sender unexpectedly hangs up. + Receiver, /// The server hung up. ConnectionBroken, } @@ -57,47 +63,146 @@ impl convert::From> for Error { } } +impl convert::From for Error { + fn from(_: sync::mpsc::RecvError) -> Error { + Error::Receiver + } +} + /// Return type of rpc calls: either the successful return value, or a client error. pub type Result = ::std::result::Result; -fn handle_conn(stream: TcpStream, f: F) -> Result<()> - where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, - Reply: 'static + fmt::Debug + serde::ser::Serialize, - F: 'static + Clone + Serve -{ - let mut read_stream = try!(stream.try_clone()); - let stream = Arc::new(Mutex::new(stream)); - loop { - let request_packet: Packet = - try!(bincode::serde::deserialize_from(&mut read_stream, bincode::SizeLimit::Infinite)); - match request_packet { - Packet::Shutdown => { - let stream = stream.clone(); - let mut my_stream = stream.lock().unwrap(); - try!(bincode::serde::serialize_into(&mut *my_stream, - &request_packet, - bincode::SizeLimit::Infinite)); - break; - } - Packet::Message(id, message) => { - let f = f.clone(); - let arc_stream = stream.clone(); - thread::spawn(move || { - let reply = f.serve(message); - let reply_packet = Packet::Message(id, reply); - let mut my_stream = arc_stream.lock().unwrap(); - bincode::serde::serialize_into(&mut *my_stream, - &reply_packet, - bincode::SizeLimit::Infinite) - .unwrap(); - }); - } - } - } - Ok(()) +#[derive(Clone)] +struct InflightRpcs { + inflight_rpcs: Arc<(Mutex, Condvar)>, } -/// Provides methods for blocking until the server completes, +impl InflightRpcs { + fn new(mutex: Mutex, cvar: Condvar) -> InflightRpcs { + InflightRpcs { inflight_rpcs: Arc::new((mutex, cvar)) } + } + + fn wait_until_zero(&self) { + let &(ref count, ref cvar) = &*self.inflight_rpcs; + let mut count = count.lock().unwrap(); + while *count != 0 { + count = cvar.wait(count).unwrap(); + } + info!("serve_async: shutdown complete ({} connections alive)", + *count); + } + + fn increment(&self) { + let &(ref count, _) = &*self.inflight_rpcs; + *count.lock().unwrap() += 1; + } + + fn decrement(&self) { + let &(ref count, _) = &*self.inflight_rpcs; + *count.lock().unwrap() -= 1; + } + + + fn decrement_and_notify(&self) { + let &(ref count, ref cvar) = &*self.inflight_rpcs; + *count.lock().unwrap() -= 1; + cvar.notify_one(); + } + +} + +struct ConnectionHandler { + read_stream: TcpStream, + write_stream: Arc>, + shutdown: Arc, + inflight_rpcs: InflightRpcs, + timeout: Option, +} + +impl Drop for ConnectionHandler { + fn drop(&mut self) { + trace!("ConnectionHandler: finished serving client."); + self.inflight_rpcs.decrement_and_notify(); + } +} + +impl ConnectionHandler { + fn read(&mut self) -> bincode::serde::DeserializeResult> + where Request: serde::de::Deserialize + { + try!(self.read_stream.set_read_timeout(self.timeout)); + bincode::serde::deserialize_from(&mut self.read_stream, bincode::SizeLimit::Infinite) + } + + fn handle_conn(&mut self, f: F) -> Result<()> + where F: 'static + Clone + Serve + { + trace!("ConnectionHandler: serving client..."); + loop { + match self.read() { + Ok(Packet { + rpc_id: id, + message: message + }) => { + let f = f.clone(); + let inflight_rpcs = self.inflight_rpcs.clone(); + inflight_rpcs.increment(); + let stream = self.write_stream.clone(); + thread::spawn(move || { + let reply = f.serve(message); + let reply_packet = Packet { + rpc_id: id, + message: reply + }; + let mut stream = stream.lock().unwrap(); + if let Err(e) = + bincode::serde::serialize_into(&mut *stream, + &reply_packet, + bincode::SizeLimit::Infinite) { + warn!("ConnectionHandler: failed to write reply to Client: {:?}", + e); + } + inflight_rpcs.decrement(); + }); + if self.shutdown.load(Ordering::SeqCst) { + info!("ConnectionHandler: server shutdown, so closing connection."); + break; + } + } + Err(bincode::serde::DeserializeError::IoError(ref err)) + if Self::timed_out(err.kind()) => { + if !self.shutdown.load(Ordering::SeqCst) { + info!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \ + retrying read.", + err); + continue; + } else { + info!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ + closing connection.", + err); + break; + } + } + Err(e) => { + warn!("ConnectionHandler: closing client connection due to error while \ + serving: {:?}", + e); + return Err(e.into()); + } + } + } + Ok(()) + } + + fn timed_out(error_kind: io::ErrorKind) -> bool { + match error_kind { + io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock => true, + _ => false, + } + } +} + +/// Provides methods for blocking until the server completes, pub struct ServeHandle { tx: Sender<()>, join_handle: JoinHandle<()>, @@ -118,47 +223,64 @@ impl ServeHandle { /// Shutdown the server. Gracefully shuts down the serve thread but currently does not /// gracefully close open connections. pub fn shutdown(self) { + info!("ServeHandle: attempting to shut down the server."); self.tx.send(()).expect(&line!().to_string()); if let Ok(_) = TcpStream::connect(self.addr) { self.join_handle.join().expect(&line!().to_string()); } else { - warn!("Best effort shutdown of serve thread failed"); + warn!("ServeHandle: best effort shutdown of serve thread failed"); } } } -/// Start -pub fn serve_async(addr: A, f: F) -> io::Result +/// Start +pub fn serve_async(addr: A, f: F, read_timeout: Option) -> io::Result where A: ToSocketAddrs, - Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, - Reply: 'static + fmt::Debug + serde::ser::Serialize, - F: 'static + Clone + Send + Serve + F: 'static + Clone + Send + Serve { let listener = try!(TcpListener::bind(&addr)); let addr = try!(listener.local_addr()); - info!("Spinning up server on {:?}", addr); + info!("serve_async: spinning up server on {:?}", addr); let (die_tx, die_rx) = channel(); let join_handle = thread::spawn(move || { + let shutdown = Arc::new(AtomicBool::new(false)); + let inflight_rpcs = InflightRpcs::new(Mutex::new(0), Condvar::new()); for conn in listener.incoming() { match die_rx.try_recv() { - Ok(_) => break, + Ok(_) => { + info!("serve_async: shutdown received. Waiting for open connections to \ + return..."); + shutdown.store(true, Ordering::SeqCst); + inflight_rpcs.wait_until_zero(); + break; + } Err(TryRecvError::Disconnected) => { - info!("Sender disconnected."); + info!("serve_async: sender disconnected."); break; } _ => (), } let conn = match conn { Err(err) => { - error!("Failed to accept connection: {:?}", err); + error!("serve_async: failed to accept connection: {:?}", err); return; } Ok(c) => c, }; let f = f.clone(); + let shutdown = shutdown.clone(); + inflight_rpcs.increment(); + let inflight_rpcs = inflight_rpcs.clone(); + let mut handler = ConnectionHandler { + read_stream: conn.try_clone().unwrap(), + write_stream: Arc::new(Mutex::new(conn)), + shutdown: shutdown, + inflight_rpcs: inflight_rpcs, + timeout: read_timeout, + }; thread::spawn(move || { - if let Err(err) = handle_conn(conn, f) { - error!("Error in connection handling: {:?}", err); + if let Err(err) = handler.handle_conn(f) { + error!("ConnectionHandler: error in connection handling: {:?}", err); } }); } @@ -171,46 +293,74 @@ pub fn serve_async(addr: A, f: F) -> io::Result: Send + Sync { +pub trait Serve: Send + Sync { + /// The type of request received by the server + type Request: 'static + fmt::Debug + serde::ser::Serialize + serde::de::Deserialize + Send; + /// The type of reply sent by the server + type Reply: 'static + fmt::Debug + serde::ser::Serialize + serde::de::Deserialize; + /// Return a reply for a given request - fn serve(&self, request: Request) -> Reply; + fn serve(&self, request: Self::Request) -> Self::Reply; } -impl Serve for Arc - where S: Serve +impl Serve for Arc where S: Serve { - fn serve(&self, request: Request) -> Reply { + type Request = S::Request; + type Reply = S::Reply; + + fn serve(&self, request: S::Request) -> S::Reply { S::serve(self, request) } } #[derive(Debug, Clone, Serialize, Deserialize)] -enum Packet { - Message(u64, T), - Shutdown, +struct Packet { + rpc_id: u64, + message: T } -fn reader(mut stream: TcpStream, requests: Arc>>>) - where Reply: serde::Deserialize -{ - loop { - let packet: bincode::serde::DeserializeResult> = - bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); - match packet { - Ok(Packet::Message(id, reply)) => { - let mut requests = requests.lock().unwrap(); - let reply_tx = requests.remove(&id).unwrap(); - reply_tx.send(reply).unwrap(); +struct Reader { + requests: Arc>>>>, +} + +impl Reader { + fn read(self, mut stream: TcpStream) + where Reply: serde::Deserialize + { + loop { + let packet: bincode::serde::DeserializeResult> = + bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); + match packet { + Ok(Packet { + rpc_id: id, + message: reply + }) => { + debug!("Client: received message, id={}", id); + let mut requests = self.requests.lock().unwrap(); + let mut requests = requests.as_mut().unwrap(); + let reply_tx = requests.remove(&id).unwrap(); + reply_tx.send(reply).unwrap(); + } + Err(err) => { + warn!("Client: reader thread encountered an unexpected error while parsing; \ + returning now. Error: {:?}", + err); + break; + } } - Ok(Packet::Shutdown) => { - break; - } - // TODO: This shutdown logic is janky.. What's the right way to do this? - Err(err) => panic!("unexpected error while parsing!: {:?}", err), } } } +impl Drop for Reader { + fn drop(&mut self) { + let mut guard = self.requests.lock().unwrap(); + guard.as_mut().unwrap().clear(); + // remove the hashmap so no one can put more senders and accidentally block + guard.take(); + } +} + fn increment(cur_id: &mut u64) -> u64 { let id = *cur_id; *cur_id += 1; @@ -227,8 +377,9 @@ pub struct Client where Request: serde::ser::Serialize { synced_state: Mutex, - requests: Arc>>>, + requests: Arc>>>>, reader_guard: Option>, + timeout: Option, _request: PhantomData, } @@ -237,12 +388,12 @@ impl Client Request: serde::ser::Serialize { /// Create a new client that connects to `addr` - pub fn new(addr: A) -> io::Result { + pub fn new(addr: A, timeout: Option) -> io::Result { let stream = try!(TcpStream::connect(addr)); - let requests = Arc::new(Mutex::new(HashMap::new())); + let requests = Arc::new(Mutex::new(Some(HashMap::new()))); let reader_stream = try!(stream.try_clone()); - let reader_requests = requests.clone(); - let reader_guard = thread::spawn(move || reader(reader_stream, reader_requests)); + let reader = Reader { requests: requests.clone() }; + let reader_guard = thread::spawn(move || reader.read(reader_stream)); Ok(Client { synced_state: Mutex::new(SyncedClientState { next_id: 0, @@ -250,6 +401,7 @@ impl Client }), requests: requests, reader_guard: Some(reader_guard), + timeout: timeout, _request: PhantomData, }) } @@ -262,21 +414,36 @@ impl Client let mut state = self.synced_state.lock().unwrap(); let id = increment(&mut state.next_id); { - let mut requests = self.requests.lock().unwrap(); - requests.insert(id, tx); + if let Some(ref mut requests) = *self.requests.lock().unwrap() { + requests.insert(id, tx); + } else { + return Err(Error::ConnectionBroken); + } } - let packet = Packet::Message(id, request); + let packet = Packet { + rpc_id: id, + message: request, + }; + try!(state.stream.set_write_timeout(self.timeout)); + try!(state.stream.set_read_timeout(self.timeout)); + debug!("Client: calling rpc({:?})", request); if let Err(err) = bincode::serde::serialize_into(&mut state.stream, &packet, bincode::SizeLimit::Infinite) { - warn!("Failed to write client packet.\nPacket: {:?}\nError: {:?}", + warn!("Client: failed to write packet.\nPacket: {:?}\nError: {:?}", packet, err); - self.requests.lock().unwrap().remove(&id); + if let Some(requests) = self.requests.lock().unwrap().as_mut() { + requests.remove(&id); + } else { + warn!("Client: couldn't remove sender for request {} because reader thread \ + returned.", + id); + } return Err(err.into()); } drop(state); - Ok(rx.recv().unwrap()) + Ok(try!(rx.recv())) } } @@ -284,24 +451,30 @@ impl Drop for Client where Request: serde::ser::Serialize { fn drop(&mut self) { - { - let mut state = self.synced_state.lock().unwrap(); - let packet: Packet = Packet::Shutdown; - if let Err(err) = bincode::serde::serialize_into(&mut state.stream, - &packet, - bincode::SizeLimit::Infinite) { - warn!("While disconnecting client from server: {:?}", err); - } + if let Err(e) = self.synced_state + .lock() + .unwrap() + .stream + .shutdown(::std::net::Shutdown::Both) { + warn!("Client: couldn't shutdown reader thread: {:?}", e); + } else { + self.reader_guard.take().unwrap().join().unwrap(); } - self.reader_guard.take().unwrap().join().unwrap(); } } #[cfg(test)] mod test { - use super::*; + extern crate env_logger; + + use super::{Client, Serve, serve_async}; use std::sync::{Arc, Mutex, Barrier}; use std::thread; + use std::time::Duration; + + fn test_timeout() -> Option { + Some(Duration::from_millis(1)) + } #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] enum Request { @@ -317,7 +490,10 @@ mod test { counter: Mutex, } - impl Serve for Server { + impl Serve for Server { + type Request = Request; + type Reply = Reply; + fn serve(&self, _: Request) -> Reply { let mut counter = self.counter.lock().unwrap(); let reply = Reply::Increment(*counter); @@ -337,21 +513,23 @@ mod test { } #[test] - fn test_handle() { + fn handle() { + let _ = env_logger::init(); let server = Arc::new(Server::new()); - let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); - let client: Client = Client::new(serve_handle.local_addr().clone()) + let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); + let client: Client = Client::new(serve_handle.local_addr().clone(), None) .expect(&line!().to_string()); drop(client); serve_handle.shutdown(); } #[test] - fn test_simple() { + fn simple() { + let _ = env_logger::init(); let server = Arc::new(Server::new()); - let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); + let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); - let client = Client::new(addr).unwrap(); + let client = Client::new(addr, None).unwrap(); assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap()); assert_eq!(1, server.count()); @@ -367,7 +545,9 @@ mod test { inner: Server, } - impl Serve for BarrierServer { + impl Serve for BarrierServer { + type Request = Request; + type Reply = Reply; fn serve(&self, request: Request) -> Reply { self.barrier.wait(); self.inner.serve(request) @@ -388,11 +568,36 @@ mod test { } #[test] - fn test_concurrent() { - let server = Arc::new(BarrierServer::new(10)); - let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); + fn force_shutdown() { + let _ = env_logger::init(); + let server = Arc::new(Server::new()); + let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr).unwrap()); + let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); + let thread = thread::spawn(move || serve_handle.shutdown()); + info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment)); + thread.join().unwrap(); + } + + #[test] + fn client_failed_rpc() { + let _ = env_logger::init(); + let server = Arc::new(Server::new()); + let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap(); + let addr = serve_handle.local_addr().clone(); + let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); + serve_handle.shutdown(); + let _ = client.rpc(&Request::Increment); // First failure will trigger reader to shutdown + let _ = client.rpc(&Request::Increment); // Test whether second failure hangs + } + + #[test] + fn concurrent() { + let _ = env_logger::init(); + let server = Arc::new(BarrierServer::new(10)); + let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); + let addr = serve_handle.local_addr().clone(); + let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); let mut join_handles = vec![]; for _ in 0..10 { let my_client = client.clone();