diff --git a/tarpc/src/protocol/mod.rs b/tarpc/src/protocol/mod.rs index 216cde9..45eb4de 100644 --- a/tarpc/src/protocol/mod.rs +++ b/tarpc/src/protocol/mod.rs @@ -35,8 +35,10 @@ impl convert::From for Error { impl convert::From for Error { fn from(err: bincode::serde::DeserializeError) -> Error { match err { - bincode::serde::DeserializeError::IoError(err) => Error::Io(Arc::new(err)), + bincode::serde::DeserializeError::IoError(ref err) + if err.kind() == io::ErrorKind::ConnectionReset => Error::ConnectionBroken, bincode::serde::DeserializeError::EndOfStreamError => Error::ConnectionBroken, + bincode::serde::DeserializeError::IoError(err) => Error::Io(Arc::new(err)), err => panic!("Unexpected error during deserialization: {:?}", err), } } @@ -179,6 +181,7 @@ mod test { let serve_handle = serve_async("localhost:0", server, test_timeout()).unwrap(); let addr = serve_handle.local_addr().clone(); let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); + client.rpc(Request::Increment).unwrap(); serve_handle.shutdown(); match client.rpc(Request::Increment) { Err(super::Error::ConnectionBroken) => {} // success diff --git a/tarpc/src/protocol/server.rs b/tarpc/src/protocol/server.rs index 4145923..37b5e77 100644 --- a/tarpc/src/protocol/server.rs +++ b/tarpc/src/protocol/server.rs @@ -5,11 +5,10 @@ use bincode; use serde; -use scoped_pool::Pool; +use scoped_pool::{Pool, Scope}; use std::fmt; use std::io::{self, BufReader, BufWriter, Write}; use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs}; -use std::sync::{Condvar, Mutex}; use std::sync::mpsc::{Receiver, Sender, TryRecvError, channel}; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; @@ -21,74 +20,62 @@ struct ConnectionHandler<'a, S> { read_stream: BufReader, write_stream: BufWriter, - shutdown: &'a AtomicBool, - inflight_rpcs: &'a InflightRpcs, server: S, - pool: &'a Pool, + shutdown: &'a AtomicBool, } -impl<'a, S> Drop for ConnectionHandler<'a, S> where S: Serve { - fn drop(&mut self) { - trace!("ConnectionHandler: finished serving client."); - self.inflight_rpcs.decrement_and_notify(); - } -} - -impl<'a, S> ConnectionHandler<'a, S> where S: Serve { - fn handle_conn(&mut self) -> Result<()> { +impl<'a, S> ConnectionHandler<'a, S> + where S: Serve +{ + fn handle_conn<'b>(&'b mut self, scope: &Scope<'b>) -> Result<()> { let ConnectionHandler { ref mut read_stream, ref mut write_stream, - shutdown, - ref inflight_rpcs, ref server, - pool, + shutdown, } = *self; trace!("ConnectionHandler: serving client..."); - pool.scoped(|scope| { - let (tx, rx) = channel(); - scope.execute(|| Self::write(rx, write_stream, inflight_rpcs)); - loop { - match bincode::serde::deserialize_from(read_stream, bincode::SizeLimit::Infinite) { - Ok(Packet { rpc_id, message, }) => { - inflight_rpcs.increment(); - let tx = tx.clone(); - scope.execute(move || { - let reply = server.serve(message); - let reply_packet = Packet { - rpc_id: rpc_id, - message: reply - }; - tx.send(reply_packet).unwrap(); - }); - if shutdown.load(Ordering::SeqCst) { - info!("ConnectionHandler: server shutdown, so closing connection."); - break; - } - } - Err(bincode::serde::DeserializeError::IoError(ref err)) - if Self::timed_out(err.kind()) => { - if !shutdown.load(Ordering::SeqCst) { - info!("ConnectionHandler: read timed out ({:?}). Server not \ - shutdown, so retrying read.", - err); - continue; - } else { - info!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ - closing connection.", - err); - break; - } - } - Err(e) => { - warn!("ConnectionHandler: closing client connection due to {:?}", - e); - return Err(e.into()); + let (tx, rx) = channel(); + scope.execute(move || Self::write(rx, write_stream)); + loop { + match bincode::serde::deserialize_from(read_stream, bincode::SizeLimit::Infinite) { + Ok(Packet { rpc_id, message, }) => { + let tx = tx.clone(); + scope.execute(move || { + let reply = server.serve(message); + let reply_packet = Packet { + rpc_id: rpc_id, + message: reply + }; + tx.send(reply_packet).expect(pos!()); + }); + if shutdown.load(Ordering::SeqCst) { + info!("ConnectionHandler: server shutdown, so closing connection."); + break; } } + Err(bincode::serde::DeserializeError::IoError(ref err)) + if Self::timed_out(err.kind()) => { + if !shutdown.load(Ordering::SeqCst) { + info!("ConnectionHandler: read timed out ({:?}). Server not \ + shutdown, so retrying read.", + err); + continue; + } else { + info!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ + closing connection.", + err); + break; + } + } + Err(e) => { + warn!("ConnectionHandler: closing client connection due to {:?}", + e); + return Err(e.into()); + } } - Ok(()) - }) + } + Ok(()) } fn timed_out(error_kind: io::ErrorKind) -> bool { @@ -98,9 +85,7 @@ impl<'a, S> ConnectionHandler<'a, S> where S: Serve { } } - fn write(rx: Receiver::Reply>>, - stream: &mut BufWriter, - inflight_rpcs: &InflightRpcs) { + fn write(rx: Receiver::Reply>>, stream: &mut BufWriter) { loop { match rx.recv() { Err(e) => { @@ -119,51 +104,12 @@ impl<'a, S> ConnectionHandler<'a, S> where S: Serve { warn!("Writer: failed to flush reply to Client: {:?}", e); } - inflight_rpcs.decrement(); } } } } } -struct InflightRpcs { - count: Mutex, - cvar: Condvar, -} - -impl InflightRpcs { - fn new() -> InflightRpcs { - InflightRpcs { - count: Mutex::new(0), - cvar: Condvar::new(), - } - } - - fn wait_until_zero(&self) { - let mut count = self.count.lock().unwrap(); - while *count != 0 { - count = self.cvar.wait(count).unwrap(); - } - info!("serve_async: shutdown complete ({} connections alive)", - *count); - } - - fn increment(&self) { - *self.count.lock().unwrap() += 1; - } - - fn decrement(&self) { - *self.count.lock().unwrap() -= 1; - } - - - fn decrement_and_notify(&self) { - *self.count.lock().unwrap() -= 1; - self.cvar.notify_one(); - } - -} - /// Provides methods for blocking until the server completes, pub struct ServeHandle { tx: Sender<()>, @@ -195,6 +141,74 @@ impl ServeHandle { } } +struct Server<'a, S: 'a> { + server: &'a S, + listener: TcpListener, + read_timeout: Option, + die_rx: Receiver<()>, + shutdown: &'a AtomicBool, +} + +impl<'a, S: 'a> Server<'a, S> + where S: Serve + 'static +{ + fn serve<'b>(self, scope: &Scope<'b>) where 'a: 'b { + for conn in self.listener.incoming() { + match self.die_rx.try_recv() { + Ok(_) => { + info!("serve: shutdown received."); + return; + } + Err(TryRecvError::Disconnected) => { + info!("serve: shutdown sender disconnected."); + return; + } + _ => (), + } + let conn = match conn { + Err(err) => { + error!("serve: failed to accept connection: {:?}", err); + return; + } + Ok(c) => c, + }; + if let Err(err) = conn.set_read_timeout(self.read_timeout) { + info!("serve: could not set read timeout: {:?}", err); + continue; + } + let read_conn = match conn.try_clone() { + Err(err) => { + error!("serve: could not clone tcp stream; possibly out of file descriptors? \ + Err: {:?}", + err); + continue; + } + Ok(conn) => conn, + }; + let mut handler = ConnectionHandler { + read_stream: BufReader::new(read_conn), + write_stream: BufWriter::new(conn), + server: self.server, + shutdown: self.shutdown, + }; + scope.recurse(move |scope| { + scope.zoom(|scope| { + if let Err(err) = handler.handle_conn(scope) { + info!("ConnectionHandler: err in connection handling: {:?}", err); + } + }); + }); + } + } +} + +impl<'a, S> Drop for Server<'a, S> { + fn drop(&mut self) { + debug!("Shutting down connection handlers."); + self.shutdown.store(true, Ordering::SeqCst); + } +} + /// Start pub fn serve_async(addr: A, server: S, @@ -210,49 +224,15 @@ pub fn serve_async(addr: A, let join_handle = thread::spawn(move || { let pool = Pool::new(100); // TODO(tjk): make this configurable, and expire idle threads let shutdown = AtomicBool::new(false); - let inflight_rpcs = InflightRpcs::new(); + let server = Server { + server: &server, + listener: listener, + read_timeout: read_timeout, + die_rx: die_rx, + shutdown: &shutdown, + }; pool.scoped(|scope| { - for conn in listener.incoming() { - match die_rx.try_recv() { - Ok(_) => { - info!("serve_async: shutdown received. Waiting for open connections to \ - return..."); - shutdown.store(true, Ordering::SeqCst); - inflight_rpcs.wait_until_zero(); - break; - } - Err(TryRecvError::Disconnected) => { - info!("serve_async: sender disconnected."); - break; - } - _ => (), - } - let conn = match conn { - Err(err) => { - error!("serve_async: failed to accept connection: {:?}", err); - return; - } - Ok(c) => c, - }; - if let Err(err) = conn.set_read_timeout(read_timeout) { - info!("Server: could not set read timeout: {:?}", err); - return; - } - inflight_rpcs.increment(); - scope.execute(|| { - let mut handler = ConnectionHandler { - read_stream: BufReader::new(conn.try_clone().expect(pos!())), - write_stream: BufWriter::new(conn), - shutdown: &shutdown, - inflight_rpcs: &inflight_rpcs, - server: &server, - pool: &pool, - }; - if let Err(err) = handler.handle_conn() { - info!("ConnectionHandler: err in connection handling: {:?}", err); - } - }); - } + server.serve(scope); }); }); Ok(ServeHandle {