diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 94d9b73..6a35757 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -8,3 +8,4 @@ serde = "*" bincode = "*" serde_macros = "*" log = "*" +env_logger = "*" diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 272783f..8bf82b5 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -18,6 +18,7 @@ //! } //! //! use self::my_server::*; +//! use std::time::Duration; //! //! impl my_server::Service for () { //! fn hello(&self, s: String) -> String { @@ -30,8 +31,8 @@ //! //! fn main() { //! let addr = "127.0.0.1:9000"; -//! let shutdown = my_server::serve(addr, ()).unwrap(); -//! let client = Client::new(addr).unwrap(); +//! let shutdown = my_server::serve(addr, (), Some(Duration::from_secs(30))).unwrap(); +//! let client = Client::new(addr, None).unwrap(); //! assert_eq!(3, client.add(1, 2).unwrap()); //! assert_eq!("Hello, Mom!".to_string(), //! client.hello("Mom".to_string()).unwrap()); diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs index 1148d3f..4f1183d 100644 --- a/tarpc/src/macros.rs +++ b/tarpc/src/macros.rs @@ -125,10 +125,10 @@ macro_rules! rpc { impl Client { #[doc="Create a new client that connects to the given address."] - pub fn new(addr: A) -> $crate::Result + pub fn new(addr: A, timeout: ::std::option::Option<::std::time::Duration>) -> $crate::Result where A: ::std::net::ToSocketAddrs, { - let inner = try!($crate::protocol::Client::new(addr)); + let inner = try!($crate::protocol::Client::new(addr, timeout)); Ok(Client(inner)) } @@ -151,12 +151,12 @@ macro_rules! rpc { } #[doc="Start a running service."] - pub fn serve(addr: A, service: S) -> $crate::Result<$crate::protocol::ServeHandle> + pub fn serve(addr: A, service: S, read_timeout: ::std::option::Option<::std::time::Duration>) -> $crate::Result<$crate::protocol::ServeHandle> where A: ::std::net::ToSocketAddrs, S: 'static + Service { let server = ::std::sync::Arc::new(__Server(service)); - Ok(try!($crate::protocol::serve_async(addr, server))) + Ok(try!($crate::protocol::serve_async(addr, server, read_timeout))) } } } @@ -165,6 +165,12 @@ macro_rules! rpc { #[cfg(test)] #[allow(dead_code)] mod test { + use std::time::Duration; + + fn test_timeout() -> Option { + Some(Duration::from_secs(5)) + } + rpc! { mod my_server { items { @@ -197,8 +203,8 @@ mod test { fn simple_test() { println!("Starting"); let addr = "127.0.0.1:9000"; - let shutdown = my_server::serve(addr, ()).unwrap(); - let client = Client::new(addr).unwrap(); + let shutdown = my_server::serve(addr, (), test_timeout()).unwrap(); + let client = Client::new(addr, None).unwrap(); assert_eq!(3, client.add(1, 2).unwrap()); let foo = Foo { message: "Adam".into() }; let want = Foo { message: format!("Hello, {}", &foo.message) }; diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index f24f6bd..d83569c 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -6,8 +6,10 @@ use std::convert; use std::collections::HashMap; use std::marker::PhantomData; use std::net::{TcpListener, TcpStream, SocketAddr, ToSocketAddrs}; -use std::sync::{self, Mutex, Arc}; +use std::sync::{self, Arc, Condvar, Mutex}; use std::sync::mpsc::{channel, Sender, TryRecvError}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; use std::thread::{self, JoinHandle}; /// Client errors that can occur during rpc calls @@ -23,6 +25,10 @@ pub enum Error { /// Channels are used for the client's inter-thread communication. This message is /// propagated if the receiver unexpectedly hangs up. Sender, + /// An internal message failed to be received. + /// Channels are used for the client's inter-thread communication. This message is + /// propagated if the sender unexpectedly hangs up. + Receiver, /// The server hung up. ConnectionBroken, } @@ -57,44 +63,97 @@ impl convert::From> for Error { } } +impl convert::From for Error { + fn from(_: sync::mpsc::RecvError) -> Error { + Error::Receiver + } +} + /// Return type of rpc calls: either the successful return value, or a client error. pub type Result = ::std::result::Result; -fn handle_conn(stream: TcpStream, f: F) -> Result<()> - where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, - Reply: 'static + fmt::Debug + serde::ser::Serialize, - F: 'static + Clone + Serve -{ - let mut read_stream = try!(stream.try_clone()); - let stream = Arc::new(Mutex::new(stream)); - loop { - let request_packet: Packet = - try!(bincode::serde::deserialize_from(&mut read_stream, bincode::SizeLimit::Infinite)); - match request_packet { - Packet::Shutdown => { - let stream = stream.clone(); - let mut my_stream = stream.lock().unwrap(); - try!(bincode::serde::serialize_into(&mut *my_stream, - &request_packet, - bincode::SizeLimit::Infinite)); - break; - } - Packet::Message(id, message) => { - let f = f.clone(); - let arc_stream = stream.clone(); - thread::spawn(move || { - let reply = f.serve(message); - let reply_packet = Packet::Message(id, reply); - let mut my_stream = arc_stream.lock().unwrap(); - bincode::serde::serialize_into(&mut *my_stream, - &reply_packet, - bincode::SizeLimit::Infinite) - .unwrap(); - }); +struct ConnectionHandler { + shutdown: Arc, + open_connections: Arc<(Mutex, Condvar)>, + timeout: Option, +} + +impl Drop for ConnectionHandler { + fn drop(&mut self) { + let &(ref count, ref cvar) = &*self.open_connections; + *count.lock().unwrap() -= 1; + cvar.notify_one(); + trace!("ConnectionHandler: finished serving client."); + } +} + +impl ConnectionHandler { + fn handle_conn(&self, stream: TcpStream, f: F) -> Result<()> + where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, + Reply: 'static + fmt::Debug + serde::ser::Serialize, + F: 'static + Clone + Serve + { + trace!("ConnectionHandler: serving client..."); + let mut read_stream = try!(stream.try_clone()); + let stream = Arc::new(Mutex::new(stream)); + loop { + try!(read_stream.set_read_timeout(self.timeout)); + match bincode::serde::deserialize_from(&mut read_stream, bincode::SizeLimit::Infinite) { + Ok(request_packet @ Packet::Shutdown) => { + let stream = stream.clone(); + let mut my_stream = stream.lock().unwrap(); + try!(bincode::serde::serialize_into(&mut *my_stream, + &request_packet, + bincode::SizeLimit::Infinite)); + break; + } + Ok(Packet::Message(id, message)) => { + let f = f.clone(); + let arc_stream = stream.clone(); + thread::spawn(move || { + let reply = f.serve(message); + let reply_packet = Packet::Message(id, reply); + let mut my_stream = arc_stream.lock().unwrap(); + bincode::serde::serialize_into(&mut *my_stream, + &reply_packet, + bincode::SizeLimit::Infinite) + .unwrap(); + }); + } + Err(bincode::serde::DeserializeError::IoError(ref err)) + if Self::timed_out(err.kind()) => + { + if !self.shutdown() { + warn!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so retrying read.", err); + continue; + } else { + warn!("ConnectionHandler: read timed out ({:?}). Server shutdown, so closing connection.", err); + let mut stream = stream.lock().unwrap(); + try!(bincode::serde::serialize_into(&mut *stream, + &Packet::Shutdown::, + bincode::SizeLimit::Infinite)); + break; + } + } + Err(e) => { + warn!("ConnectionHandler: closing client connection due to error while serving: {:?}", e); + return Err(e.into()); + } } } + Ok(()) + } + + fn shutdown(&self) -> bool { + self.shutdown.load(Ordering::SeqCst) + } + + fn timed_out(error_kind: io::ErrorKind) -> bool { + match error_kind { + io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock => true, + _ => false, + } } - Ok(()) } /// Provides methods for blocking until the server completes, @@ -118,17 +177,18 @@ impl ServeHandle { /// Shutdown the server. Gracefully shuts down the serve thread but currently does not /// gracefully close open connections. pub fn shutdown(self) { + info!("ServeHandle: attempting to shut down the server."); self.tx.send(()).expect(&line!().to_string()); if let Ok(_) = TcpStream::connect(self.addr) { self.join_handle.join().expect(&line!().to_string()); } else { - warn!("Best effort shutdown of serve thread failed"); + warn!("ServeHandle: best effort shutdown of serve thread failed"); } } } /// Start -pub fn serve_async(addr: A, f: F) -> io::Result +pub fn serve_async(addr: A, f: F, read_timeout: Option) -> io::Result where A: ToSocketAddrs, Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, Reply: 'static + fmt::Debug + serde::ser::Serialize, @@ -136,29 +196,50 @@ pub fn serve_async(addr: A, f: F) -> io::Result break, + Ok(_) => { + info!("serve_async: shutdown received. Waiting for open connections to return..."); + shutdown.store(true, Ordering::SeqCst); + let &(ref count, ref cvar) = &*open_connections; + let mut count = count.lock().unwrap(); + while *count != 0 { + count = cvar.wait(count).unwrap(); + } + info!("serve_async: shutdown complete ({} connections alive)", *count); + break; + } Err(TryRecvError::Disconnected) => { - info!("Sender disconnected."); + info!("serve_async: sender disconnected."); break; } _ => (), } let conn = match conn { Err(err) => { - error!("Failed to accept connection: {:?}", err); + error!("serve_async: failed to accept connection: {:?}", err); return; } Ok(c) => c, }; let f = f.clone(); + let shutdown = shutdown.clone(); + let &(ref count, _) = &*open_connections; + *count.lock().unwrap() += 1; + let open_connections = open_connections.clone(); thread::spawn(move || { - if let Err(err) = handle_conn(conn, f) { - error!("Error in connection handling: {:?}", err); + let handler = ConnectionHandler { + shutdown: shutdown, + open_connections: open_connections, + timeout: read_timeout, + }; + if let Err(err) = handler.handle_conn(conn, f) { + error!("ConnectionHandler: error in connection handling: {:?}", err); } }); } @@ -198,11 +279,14 @@ fn reader(mut stream: TcpStream, requests: Arc { + debug!("Client: received message, id={}", id); let mut requests = requests.lock().unwrap(); let reply_tx = requests.remove(&id).unwrap(); reply_tx.send(reply).unwrap(); } Ok(Packet::Shutdown) => { + info!("Client: got shutdown message."); + requests.lock().unwrap().clear(); break; } // TODO: This shutdown logic is janky.. What's the right way to do this? @@ -229,6 +313,7 @@ pub struct Client synced_state: Mutex, requests: Arc>>>, reader_guard: Option>, + timeout: Option, _request: PhantomData, } @@ -237,7 +322,7 @@ impl Client Request: serde::ser::Serialize { /// Create a new client that connects to `addr` - pub fn new(addr: A) -> io::Result { + pub fn new(addr: A, timeout: Option) -> io::Result { let stream = try!(TcpStream::connect(addr)); let requests = Arc::new(Mutex::new(HashMap::new())); let reader_stream = try!(stream.try_clone()); @@ -250,6 +335,7 @@ impl Client }), requests: requests, reader_guard: Some(reader_guard), + timeout: timeout, _request: PhantomData, }) } @@ -266,17 +352,20 @@ impl Client requests.insert(id, tx); } let packet = Packet::Message(id, request); + try!(state.stream.set_write_timeout(self.timeout)); + try!(state.stream.set_read_timeout(self.timeout)); + debug!("Client: calling rpc({:?})", request); if let Err(err) = bincode::serde::serialize_into(&mut state.stream, &packet, bincode::SizeLimit::Infinite) { - warn!("Failed to write client packet.\nPacket: {:?}\nError: {:?}", + warn!("Client: failed to write packet.\nPacket: {:?}\nError: {:?}", packet, err); self.requests.lock().unwrap().remove(&id); return Err(err.into()); } drop(state); - Ok(rx.recv().unwrap()) + Ok(try!(rx.recv())) } } @@ -299,9 +388,16 @@ impl Drop for Client #[cfg(test)] mod test { + extern crate env_logger; + use super::*; use std::sync::{Arc, Mutex, Barrier}; use std::thread; + use std::time::Duration; + + fn test_timeout() -> Option { + Some(Duration::from_millis(1)) + } #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] enum Request { @@ -337,21 +433,24 @@ mod test { } #[test] - fn test_handle() { + fn handle() { + let _ = env_logger::init(); let server = Arc::new(Server::new()); - let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); - let client: Client = Client::new(serve_handle.local_addr().clone()) + let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); + let client: Client = Client::new(serve_handle.local_addr().clone(), + test_timeout()) .expect(&line!().to_string()); drop(client); serve_handle.shutdown(); } #[test] - fn test_simple() { + fn simple() { + let _ = env_logger::init(); let server = Arc::new(Server::new()); - let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); + let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); - let client = Client::new(addr).unwrap(); + let client = Client::new(addr, test_timeout()).unwrap(); assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap()); assert_eq!(1, server.count()); @@ -388,11 +487,25 @@ mod test { } #[test] - fn test_concurrent() { - let server = Arc::new(BarrierServer::new(10)); - let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); + fn force_shutdown() { + let _ = env_logger::init(); + let server = Arc::new(Server::new()); + let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr).unwrap()); + let client: Arc> = Arc::new(Client::new(addr, test_timeout()) + .unwrap()); + let thread = thread::spawn(move || serve_handle.shutdown()); + info!("force_shutdown::client: {:?}", client.rpc(&Request::Increment)); + thread.join().unwrap(); + } + + #[test] + fn concurrent() { + let _ = env_logger::init(); + let server = Arc::new(BarrierServer::new(10)); + let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); + let addr = serve_handle.local_addr().clone(); + let client: Arc> = Arc::new(Client::new(addr, test_timeout()).unwrap()); let mut join_handles = vec![]; for _ in 0..10 { let my_client = client.clone();