From 0df3cfdd9840263a928887e3afc43357f5c4eacf Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Thu, 14 Jan 2016 09:22:46 -0800 Subject: [PATCH] Properly wait for spawned connection handler threads to shutdown. Set client timeout to None in tests. --- tarpc/src/protocol.rs | 134 +++++++++++++++++++++++++++--------------- 1 file changed, 85 insertions(+), 49 deletions(-) diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index c4a4378..3e54e7d 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -72,56 +72,101 @@ impl convert::From for Error { /// Return type of rpc calls: either the successful return value, or a client error. pub type Result = ::std::result::Result; -struct ConnectionHandler { - shutdown: Arc, +#[derive(Clone)] +struct OpenConnections { open_connections: Arc<(Mutex, Condvar)>, +} + +impl OpenConnections { + fn new(mutex: Mutex, cvar: Condvar) -> OpenConnections { + OpenConnections { + open_connections: Arc::new((mutex, cvar)), + } + } + + fn wait_until_zero(&self) { + let &(ref count, ref cvar) = &*self.open_connections; + let mut count = count.lock().unwrap(); + while *count != 0 { + count = cvar.wait(count).unwrap(); + } + info!("serve_async: shutdown complete ({} connections alive)", *count); + } + + fn increment(&self) { + let &(ref count, _) = &*self.open_connections; + *count.lock().unwrap() += 1; + } + + fn decrement(&self) { + let &(ref count, _) = &*self.open_connections; + *count.lock().unwrap() -= 1; + } + + + fn decrement_and_notify(&self) { + let &(ref count, ref cvar) = &*self.open_connections; + *count.lock().unwrap() -= 1; + cvar.notify_one(); + } + +} + +struct ConnectionHandler { + read_stream: TcpStream, + write_stream: Arc>, + shutdown: Arc, + open_connections: OpenConnections, timeout: Option, } impl Drop for ConnectionHandler { fn drop(&mut self) { - let &(ref count, ref cvar) = &*self.open_connections; - *count.lock().unwrap() -= 1; - cvar.notify_one(); + if let Err(e) = bincode::serde::serialize_into(&mut self.read_stream, + &Packet::Shutdown::<()>, + bincode::SizeLimit::Infinite) { + warn!("ConnectionHandler: could not notify client of shutdown: {:?}", e); + } trace!("ConnectionHandler: finished serving client."); + self.open_connections.decrement_and_notify(); } } impl ConnectionHandler { - fn handle_conn(&self, stream: TcpStream, f: F) -> Result<()> + fn read(&mut self) -> bincode::serde::DeserializeResult> + where Request: serde::de::Deserialize + { + try!(self.read_stream.set_read_timeout(self.timeout)); + bincode::serde::deserialize_from(&mut self.read_stream, bincode::SizeLimit::Infinite) + } + + fn handle_conn(&mut self, f: F) -> Result<()> where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, Reply: 'static + fmt::Debug + serde::ser::Serialize, F: 'static + Clone + Serve { trace!("ConnectionHandler: serving client..."); - let mut read_stream = try!(stream.try_clone()); - let stream = Arc::new(Mutex::new(stream)); loop { - try!(read_stream.set_read_timeout(self.timeout)); - match bincode::serde::deserialize_from(&mut read_stream, bincode::SizeLimit::Infinite) { - Ok(request_packet @ Packet::Shutdown) => { - let stream = stream.clone(); - let mut my_stream = stream.lock().unwrap(); - try!(bincode::serde::serialize_into(&mut *my_stream, - &request_packet, - bincode::SizeLimit::Infinite)); - break; - } + match self.read() { + Ok(Packet::Shutdown) => break, Ok(Packet::Message(id, message)) => { let f = f.clone(); - let arc_stream = stream.clone(); + let open_connections = self.open_connections.clone(); + open_connections.increment(); + let stream = self.write_stream.clone(); thread::spawn(move || { let reply = f.serve(message); let reply_packet = Packet::Message(id, reply); - let mut my_stream = arc_stream.lock().unwrap(); - bincode::serde::serialize_into(&mut *my_stream, + let mut stream = stream.lock().unwrap(); + if let Err(e) = bincode::serde::serialize_into(&mut *stream, &reply_packet, - bincode::SizeLimit::Infinite) - .unwrap(); + bincode::SizeLimit::Infinite) { + warn!("ConnectionHandler: failed to write reply to Client: {:?}", e); + } + open_connections.decrement(); }); } - Err(bincode::serde::DeserializeError::IoError(ref err)) - if Self::timed_out(err.kind()) => { + Err(bincode::serde::DeserializeError::IoError(ref err)) if Self::timed_out(err.kind()) => { if !self.shutdown() { warn!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \ retrying read.", @@ -131,10 +176,6 @@ impl ConnectionHandler { warn!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ closing connection.", err); - let mut stream = stream.lock().unwrap(); - try!(bincode::serde::serialize_into(&mut *stream, - &Packet::Shutdown::, - bincode::SizeLimit::Infinite)); break; } } @@ -208,20 +249,14 @@ pub fn serve_async(addr: A, let (die_tx, die_rx) = channel(); let join_handle = thread::spawn(move || { let shutdown = Arc::new(AtomicBool::new(false)); - let open_connections = Arc::new((Mutex::new(0), Condvar::new())); + let open_connections = OpenConnections::new(Mutex::new(0), Condvar::new()); for conn in listener.incoming() { match die_rx.try_recv() { Ok(_) => { info!("serve_async: shutdown received. Waiting for open connections to \ return..."); shutdown.store(true, Ordering::SeqCst); - let &(ref count, ref cvar) = &*open_connections; - let mut count = count.lock().unwrap(); - while *count != 0 { - count = cvar.wait(count).unwrap(); - } - info!("serve_async: shutdown complete ({} connections alive)", - *count); + open_connections.wait_until_zero(); break; } Err(TryRecvError::Disconnected) => { @@ -239,16 +274,17 @@ pub fn serve_async(addr: A, }; let f = f.clone(); let shutdown = shutdown.clone(); - let &(ref count, _) = &*open_connections; - *count.lock().unwrap() += 1; + open_connections.increment(); let open_connections = open_connections.clone(); + let mut handler = ConnectionHandler { + read_stream: conn.try_clone().unwrap(), + write_stream: Arc::new(Mutex::new(conn)), + shutdown: shutdown, + open_connections: open_connections, + timeout: read_timeout, + }; thread::spawn(move || { - let handler = ConnectionHandler { - shutdown: shutdown, - open_connections: open_connections, - timeout: read_timeout, - }; - if let Err(err) = handler.handle_conn(conn, f) { + if let Err(err) = handler.handle_conn(f) { error!("ConnectionHandler: error in connection handling: {:?}", err); } }); @@ -472,7 +508,7 @@ mod test { let server = Arc::new(Server::new()); let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); let client: Client = Client::new(serve_handle.local_addr().clone(), - test_timeout()) + None) .expect(&line!().to_string()); drop(client); serve_handle.shutdown(); @@ -484,7 +520,7 @@ mod test { let server = Arc::new(Server::new()); let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); - let client = Client::new(addr, test_timeout()).unwrap(); + let client = Client::new(addr, None).unwrap(); assert_eq!(Reply::Increment(0), client.rpc(&Request::Increment).unwrap()); assert_eq!(1, server.count()); @@ -526,7 +562,7 @@ mod test { let server = Arc::new(Server::new()); let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr, test_timeout()) + let client: Arc> = Arc::new(Client::new(addr, None) .unwrap()); let thread = thread::spawn(move || serve_handle.shutdown()); info!("force_shutdown:: rpc1: {:?}", @@ -542,7 +578,7 @@ mod test { let server = Arc::new(BarrierServer::new(10)); let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr, test_timeout()) + let client: Arc> = Arc::new(Client::new(addr, None) .unwrap()); let mut join_handles = vec![]; for _ in 0..10 {