diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index d498f5f..f0d22a1 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -73,17 +73,17 @@ impl convert::From for Error { pub type Result = ::std::result::Result; #[derive(Clone)] -struct OpenConnections { - open_connections: Arc<(Mutex, Condvar)>, +struct InflightRpcs { + inflight_rpcs: Arc<(Mutex, Condvar)>, } -impl OpenConnections { - fn new(mutex: Mutex, cvar: Condvar) -> OpenConnections { - OpenConnections { open_connections: Arc::new((mutex, cvar)) } +impl InflightRpcs { + fn new(mutex: Mutex, cvar: Condvar) -> InflightRpcs { + InflightRpcs { inflight_rpcs: Arc::new((mutex, cvar)) } } fn wait_until_zero(&self) { - let &(ref count, ref cvar) = &*self.open_connections; + let &(ref count, ref cvar) = &*self.inflight_rpcs; let mut count = count.lock().unwrap(); while *count != 0 { count = cvar.wait(count).unwrap(); @@ -93,18 +93,18 @@ impl OpenConnections { } fn increment(&self) { - let &(ref count, _) = &*self.open_connections; + let &(ref count, _) = &*self.inflight_rpcs; *count.lock().unwrap() += 1; } fn decrement(&self) { - let &(ref count, _) = &*self.open_connections; + let &(ref count, _) = &*self.inflight_rpcs; *count.lock().unwrap() -= 1; } fn decrement_and_notify(&self) { - let &(ref count, ref cvar) = &*self.open_connections; + let &(ref count, ref cvar) = &*self.inflight_rpcs; *count.lock().unwrap() -= 1; cvar.notify_one(); } @@ -115,14 +115,14 @@ struct ConnectionHandler { read_stream: TcpStream, write_stream: Arc>, shutdown: Arc, - open_connections: OpenConnections, + inflight_rpcs: InflightRpcs, timeout: Option, } impl Drop for ConnectionHandler { fn drop(&mut self) { trace!("ConnectionHandler: finished serving client."); - self.open_connections.decrement_and_notify(); + self.inflight_rpcs.decrement_and_notify(); } } @@ -140,14 +140,20 @@ impl ConnectionHandler { trace!("ConnectionHandler: serving client..."); loop { match self.read() { - Ok(Packet(id, message)) => { + Ok(Packet { + rpc_id: id, + message: message + }) => { let f = f.clone(); - let open_connections = self.open_connections.clone(); - open_connections.increment(); + let inflight_rpcs = self.inflight_rpcs.clone(); + inflight_rpcs.increment(); let stream = self.write_stream.clone(); thread::spawn(move || { let reply = f.serve(message); - let reply_packet = Packet(id, reply); + let reply_packet = Packet { + rpc_id: id, + message: reply + }; let mut stream = stream.lock().unwrap(); if let Err(e) = bincode::serde::serialize_into(&mut *stream, @@ -156,18 +162,22 @@ impl ConnectionHandler { warn!("ConnectionHandler: failed to write reply to Client: {:?}", e); } - open_connections.decrement(); + inflight_rpcs.decrement(); }); + if self.shutdown.load(Ordering::SeqCst) { + info!("ConnectionHandler: server shutdown, so closing connection."); + break; + } } Err(bincode::serde::DeserializeError::IoError(ref err)) if Self::timed_out(err.kind()) => { - if !self.shutdown() { - warn!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \ + if !self.shutdown.load(Ordering::SeqCst) { + info!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \ retrying read.", err); continue; } else { - warn!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ + info!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ closing connection.", err); break; @@ -184,10 +194,6 @@ impl ConnectionHandler { Ok(()) } - fn shutdown(&self) -> bool { - self.shutdown.load(Ordering::SeqCst) - } - fn timed_out(error_kind: io::ErrorKind) -> bool { match error_kind { io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock => true, @@ -238,14 +244,14 @@ pub fn serve_async(addr: A, f: F, read_timeout: Option) -> io::R let (die_tx, die_rx) = channel(); let join_handle = thread::spawn(move || { let shutdown = Arc::new(AtomicBool::new(false)); - let open_connections = OpenConnections::new(Mutex::new(0), Condvar::new()); + let inflight_rpcs = InflightRpcs::new(Mutex::new(0), Condvar::new()); for conn in listener.incoming() { match die_rx.try_recv() { Ok(_) => { info!("serve_async: shutdown received. Waiting for open connections to \ return..."); shutdown.store(true, Ordering::SeqCst); - open_connections.wait_until_zero(); + inflight_rpcs.wait_until_zero(); break; } Err(TryRecvError::Disconnected) => { @@ -263,13 +269,13 @@ pub fn serve_async(addr: A, f: F, read_timeout: Option) -> io::R }; let f = f.clone(); let shutdown = shutdown.clone(); - open_connections.increment(); - let open_connections = open_connections.clone(); + inflight_rpcs.increment(); + let inflight_rpcs = inflight_rpcs.clone(); let mut handler = ConnectionHandler { read_stream: conn.try_clone().unwrap(), write_stream: Arc::new(Mutex::new(conn)), shutdown: shutdown, - open_connections: open_connections, + inflight_rpcs: inflight_rpcs, timeout: read_timeout, }; thread::spawn(move || { @@ -308,7 +314,10 @@ impl Serve for Arc where S: Serve } #[derive(Debug, Clone, Serialize, Deserialize)] -struct Packet(u64, T); +struct Packet { + rpc_id: u64, + message: T +} struct Reader { requests: Arc>>>>, @@ -322,7 +331,10 @@ impl Reader { let packet: bincode::serde::DeserializeResult> = bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); match packet { - Ok(Packet(id, reply)) => { + Ok(Packet { + rpc_id: id, + message: reply + }) => { debug!("Client: received message, id={}", id); let mut requests = self.requests.lock().unwrap(); let mut requests = requests.as_mut().unwrap(); @@ -408,7 +420,10 @@ impl Client return Err(Error::ConnectionBroken); } } - let packet = Packet(id, request); + let packet = Packet { + rpc_id: id, + message: request, + }; try!(state.stream.set_write_timeout(self.timeout)); try!(state.stream.set_read_timeout(self.timeout)); debug!("Client: calling rpc({:?})", request); @@ -559,13 +574,22 @@ mod test { let addr = serve_handle.local_addr().clone(); let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); let thread = thread::spawn(move || serve_handle.shutdown()); - info!("force_shutdown:: rpc1: {:?}", - client.rpc(&Request::Increment)); - info!("force_shutdown:: rpc2: {:?}", - client.rpc(&Request::Increment)); + info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment)); thread.join().unwrap(); } + #[test] + fn client_failed_rpc() { + let _ = env_logger::init(); + let server = Arc::new(Server::new()); + let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap(); + let addr = serve_handle.local_addr().clone(); + let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); + serve_handle.shutdown(); + let _ = client.rpc(&Request::Increment); // First failure will trigger reader to shutdown + let _ = client.rpc(&Request::Increment); // Test whether second failure hangs + } + #[test] fn concurrent() { let _ = env_logger::init();