From 91053b96c0ef3770071b4d381e01f6e692b4d8f2 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Thu, 14 Jan 2016 01:09:47 -0800 Subject: [PATCH 01/10] Orderly shutdown of serving threads when calling ServeHandle::shutdown. --- tarpc/Cargo.toml | 1 + tarpc/src/lib.rs | 5 +- tarpc/src/macros.rs | 18 ++-- tarpc/src/protocol.rs | 221 +++++++++++++++++++++++++++++++----------- 4 files changed, 183 insertions(+), 62 deletions(-) diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 94d9b73..6a35757 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -8,3 +8,4 @@ serde = "*" bincode = "*" serde_macros = "*" log = "*" +env_logger = "*" diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 272783f..8bf82b5 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -18,6 +18,7 @@ //! } //! //! use self::my_server::*; +//! use std::time::Duration; //! //! impl my_server::Service for () { //! fn hello(&self, s: String) -> String { @@ -30,8 +31,8 @@ //! //! fn main() { //! let addr = "127.0.0.1:9000"; -//! let shutdown = my_server::serve(addr, ()).unwrap(); -//! let client = Client::new(addr).unwrap(); +//! let shutdown = my_server::serve(addr, (), Some(Duration::from_secs(30))).unwrap(); +//! let client = Client::new(addr, None).unwrap(); //! assert_eq!(3, client.add(1, 2).unwrap()); //! assert_eq!("Hello, Mom!".to_string(), //! client.hello("Mom".to_string()).unwrap()); diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs index 1148d3f..4f1183d 100644 --- a/tarpc/src/macros.rs +++ b/tarpc/src/macros.rs @@ -125,10 +125,10 @@ macro_rules! rpc { impl Client { #[doc="Create a new client that connects to the given address."] - pub fn new(addr: A) -> $crate::Result + pub fn new(addr: A, timeout: ::std::option::Option<::std::time::Duration>) -> $crate::Result where A: ::std::net::ToSocketAddrs, { - let inner = try!($crate::protocol::Client::new(addr)); + let inner = try!($crate::protocol::Client::new(addr, timeout)); Ok(Client(inner)) } @@ -151,12 +151,12 @@ macro_rules! rpc { } #[doc="Start a running service."] - pub fn serve(addr: A, service: S) -> $crate::Result<$crate::protocol::ServeHandle> + pub fn serve(addr: A, service: S, read_timeout: ::std::option::Option<::std::time::Duration>) -> $crate::Result<$crate::protocol::ServeHandle> where A: ::std::net::ToSocketAddrs, S: 'static + Service { let server = ::std::sync::Arc::new(__Server(service)); - Ok(try!($crate::protocol::serve_async(addr, server))) + Ok(try!($crate::protocol::serve_async(addr, server, read_timeout))) } } } @@ -165,6 +165,12 @@ macro_rules! rpc { #[cfg(test)] #[allow(dead_code)] mod test { + use std::time::Duration; + + fn test_timeout() -> Option { + Some(Duration::from_secs(5)) + } + rpc! { mod my_server { items { @@ -197,8 +203,8 @@ mod test { fn simple_test() { println!("Starting"); let addr = "127.0.0.1:9000"; - let shutdown = my_server::serve(addr, ()).unwrap(); - let client = Client::new(addr).unwrap(); + let shutdown = my_server::serve(addr, (), test_timeout()).unwrap(); + let client = Client::new(addr, None).unwrap(); assert_eq!(3, client.add(1, 2).unwrap()); let foo = Foo { message: "Adam".into() }; let want = Foo { message: format!("Hello, {}", &foo.message) }; diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index f24f6bd..d83569c 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -6,8 +6,10 @@ use std::convert; use std::collections::HashMap; use std::marker::PhantomData; use std::net::{TcpListener, TcpStream, SocketAddr, ToSocketAddrs}; -use std::sync::{self, Mutex, Arc}; +use std::sync::{self, Arc, Condvar, Mutex}; use std::sync::mpsc::{channel, Sender, TryRecvError}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; use std::thread::{self, JoinHandle}; /// Client errors that can occur during rpc calls @@ -23,6 +25,10 @@ pub enum Error { /// Channels are used for the client's inter-thread communication. This message is /// propagated if the receiver unexpectedly hangs up. Sender, + /// An internal message failed to be received. + /// Channels are used for the client's inter-thread communication. This message is + /// propagated if the sender unexpectedly hangs up. + Receiver, /// The server hung up. ConnectionBroken, } @@ -57,44 +63,97 @@ impl convert::From> for Error { } } +impl convert::From for Error { + fn from(_: sync::mpsc::RecvError) -> Error { + Error::Receiver + } +} + /// Return type of rpc calls: either the successful return value, or a client error. pub type Result = ::std::result::Result; -fn handle_conn(stream: TcpStream, f: F) -> Result<()> - where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, - Reply: 'static + fmt::Debug + serde::ser::Serialize, - F: 'static + Clone + Serve -{ - let mut read_stream = try!(stream.try_clone()); - let stream = Arc::new(Mutex::new(stream)); - loop { - let request_packet: Packet = - try!(bincode::serde::deserialize_from(&mut read_stream, bincode::SizeLimit::Infinite)); - match request_packet { - Packet::Shutdown => { - let stream = stream.clone(); - let mut my_stream = stream.lock().unwrap(); - try!(bincode::serde::serialize_into(&mut *my_stream, - &request_packet, - bincode::SizeLimit::Infinite)); - break; - } - Packet::Message(id, message) => { - let f = f.clone(); - let arc_stream = stream.clone(); - thread::spawn(move || { - let reply = f.serve(message); - let reply_packet = Packet::Message(id, reply); - let mut my_stream = arc_stream.lock().unwrap(); - bincode::serde::serialize_into(&mut *my_stream, - &reply_packet, - bincode::SizeLimit::Infinite) - .unwrap(); - }); +struct ConnectionHandler { + shutdown: Arc, + open_connections: Arc<(Mutex, Condvar)>, + timeout: Option, +} + +impl Drop for ConnectionHandler { + fn drop(&mut self) { + let &(ref count, ref cvar) = &*self.open_connections; + *count.lock().unwrap() -= 1; + cvar.notify_one(); + trace!("ConnectionHandler: finished serving client."); + } +} + +impl ConnectionHandler { + fn handle_conn(&self, stream: TcpStream, f: F) -> Result<()> + where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, + Reply: 'static + fmt::Debug + serde::ser::Serialize, + F: 'static + Clone + Serve + { + trace!("ConnectionHandler: serving client..."); + let mut read_stream = try!(stream.try_clone()); + let stream = Arc::new(Mutex::new(stream)); + loop { + try!(read_stream.set_read_timeout(self.timeout)); + match bincode::serde::deserialize_from(&mut read_stream, bincode::SizeLimit::Infinite) { + Ok(request_packet @ Packet::Shutdown) => { + let stream = stream.clone(); + let mut my_stream = stream.lock().unwrap(); + try!(bincode::serde::serialize_into(&mut *my_stream, + &request_packet, + bincode::SizeLimit::Infinite)); + break; + } + Ok(Packet::Message(id, message)) => { + let f = f.clone(); + let arc_stream = stream.clone(); + thread::spawn(move || { + let reply = f.serve(message); + let reply_packet = Packet::Message(id, reply); + let mut my_stream = arc_stream.lock().unwrap(); + bincode::serde::serialize_into(&mut *my_stream, + &reply_packet, + bincode::SizeLimit::Infinite) + .unwrap(); + }); + } + Err(bincode::serde::DeserializeError::IoError(ref err)) + if Self::timed_out(err.kind()) => + { + if !self.shutdown() { + warn!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so retrying read.", err); + continue; + } else { + warn!("ConnectionHandler: read timed out ({:?}). Server shutdown, so closing connection.", err); + let mut stream = stream.lock().unwrap(); + try!(bincode::serde::serialize_into(&mut *stream, + &Packet::Shutdown::, + bincode::SizeLimit::Infinite)); + break; + } + } + Err(e) => { + warn!("ConnectionHandler: closing client connection due to error while serving: {:?}", e); + return Err(e.into()); + } } } + Ok(()) + } + + fn shutdown(&self) -> bool { + self.shutdown.load(Ordering::SeqCst) + } + + fn timed_out(error_kind: io::ErrorKind) -> bool { + match error_kind { + io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock => true, + _ => false, + } } - Ok(()) } /// Provides methods for blocking until the server completes, @@ -118,17 +177,18 @@ impl ServeHandle { /// Shutdown the server. Gracefully shuts down the serve thread but currently does not /// gracefully close open connections. pub fn shutdown(self) { + info!("ServeHandle: attempting to shut down the server."); self.tx.send(()).expect(&line!().to_string()); if let Ok(_) = TcpStream::connect(self.addr) { self.join_handle.join().expect(&line!().to_string()); } else { - warn!("Best effort shutdown of serve thread failed"); + warn!("ServeHandle: best effort shutdown of serve thread failed"); } } } /// Start -pub fn serve_async(addr: A, f: F) -> io::Result +pub fn serve_async(addr: A, f: F, read_timeout: Option) -> io::Result where A: ToSocketAddrs, Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, Reply: 'static + fmt::Debug + serde::ser::Serialize, @@ -136,29 +196,50 @@ pub fn serve_async(addr: A, f: F) -> io::Result break, + Ok(_) => { + info!("serve_async: shutdown received. Waiting for open connections to return..."); + shutdown.store(true, Ordering::SeqCst); + let &(ref count, ref cvar) = &*open_connections; + let mut count = count.lock().unwrap(); + while *count != 0 { + count = cvar.wait(count).unwrap(); + } + info!("serve_async: shutdown complete ({} connections alive)", *count); + break; + } Err(TryRecvError::Disconnected) => { - info!("Sender disconnected."); + info!("serve_async: sender disconnected."); break; } _ => (), } let conn = match conn { Err(err) => { - error!("Failed to accept connection: {:?}", err); + error!("serve_async: failed to accept connection: {:?}", err); return; } Ok(c) => c, }; let f = f.clone(); + let shutdown = shutdown.clone(); + let &(ref count, _) = &*open_connections; + *count.lock().unwrap() += 1; + let open_connections = open_connections.clone(); thread::spawn(move || { - if let Err(err) = handle_conn(conn, f) { - error!("Error in connection handling: {:?}", err); + let handler = ConnectionHandler { + shutdown: shutdown, + open_connections: open_connections, + timeout: read_timeout, + }; + if let Err(err) = handler.handle_conn(conn, f) { + error!("ConnectionHandler: error in connection handling: {:?}", err); } }); } @@ -198,11 +279,14 @@ fn reader(mut stream: TcpStream, requests: Arc { + debug!("Client: received message, id={}", id); let mut requests = requests.lock().unwrap(); let reply_tx = requests.remove(&id).unwrap(); reply_tx.send(reply).unwrap(); } Ok(Packet::Shutdown) => { + info!("Client: got shutdown message."); + requests.lock().unwrap().clear(); break; } // TODO: This shutdown logic is janky.. What's the right way to do this? @@ -229,6 +313,7 @@ pub struct Client synced_state: Mutex, requests: Arc>>>, reader_guard: Option>, + timeout: Option, _request: PhantomData, } @@ -237,7 +322,7 @@ impl Client Request: serde::ser::Serialize { /// Create a new client that connects to `addr` - pub fn new(addr: A) -> io::Result { + pub fn new(addr: A, timeout: Option) -> io::Result { let stream = try!(TcpStream::connect(addr)); let requests = Arc::new(Mutex::new(HashMap::new())); let reader_stream = try!(stream.try_clone()); @@ -250,6 +335,7 @@ impl Client }), requests: requests, reader_guard: Some(reader_guard), + timeout: timeout, _request: PhantomData, }) } @@ -266,17 +352,20 @@ impl Client requests.insert(id, tx); } let packet = Packet::Message(id, request); + try!(state.stream.set_write_timeout(self.timeout)); + try!(state.stream.set_read_timeout(self.timeout)); + debug!("Client: calling rpc({:?})", request); if let Err(err) = bincode::serde::serialize_into(&mut state.stream, &packet, bincode::SizeLimit::Infinite) { - warn!("Failed to write client packet.\nPacket: {:?}\nError: {:?}", + warn!("Client: failed to write packet.\nPacket: {:?}\nError: {:?}", packet, err); self.requests.lock().unwrap().remove(&id); return Err(err.into()); } drop(state); - Ok(rx.recv().unwrap()) + Ok(try!(rx.recv())) } } @@ -299,9 +388,16 @@ impl Drop for Client #[cfg(test)] mod test { + extern crate env_logger; + use super::*; use std::sync::{Arc, Mutex, Barrier}; use std::thread; + use std::time::Duration; + + fn test_timeout() -> Option { + Some(Duration::from_millis(1)) + } #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] enum Request { @@ -337,21 +433,24 @@ mod test { } #[test] - fn test_handle() { + fn handle() { + let _ = env_logger::init(); let server = Arc::new(Server::new()); - let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); - let client: Client = Client::new(serve_handle.local_addr().clone()) + let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); + let client: Client = Client::new(serve_handle.local_addr().clone(), + test_timeout()) .expect(&line!().to_string()); drop(client); serve_handle.shutdown(); } #[test] - fn test_simple() { + fn simple() { + let _ = env_logger::init(); let server = Arc::new(Server::new()); - let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); + let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); - let client = Client::new(addr).unwrap(); + let client = Client::new(addr, test_timeout()).unwrap(); assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap()); assert_eq!(1, server.count()); @@ -388,11 +487,25 @@ mod test { } #[test] - fn test_concurrent() { - let server = Arc::new(BarrierServer::new(10)); - let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); + fn force_shutdown() { + let _ = env_logger::init(); + let server = Arc::new(Server::new()); + let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr).unwrap()); + let client: Arc> = Arc::new(Client::new(addr, test_timeout()) + .unwrap()); + let thread = thread::spawn(move || serve_handle.shutdown()); + info!("force_shutdown::client: {:?}", client.rpc(&Request::Increment)); + thread.join().unwrap(); + } + + #[test] + fn concurrent() { + let _ = env_logger::init(); + let server = Arc::new(BarrierServer::new(10)); + let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); + let addr = serve_handle.local_addr().clone(); + let client: Arc> = Arc::new(Client::new(addr, test_timeout()).unwrap()); let mut join_handles = vec![]; for _ in 0..10 { let my_client = client.clone(); From 2644bf0d9bbb7eed2441a58d19fec52cf87d7110 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Thu, 14 Jan 2016 01:53:28 -0800 Subject: [PATCH 02/10] Ensure no rpc calls can be started once the reader thread returns. --- tarpc/src/protocol.rs | 75 +++++++++++++++++++++++++++---------------- 1 file changed, 48 insertions(+), 27 deletions(-) diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index d83569c..0e1cab7 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -271,30 +271,43 @@ enum Packet { Shutdown, } -fn reader(mut stream: TcpStream, requests: Arc>>>) - where Reply: serde::Deserialize -{ - loop { - let packet: bincode::serde::DeserializeResult> = - bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); - match packet { - Ok(Packet::Message(id, reply)) => { - debug!("Client: received message, id={}", id); - let mut requests = requests.lock().unwrap(); - let reply_tx = requests.remove(&id).unwrap(); - reply_tx.send(reply).unwrap(); +struct Reader { + requests: Arc>>>> +} + +impl Reader { + fn read(self, mut stream: TcpStream) where Reply: serde::Deserialize { + loop { + let packet: bincode::serde::DeserializeResult> = + bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); + match packet { + Ok(Packet::Message(id, reply)) => { + debug!("Client: received message, id={}", id); + let mut requests = self.requests.lock().unwrap(); + let mut requests = requests.as_mut().unwrap(); + let reply_tx = requests.remove(&id).unwrap(); + reply_tx.send(reply).unwrap(); + } + Ok(Packet::Shutdown) => { + info!("Client: got shutdown message."); + break; + } + // TODO: This shutdown logic is janky.. What's the right way to do this? + Err(err) => panic!("unexpected error while parsing!: {:?}", err), } - Ok(Packet::Shutdown) => { - info!("Client: got shutdown message."); - requests.lock().unwrap().clear(); - break; - } - // TODO: This shutdown logic is janky.. What's the right way to do this? - Err(err) => panic!("unexpected error while parsing!: {:?}", err), } } } +impl Drop for Reader { + fn drop(&mut self) { + let mut guard = self.requests.lock().unwrap(); + guard.as_mut().unwrap().clear(); + // remove the hashmap so no one can put more senders and accidentally block + guard.take(); + } +} + fn increment(cur_id: &mut u64) -> u64 { let id = *cur_id; *cur_id += 1; @@ -311,7 +324,7 @@ pub struct Client where Request: serde::ser::Serialize { synced_state: Mutex, - requests: Arc>>>, + requests: Arc>>>>, reader_guard: Option>, timeout: Option, _request: PhantomData, @@ -324,10 +337,10 @@ impl Client /// Create a new client that connects to `addr` pub fn new(addr: A, timeout: Option) -> io::Result { let stream = try!(TcpStream::connect(addr)); - let requests = Arc::new(Mutex::new(HashMap::new())); + let requests = Arc::new(Mutex::new(Some(HashMap::new()))); let reader_stream = try!(stream.try_clone()); - let reader_requests = requests.clone(); - let reader_guard = thread::spawn(move || reader(reader_stream, reader_requests)); + let reader = Reader { requests: requests.clone() }; + let reader_guard = thread::spawn(move || reader.read(reader_stream)); Ok(Client { synced_state: Mutex::new(SyncedClientState { next_id: 0, @@ -348,8 +361,11 @@ impl Client let mut state = self.synced_state.lock().unwrap(); let id = increment(&mut state.next_id); { - let mut requests = self.requests.lock().unwrap(); - requests.insert(id, tx); + if let Some(ref mut requests) = *self.requests.lock().unwrap() { + requests.insert(id, tx); + } else { + return Err(Error::ConnectionBroken); + } } let packet = Packet::Message(id, request); try!(state.stream.set_write_timeout(self.timeout)); @@ -361,7 +377,11 @@ impl Client warn!("Client: failed to write packet.\nPacket: {:?}\nError: {:?}", packet, err); - self.requests.lock().unwrap().remove(&id); + if let Some(requests) = self.requests.lock().unwrap().as_mut() { + requests.remove(&id); + } else { + warn!("Client: couldn't remove sender for request {} because reader thread returned.", id); + } return Err(err.into()); } drop(state); @@ -495,7 +515,8 @@ mod test { let client: Arc> = Arc::new(Client::new(addr, test_timeout()) .unwrap()); let thread = thread::spawn(move || serve_handle.shutdown()); - info!("force_shutdown::client: {:?}", client.rpc(&Request::Increment)); + info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment)); + info!("force_shutdown:: rpc2: {:?}", client.rpc(&Request::Increment)); thread.join().unwrap(); } From 8c51d2ca1bc0508e92c0706eeaec2ccc8a15b63d Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Thu, 14 Jan 2016 01:55:25 -0800 Subject: [PATCH 03/10] Cargo fmt --- tarpc/src/lib.rs | 3 ++- tarpc/src/protocol.rs | 49 +++++++++++++++++++++++++++++-------------- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 8bf82b5..05823a0 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -31,7 +31,8 @@ //! //! fn main() { //! let addr = "127.0.0.1:9000"; -//! let shutdown = my_server::serve(addr, (), Some(Duration::from_secs(30))).unwrap(); +//! let shutdown = my_server::serve(addr, (), +//! Some(Duration::from_secs(30))).unwrap(); //! let client = Client::new(addr, None).unwrap(); //! assert_eq!(3, client.add(1, 2).unwrap()); //! assert_eq!("Hello, Mom!".to_string(), diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index 0e1cab7..c4a4378 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -117,17 +117,20 @@ impl ConnectionHandler { bincode::serde::serialize_into(&mut *my_stream, &reply_packet, bincode::SizeLimit::Infinite) - .unwrap(); + .unwrap(); }); } Err(bincode::serde::DeserializeError::IoError(ref err)) - if Self::timed_out(err.kind()) => - { + if Self::timed_out(err.kind()) => { if !self.shutdown() { - warn!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so retrying read.", err); + warn!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \ + retrying read.", + err); continue; } else { - warn!("ConnectionHandler: read timed out ({:?}). Server shutdown, so closing connection.", err); + warn!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ + closing connection.", + err); let mut stream = stream.lock().unwrap(); try!(bincode::serde::serialize_into(&mut *stream, &Packet::Shutdown::, @@ -136,7 +139,9 @@ impl ConnectionHandler { } } Err(e) => { - warn!("ConnectionHandler: closing client connection due to error while serving: {:?}", e); + warn!("ConnectionHandler: closing client connection due to error while \ + serving: {:?}", + e); return Err(e.into()); } } @@ -188,7 +193,10 @@ impl ServeHandle { } /// Start -pub fn serve_async(addr: A, f: F, read_timeout: Option) -> io::Result +pub fn serve_async(addr: A, + f: F, + read_timeout: Option) + -> io::Result where A: ToSocketAddrs, Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, Reply: 'static + fmt::Debug + serde::ser::Serialize, @@ -204,14 +212,16 @@ pub fn serve_async(addr: A, f: F, read_timeout: Option { - info!("serve_async: shutdown received. Waiting for open connections to return..."); + info!("serve_async: shutdown received. Waiting for open connections to \ + return..."); shutdown.store(true, Ordering::SeqCst); let &(ref count, ref cvar) = &*open_connections; let mut count = count.lock().unwrap(); while *count != 0 { count = cvar.wait(count).unwrap(); } - info!("serve_async: shutdown complete ({} connections alive)", *count); + info!("serve_async: shutdown complete ({} connections alive)", + *count); break; } Err(TryRecvError::Disconnected) => { @@ -272,11 +282,13 @@ enum Packet { } struct Reader { - requests: Arc>>>> + requests: Arc>>>>, } impl Reader { - fn read(self, mut stream: TcpStream) where Reply: serde::Deserialize { + fn read(self, mut stream: TcpStream) + where Reply: serde::Deserialize + { loop { let packet: bincode::serde::DeserializeResult> = bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); @@ -380,7 +392,9 @@ impl Client if let Some(requests) = self.requests.lock().unwrap().as_mut() { requests.remove(&id); } else { - warn!("Client: couldn't remove sender for request {} because reader thread returned.", id); + warn!("Client: couldn't remove sender for request {} because reader thread \ + returned.", + id); } return Err(err.into()); } @@ -513,10 +527,12 @@ mod test { let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap(); let addr = serve_handle.local_addr().clone(); let client: Arc> = Arc::new(Client::new(addr, test_timeout()) - .unwrap()); + .unwrap()); let thread = thread::spawn(move || serve_handle.shutdown()); - info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment)); - info!("force_shutdown:: rpc2: {:?}", client.rpc(&Request::Increment)); + info!("force_shutdown:: rpc1: {:?}", + client.rpc(&Request::Increment)); + info!("force_shutdown:: rpc2: {:?}", + client.rpc(&Request::Increment)); thread.join().unwrap(); } @@ -526,7 +542,8 @@ mod test { let server = Arc::new(BarrierServer::new(10)); let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr, test_timeout()).unwrap()); + let client: Arc> = Arc::new(Client::new(addr, test_timeout()) + .unwrap()); let mut join_handles = vec![]; for _ in 0..10 { let my_client = client.clone(); From 0df3cfdd9840263a928887e3afc43357f5c4eacf Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Thu, 14 Jan 2016 09:22:46 -0800 Subject: [PATCH 04/10] Properly wait for spawned connection handler threads to shutdown. Set client timeout to None in tests. --- tarpc/src/protocol.rs | 134 +++++++++++++++++++++++++++--------------- 1 file changed, 85 insertions(+), 49 deletions(-) diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index c4a4378..3e54e7d 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -72,56 +72,101 @@ impl convert::From for Error { /// Return type of rpc calls: either the successful return value, or a client error. pub type Result = ::std::result::Result; -struct ConnectionHandler { - shutdown: Arc, +#[derive(Clone)] +struct OpenConnections { open_connections: Arc<(Mutex, Condvar)>, +} + +impl OpenConnections { + fn new(mutex: Mutex, cvar: Condvar) -> OpenConnections { + OpenConnections { + open_connections: Arc::new((mutex, cvar)), + } + } + + fn wait_until_zero(&self) { + let &(ref count, ref cvar) = &*self.open_connections; + let mut count = count.lock().unwrap(); + while *count != 0 { + count = cvar.wait(count).unwrap(); + } + info!("serve_async: shutdown complete ({} connections alive)", *count); + } + + fn increment(&self) { + let &(ref count, _) = &*self.open_connections; + *count.lock().unwrap() += 1; + } + + fn decrement(&self) { + let &(ref count, _) = &*self.open_connections; + *count.lock().unwrap() -= 1; + } + + + fn decrement_and_notify(&self) { + let &(ref count, ref cvar) = &*self.open_connections; + *count.lock().unwrap() -= 1; + cvar.notify_one(); + } + +} + +struct ConnectionHandler { + read_stream: TcpStream, + write_stream: Arc>, + shutdown: Arc, + open_connections: OpenConnections, timeout: Option, } impl Drop for ConnectionHandler { fn drop(&mut self) { - let &(ref count, ref cvar) = &*self.open_connections; - *count.lock().unwrap() -= 1; - cvar.notify_one(); + if let Err(e) = bincode::serde::serialize_into(&mut self.read_stream, + &Packet::Shutdown::<()>, + bincode::SizeLimit::Infinite) { + warn!("ConnectionHandler: could not notify client of shutdown: {:?}", e); + } trace!("ConnectionHandler: finished serving client."); + self.open_connections.decrement_and_notify(); } } impl ConnectionHandler { - fn handle_conn(&self, stream: TcpStream, f: F) -> Result<()> + fn read(&mut self) -> bincode::serde::DeserializeResult> + where Request: serde::de::Deserialize + { + try!(self.read_stream.set_read_timeout(self.timeout)); + bincode::serde::deserialize_from(&mut self.read_stream, bincode::SizeLimit::Infinite) + } + + fn handle_conn(&mut self, f: F) -> Result<()> where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, Reply: 'static + fmt::Debug + serde::ser::Serialize, F: 'static + Clone + Serve { trace!("ConnectionHandler: serving client..."); - let mut read_stream = try!(stream.try_clone()); - let stream = Arc::new(Mutex::new(stream)); loop { - try!(read_stream.set_read_timeout(self.timeout)); - match bincode::serde::deserialize_from(&mut read_stream, bincode::SizeLimit::Infinite) { - Ok(request_packet @ Packet::Shutdown) => { - let stream = stream.clone(); - let mut my_stream = stream.lock().unwrap(); - try!(bincode::serde::serialize_into(&mut *my_stream, - &request_packet, - bincode::SizeLimit::Infinite)); - break; - } + match self.read() { + Ok(Packet::Shutdown) => break, Ok(Packet::Message(id, message)) => { let f = f.clone(); - let arc_stream = stream.clone(); + let open_connections = self.open_connections.clone(); + open_connections.increment(); + let stream = self.write_stream.clone(); thread::spawn(move || { let reply = f.serve(message); let reply_packet = Packet::Message(id, reply); - let mut my_stream = arc_stream.lock().unwrap(); - bincode::serde::serialize_into(&mut *my_stream, + let mut stream = stream.lock().unwrap(); + if let Err(e) = bincode::serde::serialize_into(&mut *stream, &reply_packet, - bincode::SizeLimit::Infinite) - .unwrap(); + bincode::SizeLimit::Infinite) { + warn!("ConnectionHandler: failed to write reply to Client: {:?}", e); + } + open_connections.decrement(); }); } - Err(bincode::serde::DeserializeError::IoError(ref err)) - if Self::timed_out(err.kind()) => { + Err(bincode::serde::DeserializeError::IoError(ref err)) if Self::timed_out(err.kind()) => { if !self.shutdown() { warn!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \ retrying read.", @@ -131,10 +176,6 @@ impl ConnectionHandler { warn!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ closing connection.", err); - let mut stream = stream.lock().unwrap(); - try!(bincode::serde::serialize_into(&mut *stream, - &Packet::Shutdown::, - bincode::SizeLimit::Infinite)); break; } } @@ -208,20 +249,14 @@ pub fn serve_async(addr: A, let (die_tx, die_rx) = channel(); let join_handle = thread::spawn(move || { let shutdown = Arc::new(AtomicBool::new(false)); - let open_connections = Arc::new((Mutex::new(0), Condvar::new())); + let open_connections = OpenConnections::new(Mutex::new(0), Condvar::new()); for conn in listener.incoming() { match die_rx.try_recv() { Ok(_) => { info!("serve_async: shutdown received. Waiting for open connections to \ return..."); shutdown.store(true, Ordering::SeqCst); - let &(ref count, ref cvar) = &*open_connections; - let mut count = count.lock().unwrap(); - while *count != 0 { - count = cvar.wait(count).unwrap(); - } - info!("serve_async: shutdown complete ({} connections alive)", - *count); + open_connections.wait_until_zero(); break; } Err(TryRecvError::Disconnected) => { @@ -239,16 +274,17 @@ pub fn serve_async(addr: A, }; let f = f.clone(); let shutdown = shutdown.clone(); - let &(ref count, _) = &*open_connections; - *count.lock().unwrap() += 1; + open_connections.increment(); let open_connections = open_connections.clone(); + let mut handler = ConnectionHandler { + read_stream: conn.try_clone().unwrap(), + write_stream: Arc::new(Mutex::new(conn)), + shutdown: shutdown, + open_connections: open_connections, + timeout: read_timeout, + }; thread::spawn(move || { - let handler = ConnectionHandler { - shutdown: shutdown, - open_connections: open_connections, - timeout: read_timeout, - }; - if let Err(err) = handler.handle_conn(conn, f) { + if let Err(err) = handler.handle_conn(f) { error!("ConnectionHandler: error in connection handling: {:?}", err); } }); @@ -472,7 +508,7 @@ mod test { let server = Arc::new(Server::new()); let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); let client: Client = Client::new(serve_handle.local_addr().clone(), - test_timeout()) + None) .expect(&line!().to_string()); drop(client); serve_handle.shutdown(); @@ -484,7 +520,7 @@ mod test { let server = Arc::new(Server::new()); let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); - let client = Client::new(addr, test_timeout()).unwrap(); + let client = Client::new(addr, None).unwrap(); assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap()); assert_eq!(1, server.count()); @@ -526,7 +562,7 @@ mod test { let server = Arc::new(Server::new()); let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr, test_timeout()) + let client: Arc> = Arc::new(Client::new(addr, None) .unwrap()); let thread = thread::spawn(move || serve_handle.shutdown()); info!("force_shutdown:: rpc1: {:?}", @@ -542,7 +578,7 @@ mod test { let server = Arc::new(BarrierServer::new(10)); let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr, test_timeout()) + let client: Arc> = Arc::new(Client::new(addr, None) .unwrap()); let mut join_handles = vec![]; for _ in 0..10 { From d81b37b35d601f881783e722b50f2c852c203d2b Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Thu, 14 Jan 2016 09:23:34 -0800 Subject: [PATCH 05/10] Cargo fmt --- tarpc/src/protocol.rs | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index 3e54e7d..96424ea 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -79,9 +79,7 @@ struct OpenConnections { impl OpenConnections { fn new(mutex: Mutex, cvar: Condvar) -> OpenConnections { - OpenConnections { - open_connections: Arc::new((mutex, cvar)), - } + OpenConnections { open_connections: Arc::new((mutex, cvar)) } } fn wait_until_zero(&self) { @@ -90,7 +88,8 @@ impl OpenConnections { while *count != 0 { count = cvar.wait(count).unwrap(); } - info!("serve_async: shutdown complete ({} connections alive)", *count); + info!("serve_async: shutdown complete ({} connections alive)", + *count); } fn increment(&self) { @@ -123,9 +122,10 @@ struct ConnectionHandler { impl Drop for ConnectionHandler { fn drop(&mut self) { if let Err(e) = bincode::serde::serialize_into(&mut self.read_stream, - &Packet::Shutdown::<()>, - bincode::SizeLimit::Infinite) { - warn!("ConnectionHandler: could not notify client of shutdown: {:?}", e); + &Packet::Shutdown::<()>, + bincode::SizeLimit::Infinite) { + warn!("ConnectionHandler: could not notify client of shutdown: {:?}", + e); } trace!("ConnectionHandler: finished serving client."); self.open_connections.decrement_and_notify(); @@ -158,15 +158,18 @@ impl ConnectionHandler { let reply = f.serve(message); let reply_packet = Packet::Message(id, reply); let mut stream = stream.lock().unwrap(); - if let Err(e) = bincode::serde::serialize_into(&mut *stream, - &reply_packet, - bincode::SizeLimit::Infinite) { - warn!("ConnectionHandler: failed to write reply to Client: {:?}", e); + if let Err(e) = + bincode::serde::serialize_into(&mut *stream, + &reply_packet, + bincode::SizeLimit::Infinite) { + warn!("ConnectionHandler: failed to write reply to Client: {:?}", + e); } open_connections.decrement(); }); } - Err(bincode::serde::DeserializeError::IoError(ref err)) if Self::timed_out(err.kind()) => { + Err(bincode::serde::DeserializeError::IoError(ref err)) + if Self::timed_out(err.kind()) => { if !self.shutdown() { warn!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \ retrying read.", @@ -507,8 +510,7 @@ mod test { let _ = env_logger::init(); let server = Arc::new(Server::new()); let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); - let client: Client = Client::new(serve_handle.local_addr().clone(), - None) + let client: Client = Client::new(serve_handle.local_addr().clone(), None) .expect(&line!().to_string()); drop(client); serve_handle.shutdown(); @@ -562,8 +564,7 @@ mod test { let server = Arc::new(Server::new()); let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr, None) - .unwrap()); + let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); let thread = thread::spawn(move || serve_handle.shutdown()); info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment)); @@ -578,8 +579,7 @@ mod test { let server = Arc::new(BarrierServer::new(10)); let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr, None) - .unwrap()); + let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); let mut join_handles = vec![]; for _ in 0..10 { let my_client = client.clone(); From 06e1eaa27a7e7b388b02ce50451d55c401736652 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Thu, 14 Jan 2016 20:31:28 -0800 Subject: [PATCH 06/10] Make Serve's Request and Reply associated types. --- tarpc/src/macros.rs | 4 +++- tarpc/src/protocol.rs | 46 +++++++++++++++++++++++++------------------ 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs index 4f1183d..ccaa025 100644 --- a/tarpc/src/macros.rs +++ b/tarpc/src/macros.rs @@ -137,9 +137,11 @@ macro_rules! rpc { struct __Server(S); - impl $crate::protocol::Serve<__Request, __Reply> for __Server + impl $crate::protocol::Serve for __Server where S: 'static + Service { + type Request = __Request; + type Reply = __Reply; fn serve(&self, request: __Request) -> __Reply { match request { $( diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index 96424ea..2c19d4a 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -140,10 +140,8 @@ impl ConnectionHandler { bincode::serde::deserialize_from(&mut self.read_stream, bincode::SizeLimit::Infinite) } - fn handle_conn(&mut self, f: F) -> Result<()> - where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, - Reply: 'static + fmt::Debug + serde::ser::Serialize, - F: 'static + Clone + Serve + fn handle_conn(&mut self, f: F) -> Result<()> + where F: 'static + Clone + Serve { trace!("ConnectionHandler: serving client..."); loop { @@ -237,14 +235,12 @@ impl ServeHandle { } /// Start -pub fn serve_async(addr: A, - f: F, - read_timeout: Option) - -> io::Result +pub fn serve_async(addr: A, + f: F, + read_timeout: Option) + -> io::Result where A: ToSocketAddrs, - Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, - Reply: 'static + fmt::Debug + serde::ser::Serialize, - F: 'static + Clone + Send + Serve + F: 'static + Clone + Send + Serve { let listener = try!(TcpListener::bind(&addr)); let addr = try!(listener.local_addr()); @@ -301,15 +297,22 @@ pub fn serve_async(addr: A, } /// A service provided by a server -pub trait Serve: Send + Sync { +pub trait Serve: Send + Sync { + /// The type of request received by the server + type Request: 'static + fmt::Debug + serde::ser::Serialize + serde::de::Deserialize + Send; + /// The type of reply sent by the server + type Reply: 'static + fmt::Debug + serde::ser::Serialize + serde::de::Deserialize; + /// Return a reply for a given request - fn serve(&self, request: Request) -> Reply; + fn serve(&self, request: Self::Request) -> Self::Reply; } -impl Serve for Arc - where S: Serve +impl Serve for Arc where S: Serve { - fn serve(&self, request: Request) -> Reply { + type Request = S::Request; + type Reply = S::Reply; + + fn serve(&self, request: S::Request) -> S::Reply { S::serve(self, request) } } @@ -463,7 +466,7 @@ impl Drop for Client mod test { extern crate env_logger; - use super::*; + use super::{Client, Serve, serve_async}; use std::sync::{Arc, Mutex, Barrier}; use std::thread; use std::time::Duration; @@ -486,7 +489,10 @@ mod test { counter: Mutex, } - impl Serve for Server { + impl Serve for Server { + type Request = Request; + type Reply = Reply; + fn serve(&self, _: Request) -> Reply { let mut counter = self.counter.lock().unwrap(); let reply = Reply::Increment(*counter); @@ -538,7 +544,9 @@ mod test { inner: Server, } - impl Serve for BarrierServer { + impl Serve for BarrierServer { + type Request = Request; + type Reply = Reply; fn serve(&self, request: Request) -> Reply { self.barrier.wait(); self.inner.serve(request) From b3ed2ef0baef8d727072407bdb4051b3eaaba286 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Thu, 14 Jan 2016 20:41:29 -0800 Subject: [PATCH 07/10] Remove all shutdown logic. Just exit and deal with it. --- tarpc/src/protocol.rs | 39 ++++++++++++--------------------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index 2c19d4a..279bf6e 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -121,12 +121,6 @@ struct ConnectionHandler { impl Drop for ConnectionHandler { fn drop(&mut self) { - if let Err(e) = bincode::serde::serialize_into(&mut self.read_stream, - &Packet::Shutdown::<()>, - bincode::SizeLimit::Infinite) { - warn!("ConnectionHandler: could not notify client of shutdown: {:?}", - e); - } trace!("ConnectionHandler: finished serving client."); self.open_connections.decrement_and_notify(); } @@ -146,15 +140,14 @@ impl ConnectionHandler { trace!("ConnectionHandler: serving client..."); loop { match self.read() { - Ok(Packet::Shutdown) => break, - Ok(Packet::Message(id, message)) => { + Ok(Packet(id, message)) => { let f = f.clone(); let open_connections = self.open_connections.clone(); open_connections.increment(); let stream = self.write_stream.clone(); thread::spawn(move || { let reply = f.serve(message); - let reply_packet = Packet::Message(id, reply); + let reply_packet = Packet(id, reply); let mut stream = stream.lock().unwrap(); if let Err(e) = bincode::serde::serialize_into(&mut *stream, @@ -318,10 +311,7 @@ impl Serve for Arc where S: Serve } #[derive(Debug, Clone, Serialize, Deserialize)] -enum Packet { - Message(u64, T), - Shutdown, -} +struct Packet(u64, T); struct Reader { requests: Arc>>>>, @@ -335,19 +325,20 @@ impl Reader { let packet: bincode::serde::DeserializeResult> = bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); match packet { - Ok(Packet::Message(id, reply)) => { + Ok(Packet(id, reply)) => { debug!("Client: received message, id={}", id); let mut requests = self.requests.lock().unwrap(); let mut requests = requests.as_mut().unwrap(); let reply_tx = requests.remove(&id).unwrap(); reply_tx.send(reply).unwrap(); } - Ok(Packet::Shutdown) => { - info!("Client: got shutdown message."); + // TODO: This shutdown logic is janky.. What's the right way to do this? + Err(err) => { + warn!("Client: reader thread encountered an unexpected error while parsing; \ + returning now. Error: {:?}", + err); break; } - // TODO: This shutdown logic is janky.. What's the right way to do this? - Err(err) => panic!("unexpected error while parsing!: {:?}", err), } } } @@ -421,7 +412,7 @@ impl Client return Err(Error::ConnectionBroken); } } - let packet = Packet::Message(id, request); + let packet = Packet(id, request); try!(state.stream.set_write_timeout(self.timeout)); try!(state.stream.set_read_timeout(self.timeout)); debug!("Client: calling rpc({:?})", request); @@ -449,14 +440,8 @@ impl Drop for Client where Request: serde::ser::Serialize { fn drop(&mut self) { - { - let mut state = self.synced_state.lock().unwrap(); - let packet: Packet = Packet::Shutdown; - if let Err(err) = bincode::serde::serialize_into(&mut state.stream, - &packet, - bincode::SizeLimit::Infinite) { - warn!("While disconnecting client from server: {:?}", err); - } + if let Err(e) = self.synced_state.lock().unwrap().stream.shutdown(::std::net::Shutdown::Both) { + warn!("Client: couldn't shutdown reader thread: {:?}", e); } self.reader_guard.take().unwrap().join().unwrap(); } From e4faff74be29b30547c1c0427b14053991faa09d Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Thu, 14 Jan 2016 20:45:09 -0800 Subject: [PATCH 08/10] reformat --- tarpc/src/macros.rs | 16 ++++++++++------ tarpc/src/protocol.rs | 18 +++++++++--------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs index ccaa025..a71d412 100644 --- a/tarpc/src/macros.rs +++ b/tarpc/src/macros.rs @@ -49,15 +49,15 @@ macro_rules! request_variant { // The main macro that creates RPC services. #[macro_export] -macro_rules! rpc { +macro_rules! rpc { ( mod $server:ident { service { - $( + $( $(#[$attr:meta])* rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty; - )* + )* } } ) => { @@ -66,7 +66,7 @@ macro_rules! rpc { items { } - service { + service { $( $(#[$attr])* rpc $fn_name($($arg: $in_),*) -> $out; @@ -125,7 +125,8 @@ macro_rules! rpc { impl Client { #[doc="Create a new client that connects to the given address."] - pub fn new(addr: A, timeout: ::std::option::Option<::std::time::Duration>) -> $crate::Result + pub fn new(addr: A, timeout: ::std::option::Option<::std::time::Duration>) + -> $crate::Result where A: ::std::net::ToSocketAddrs, { let inner = try!($crate::protocol::Client::new(addr, timeout)); @@ -153,7 +154,10 @@ macro_rules! rpc { } #[doc="Start a running service."] - pub fn serve(addr: A, service: S, read_timeout: ::std::option::Option<::std::time::Duration>) -> $crate::Result<$crate::protocol::ServeHandle> + pub fn serve(addr: A, + service: S, + read_timeout: ::std::option::Option<::std::time::Duration>) + -> $crate::Result<$crate::protocol::ServeHandle> where A: ::std::net::ToSocketAddrs, S: 'static + Service { diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index 279bf6e..d498f5f 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -196,7 +196,7 @@ impl ConnectionHandler { } } -/// Provides methods for blocking until the server completes, +/// Provides methods for blocking until the server completes, pub struct ServeHandle { tx: Sender<()>, join_handle: JoinHandle<()>, @@ -227,11 +227,8 @@ impl ServeHandle { } } -/// Start -pub fn serve_async(addr: A, - f: F, - read_timeout: Option) - -> io::Result +/// Start +pub fn serve_async(addr: A, f: F, read_timeout: Option) -> io::Result where A: ToSocketAddrs, F: 'static + Clone + Send + Serve { @@ -332,10 +329,9 @@ impl Reader { let reply_tx = requests.remove(&id).unwrap(); reply_tx.send(reply).unwrap(); } - // TODO: This shutdown logic is janky.. What's the right way to do this? Err(err) => { warn!("Client: reader thread encountered an unexpected error while parsing; \ - returning now. Error: {:?}", + returning now. Error: {:?}", err); break; } @@ -440,7 +436,11 @@ impl Drop for Client where Request: serde::ser::Serialize { fn drop(&mut self) { - if let Err(e) = self.synced_state.lock().unwrap().stream.shutdown(::std::net::Shutdown::Both) { + if let Err(e) = self.synced_state + .lock() + .unwrap() + .stream + .shutdown(::std::net::Shutdown::Both) { warn!("Client: couldn't shutdown reader thread: {:?}", e); } self.reader_guard.take().unwrap().join().unwrap(); From ebd825e679165360d62b7eea74cc8a04711853b7 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Fri, 15 Jan 2016 00:28:42 -0800 Subject: [PATCH 09/10] Bundle of small changes. 1. Rename OpenConnections => InflightRpcs, as it represents all current rpc calls being processed. 2. Change Packet from a tuple struct to a regular struct, to clarify its fields. 3. Lower log statements from WARN to INFO where appropriate. 4. Remove shutdown method on ConnectionHandler to disambiguate with the shutdown field. 5. Add a test of client behavior when calling rpc on a client whose stream closed. --- tarpc/src/protocol.rs | 94 +++++++++++++++++++++++++++---------------- 1 file changed, 59 insertions(+), 35 deletions(-) diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index d498f5f..f0d22a1 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -73,17 +73,17 @@ impl convert::From for Error { pub type Result = ::std::result::Result; #[derive(Clone)] -struct OpenConnections { - open_connections: Arc<(Mutex, Condvar)>, +struct InflightRpcs { + inflight_rpcs: Arc<(Mutex, Condvar)>, } -impl OpenConnections { - fn new(mutex: Mutex, cvar: Condvar) -> OpenConnections { - OpenConnections { open_connections: Arc::new((mutex, cvar)) } +impl InflightRpcs { + fn new(mutex: Mutex, cvar: Condvar) -> InflightRpcs { + InflightRpcs { inflight_rpcs: Arc::new((mutex, cvar)) } } fn wait_until_zero(&self) { - let &(ref count, ref cvar) = &*self.open_connections; + let &(ref count, ref cvar) = &*self.inflight_rpcs; let mut count = count.lock().unwrap(); while *count != 0 { count = cvar.wait(count).unwrap(); @@ -93,18 +93,18 @@ impl OpenConnections { } fn increment(&self) { - let &(ref count, _) = &*self.open_connections; + let &(ref count, _) = &*self.inflight_rpcs; *count.lock().unwrap() += 1; } fn decrement(&self) { - let &(ref count, _) = &*self.open_connections; + let &(ref count, _) = &*self.inflight_rpcs; *count.lock().unwrap() -= 1; } fn decrement_and_notify(&self) { - let &(ref count, ref cvar) = &*self.open_connections; + let &(ref count, ref cvar) = &*self.inflight_rpcs; *count.lock().unwrap() -= 1; cvar.notify_one(); } @@ -115,14 +115,14 @@ struct ConnectionHandler { read_stream: TcpStream, write_stream: Arc>, shutdown: Arc, - open_connections: OpenConnections, + inflight_rpcs: InflightRpcs, timeout: Option, } impl Drop for ConnectionHandler { fn drop(&mut self) { trace!("ConnectionHandler: finished serving client."); - self.open_connections.decrement_and_notify(); + self.inflight_rpcs.decrement_and_notify(); } } @@ -140,14 +140,20 @@ impl ConnectionHandler { trace!("ConnectionHandler: serving client..."); loop { match self.read() { - Ok(Packet(id, message)) => { + Ok(Packet { + rpc_id: id, + message: message + }) => { let f = f.clone(); - let open_connections = self.open_connections.clone(); - open_connections.increment(); + let inflight_rpcs = self.inflight_rpcs.clone(); + inflight_rpcs.increment(); let stream = self.write_stream.clone(); thread::spawn(move || { let reply = f.serve(message); - let reply_packet = Packet(id, reply); + let reply_packet = Packet { + rpc_id: id, + message: reply + }; let mut stream = stream.lock().unwrap(); if let Err(e) = bincode::serde::serialize_into(&mut *stream, @@ -156,18 +162,22 @@ impl ConnectionHandler { warn!("ConnectionHandler: failed to write reply to Client: {:?}", e); } - open_connections.decrement(); + inflight_rpcs.decrement(); }); + if self.shutdown.load(Ordering::SeqCst) { + info!("ConnectionHandler: server shutdown, so closing connection."); + break; + } } Err(bincode::serde::DeserializeError::IoError(ref err)) if Self::timed_out(err.kind()) => { - if !self.shutdown() { - warn!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \ + if !self.shutdown.load(Ordering::SeqCst) { + info!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \ retrying read.", err); continue; } else { - warn!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ + info!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ closing connection.", err); break; @@ -184,10 +194,6 @@ impl ConnectionHandler { Ok(()) } - fn shutdown(&self) -> bool { - self.shutdown.load(Ordering::SeqCst) - } - fn timed_out(error_kind: io::ErrorKind) -> bool { match error_kind { io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock => true, @@ -238,14 +244,14 @@ pub fn serve_async(addr: A, f: F, read_timeout: Option) -> io::R let (die_tx, die_rx) = channel(); let join_handle = thread::spawn(move || { let shutdown = Arc::new(AtomicBool::new(false)); - let open_connections = OpenConnections::new(Mutex::new(0), Condvar::new()); + let inflight_rpcs = InflightRpcs::new(Mutex::new(0), Condvar::new()); for conn in listener.incoming() { match die_rx.try_recv() { Ok(_) => { info!("serve_async: shutdown received. Waiting for open connections to \ return..."); shutdown.store(true, Ordering::SeqCst); - open_connections.wait_until_zero(); + inflight_rpcs.wait_until_zero(); break; } Err(TryRecvError::Disconnected) => { @@ -263,13 +269,13 @@ pub fn serve_async(addr: A, f: F, read_timeout: Option) -> io::R }; let f = f.clone(); let shutdown = shutdown.clone(); - open_connections.increment(); - let open_connections = open_connections.clone(); + inflight_rpcs.increment(); + let inflight_rpcs = inflight_rpcs.clone(); let mut handler = ConnectionHandler { read_stream: conn.try_clone().unwrap(), write_stream: Arc::new(Mutex::new(conn)), shutdown: shutdown, - open_connections: open_connections, + inflight_rpcs: inflight_rpcs, timeout: read_timeout, }; thread::spawn(move || { @@ -308,7 +314,10 @@ impl Serve for Arc where S: Serve } #[derive(Debug, Clone, Serialize, Deserialize)] -struct Packet(u64, T); +struct Packet { + rpc_id: u64, + message: T +} struct Reader { requests: Arc>>>>, @@ -322,7 +331,10 @@ impl Reader { let packet: bincode::serde::DeserializeResult> = bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); match packet { - Ok(Packet(id, reply)) => { + Ok(Packet { + rpc_id: id, + message: reply + }) => { debug!("Client: received message, id={}", id); let mut requests = self.requests.lock().unwrap(); let mut requests = requests.as_mut().unwrap(); @@ -408,7 +420,10 @@ impl Client return Err(Error::ConnectionBroken); } } - let packet = Packet(id, request); + let packet = Packet { + rpc_id: id, + message: request, + }; try!(state.stream.set_write_timeout(self.timeout)); try!(state.stream.set_read_timeout(self.timeout)); debug!("Client: calling rpc({:?})", request); @@ -559,13 +574,22 @@ mod test { let addr = serve_handle.local_addr().clone(); let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); let thread = thread::spawn(move || serve_handle.shutdown()); - info!("force_shutdown:: rpc1: {:?}", - client.rpc(&Request::Increment)); - info!("force_shutdown:: rpc2: {:?}", - client.rpc(&Request::Increment)); + info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment)); thread.join().unwrap(); } + #[test] + fn client_failed_rpc() { + let _ = env_logger::init(); + let server = Arc::new(Server::new()); + let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap(); + let addr = serve_handle.local_addr().clone(); + let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); + serve_handle.shutdown(); + let _ = client.rpc(&Request::Increment); // First failure will trigger reader to shutdown + let _ = client.rpc(&Request::Increment); // Test whether second failure hangs + } + #[test] fn concurrent() { let _ = env_logger::init(); From 001b1b1e43b7d8d0fd4540a1c39d97d3cdaeb9af Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Fri, 15 Jan 2016 00:34:36 -0800 Subject: [PATCH 10/10] Only join reader thread if tcp stream shutdown succeeded. --- tarpc/src/protocol.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index f0d22a1..1414570 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -457,8 +457,9 @@ impl Drop for Client .stream .shutdown(::std::net::Shutdown::Both) { warn!("Client: couldn't shutdown reader thread: {:?}", e); + } else { + self.reader_guard.take().unwrap().join().unwrap(); } - self.reader_guard.take().unwrap().join().unwrap(); } }