diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 6a35757..2e79b24 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -9,3 +9,4 @@ bincode = "*" serde_macros = "*" log = "*" env_logger = "*" +crossbeam = "*" diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 46809fa..fcba8c2 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -60,6 +60,7 @@ extern crate serde; extern crate bincode; #[macro_use] extern crate log; +extern crate crossbeam; /// Provides the tarpc client and server, which implements the tarpc protocol. /// The protocol is defined by the implementation. diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index 56a370b..13ad1ab 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -8,6 +8,7 @@ use bincode; use serde; +use crossbeam; use std::fmt; use std::io::{self, Read}; use std::convert; @@ -58,121 +59,121 @@ impl convert::From for Error { /// Return type of rpc calls: either the successful return value, or a client error. pub type Result = ::std::result::Result; -#[derive(Clone)] struct InflightRpcs { - inflight_rpcs: Arc<(Mutex, Condvar)>, + count: Mutex, + cvar: Condvar, } impl InflightRpcs { - fn new(mutex: Mutex, cvar: Condvar) -> InflightRpcs { - InflightRpcs { inflight_rpcs: Arc::new((mutex, cvar)) } + fn new() -> InflightRpcs { + InflightRpcs { + count: Mutex::new(0), + cvar: Condvar::new(), + } } fn wait_until_zero(&self) { - let &(ref count, ref cvar) = &*self.inflight_rpcs; - let mut count = count.lock().unwrap(); + let mut count = self.count.lock().unwrap(); while *count != 0 { - count = cvar.wait(count).unwrap(); + count = self.cvar.wait(count).unwrap(); } info!("serve_async: shutdown complete ({} connections alive)", *count); } fn increment(&self) { - let &(ref count, _) = &*self.inflight_rpcs; - *count.lock().unwrap() += 1; + *self.count.lock().unwrap() += 1; } fn decrement(&self) { - let &(ref count, _) = &*self.inflight_rpcs; - *count.lock().unwrap() -= 1; + *self.count.lock().unwrap() -= 1; } fn decrement_and_notify(&self) { - let &(ref count, ref cvar) = &*self.inflight_rpcs; - *count.lock().unwrap() -= 1; - cvar.notify_one(); + *self.count.lock().unwrap() -= 1; + self.cvar.notify_one(); } } -struct ConnectionHandler { - read_stream: TcpStream, - write_stream: Arc>, - shutdown: Arc, - inflight_rpcs: InflightRpcs, +struct ConnectionHandler<'a> { + write_stream: Mutex, + shutdown: &'a AtomicBool, + inflight_rpcs: &'a InflightRpcs, timeout: Option, } -impl Drop for ConnectionHandler { +impl<'a> Drop for ConnectionHandler<'a> { fn drop(&mut self) { trace!("ConnectionHandler: finished serving client."); self.inflight_rpcs.decrement_and_notify(); } } -impl ConnectionHandler { - fn read(&mut self) -> bincode::serde::DeserializeResult> +impl<'a> ConnectionHandler<'a> { + fn read(read_stream: &mut TcpStream, timeout: Option) -> bincode::serde::DeserializeResult> where Request: serde::de::Deserialize { - try!(self.read_stream.set_read_timeout(self.timeout)); - bincode::serde::deserialize_from(&mut self.read_stream, bincode::SizeLimit::Infinite) + try!(read_stream.set_read_timeout(timeout)); + bincode::serde::deserialize_from(read_stream, bincode::SizeLimit::Infinite) } - fn handle_conn(&mut self, f: F) -> Result<()> - where F: 'static + Clone + Serve + fn handle_conn(&mut self, mut read_stream: TcpStream, f: F) -> Result<()> + where F: Serve { + let f = &f; trace!("ConnectionHandler: serving client..."); - loop { - match self.read() { - Ok(Packet { rpc_id, message, }) => { - let f = f.clone(); - let inflight_rpcs = self.inflight_rpcs.clone(); - inflight_rpcs.increment(); - let stream = self.write_stream.clone(); - thread::spawn(move || { - let reply = f.serve(message); - let reply_packet = Packet { - rpc_id: rpc_id, - message: reply - }; - let mut stream = stream.lock().unwrap(); - if let Err(e) = - bincode::serde::serialize_into(&mut *stream, - &reply_packet, - bincode::SizeLimit::Infinite) { - warn!("ConnectionHandler: failed to write reply to Client: {:?}", - e); + crossbeam::scope(|scope| { + loop { + match Self::read(&mut read_stream, self.timeout) { + Ok(Packet { rpc_id, message, }) => { + let inflight_rpcs = &self.inflight_rpcs; + inflight_rpcs.increment(); + let stream = &self.write_stream; + scope.spawn(move || { + let reply = f.serve(message); + let reply_packet = Packet { + rpc_id: rpc_id, + message: reply + }; + let mut stream = stream.lock().unwrap(); + if let Err(e) = + bincode::serde::serialize_into(&mut *stream, + &reply_packet, + bincode::SizeLimit::Infinite) { + warn!("ConnectionHandler: failed to write reply to Client: {:?}", + e); + } + inflight_rpcs.decrement(); + }); + if self.shutdown.load(Ordering::SeqCst) { + info!("ConnectionHandler: server shutdown, so closing connection."); + break; } - inflight_rpcs.decrement(); - }); - if self.shutdown.load(Ordering::SeqCst) { - info!("ConnectionHandler: server shutdown, so closing connection."); - break; } - } - Err(bincode::serde::DeserializeError::IoError(ref err)) - if Self::timed_out(err.kind()) => { - if !self.shutdown.load(Ordering::SeqCst) { - info!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \ - retrying read.", - err); - continue; - } else { - info!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ - closing connection.", - err); - break; + Err(bincode::serde::DeserializeError::IoError(ref err)) + if Self::timed_out(err.kind()) => { + if !self.shutdown.load(Ordering::SeqCst) { + info!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \ + retrying read.", + err); + continue; + } else { + info!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ + closing connection.", + err); + break; + } + } + Err(e) => { + warn!("ConnectionHandler: closing client connection due to {:?}", e); + return Err(e.into()); } - } - Err(e) => { - warn!("ConnectionHandler: closing client connection due to {:?}", e); - return Err(e.into()); } } - } - Ok(()) + Ok(()) + }) } fn timed_out(error_kind: io::ErrorKind) -> bool { @@ -217,54 +218,54 @@ impl ServeHandle { /// Start pub fn serve_async(addr: A, f: F, read_timeout: Option) -> io::Result where A: ToSocketAddrs, - F: 'static + Clone + Send + Serve + F: 'static + Serve { let listener = try!(TcpListener::bind(&addr)); let addr = try!(listener.local_addr()); info!("serve_async: spinning up server on {:?}", addr); let (die_tx, die_rx) = channel(); let join_handle = thread::spawn(move || { - let shutdown = Arc::new(AtomicBool::new(false)); - let inflight_rpcs = InflightRpcs::new(Mutex::new(0), Condvar::new()); - for conn in listener.incoming() { - match die_rx.try_recv() { - Ok(_) => { - info!("serve_async: shutdown received. Waiting for open connections to \ - return..."); - shutdown.store(true, Ordering::SeqCst); - inflight_rpcs.wait_until_zero(); - break; + let shutdown = AtomicBool::new(false); + let inflight_rpcs = InflightRpcs::new(); + crossbeam::scope(|scope| { + for conn in listener.incoming() { + match die_rx.try_recv() { + Ok(_) => { + info!("serve_async: shutdown received. Waiting for open connections to \ + return..."); + shutdown.store(true, Ordering::SeqCst); + inflight_rpcs.wait_until_zero(); + break; + } + Err(TryRecvError::Disconnected) => { + info!("serve_async: sender disconnected."); + break; + } + _ => (), } - Err(TryRecvError::Disconnected) => { - info!("serve_async: sender disconnected."); - break; - } - _ => (), + let conn = match conn { + Err(err) => { + error!("serve_async: failed to accept connection: {:?}", err); + return; + } + Ok(c) => c, + }; + inflight_rpcs.increment(); + let read_stream = conn.try_clone().unwrap(); + let mut handler = ConnectionHandler { + write_stream: Mutex::new(conn), + shutdown: &shutdown, + inflight_rpcs: &inflight_rpcs, + timeout: read_timeout, + }; + let f = &f; + scope.spawn(move || { + if let Err(err) = handler.handle_conn(read_stream, f) { + info!("ConnectionHandler: err in connection handling: {:?}", err); + } + }); } - let conn = match conn { - Err(err) => { - error!("serve_async: failed to accept connection: {:?}", err); - return; - } - Ok(c) => c, - }; - let f = f.clone(); - let shutdown = shutdown.clone(); - inflight_rpcs.increment(); - let inflight_rpcs = inflight_rpcs.clone(); - let mut handler = ConnectionHandler { - read_stream: conn.try_clone().unwrap(), - write_stream: Arc::new(Mutex::new(conn)), - shutdown: shutdown, - inflight_rpcs: inflight_rpcs, - timeout: read_timeout, - }; - thread::spawn(move || { - if let Err(err) = handler.handle_conn(f) { - info!("ConnectionHandler: err in connection handling: {:?}", err); - } - }); - } + }); }); Ok(ServeHandle { tx: die_tx, @@ -284,7 +285,9 @@ pub trait Serve: Send + Sync { fn serve(&self, request: Self::Request) -> Self::Reply; } -impl Serve for Arc where S: Serve +impl Serve for P + where P: Send + Sync + ::std::ops::Deref, + S: Serve { type Request = S::Request; type Reply = S::Reply;