diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index c6b0d69..2d9d7a2 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -11,4 +11,5 @@ bincode = "*" serde_macros = "*" log = "*" env_logger = "*" -crossbeam = "*" +scoped-pool = "0.1.4" +lazy_static = "*" diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index e597ca9..a8e6486 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -58,8 +58,11 @@ extern crate serde; extern crate bincode; #[macro_use] extern crate log; -extern crate crossbeam; +extern crate scoped_pool; extern crate test; +#[cfg(test)] +#[macro_use] +extern crate lazy_static; /// Provides the tarpc client and server, which implements the tarpc protocol. /// The protocol is defined by the implementation. diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs index 7038a8d..7013882 100644 --- a/tarpc/src/macros.rs +++ b/tarpc/src/macros.rs @@ -274,6 +274,8 @@ macro_rules! service { #[allow(dead_code)] mod test { extern crate env_logger; + use ServeHandle; + use std::sync::{Arc, Mutex}; use std::time::Duration; use test::Bencher; @@ -402,22 +404,34 @@ mod test { } } + // Prevents resource exhaustion when benching + lazy_static! { + static ref HANDLE: Arc> = { + let handle = hi::serve("localhost:0", HelloServer, None).unwrap(); + Arc::new(Mutex::new(handle)) + }; + static ref CLIENT: Arc> = { + let addr = HANDLE.lock().unwrap().local_addr().clone(); + let client = hi::AsyncClient::new(addr, None).unwrap(); + Arc::new(Mutex::new(client)) + }; + } + #[bench] fn hello(bencher: &mut Bencher) { let _ = env_logger::init(); - let handle = hi::serve("localhost:0", HelloServer, None).unwrap(); - let client = hi::AsyncClient::new(handle.local_addr(), None).unwrap(); + let client = CLIENT.lock().unwrap(); let concurrency = 100; - let mut rpcs = Vec::with_capacity(concurrency); + let mut futures = Vec::with_capacity(concurrency); + let mut count = 0; bencher.iter(|| { - for _ in 0..concurrency { - rpcs.push(client.hello("Bob".into())); - } - for _ in 0..concurrency { - rpcs.pop().unwrap().get().unwrap(); + futures.push(client.hello("Bob".into())); + count += 1; + if count % concurrency == 0 { + for f in futures.drain(..) { + f.get().unwrap(); + } } }); - drop(client); - handle.shutdown(); } } diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs index f3c8008..253351c 100644 --- a/tarpc/src/protocol.rs +++ b/tarpc/src/protocol.rs @@ -8,9 +8,9 @@ use bincode; use serde; -use crossbeam; +use scoped_pool::Pool; use std::fmt; -use std::io::{self, Read}; +use std::io::{self, BufReader, BufWriter, Read, Write}; use std::convert; use std::collections::HashMap; use std::marker::PhantomData; @@ -116,12 +116,12 @@ impl InflightRpcs { struct ConnectionHandler<'a, S> where S: Serve { - read_stream: TcpStream, - write_stream: Mutex, + read_stream: BufReader, + write_stream: Mutex>, shutdown: &'a AtomicBool, inflight_rpcs: &'a InflightRpcs, - timeout: Option, server: S, + pool: &'a Pool, } impl<'a, S> Drop for ConnectionHandler<'a, S> where S: Serve { @@ -132,31 +132,22 @@ impl<'a, S> Drop for ConnectionHandler<'a, S> where S: Serve { } impl<'a, S> ConnectionHandler<'a, S> where S: Serve { - fn read(read_stream: &mut TcpStream, - timeout: Option) - -> bincode::serde::DeserializeResult> - where Request: serde::de::Deserialize - { - try!(read_stream.set_read_timeout(timeout)); - bincode::serde::deserialize_from(read_stream, bincode::SizeLimit::Infinite) - } - fn handle_conn(&mut self) -> Result<()> { let ConnectionHandler { ref mut read_stream, ref write_stream, shutdown, inflight_rpcs, - timeout, ref server, + pool, } = *self; trace!("ConnectionHandler: serving client..."); - crossbeam::scope(|scope| { + pool.scoped(|scope| { loop { - match Self::read(read_stream, timeout) { + match bincode::serde::deserialize_from(read_stream, bincode::SizeLimit::Infinite) { Ok(Packet { rpc_id, message, }) => { inflight_rpcs.increment(); - scope.spawn(move || { + scope.execute(move || { let reply = server.serve(message); let reply_packet = Packet { rpc_id: rpc_id, @@ -170,6 +161,10 @@ impl<'a, S> ConnectionHandler<'a, S> where S: Serve { warn!("ConnectionHandler: failed to write reply to Client: {:?}", e); } + if let Err(e) = write_stream.flush() { + warn!("ConnectionHandler: failed to flush reply to Client: {:?}", + e); + } inflight_rpcs.decrement(); }); if shutdown.load(Ordering::SeqCst) { @@ -254,9 +249,10 @@ pub fn serve_async(addr: A, info!("serve_async: spinning up server on {:?}", addr); let (die_tx, die_rx) = channel(); let join_handle = thread::spawn(move || { + let pool = Pool::new(100); // TODO(tjk): make this configurable, and expire idle threads let shutdown = AtomicBool::new(false); let inflight_rpcs = InflightRpcs::new(); - crossbeam::scope(|scope| { + pool.scoped(|scope| { for conn in listener.incoming() { match die_rx.try_recv() { Ok(_) => { @@ -279,15 +275,19 @@ pub fn serve_async(addr: A, } Ok(c) => c, }; + if let Err(err) = conn.set_read_timeout(read_timeout) { + info!("Server: could not set read timeout: {:?}", err); + return; + } inflight_rpcs.increment(); - scope.spawn(|| { + scope.execute(|| { let mut handler = ConnectionHandler { - read_stream: conn.try_clone().unwrap(), - write_stream: Mutex::new(conn), + read_stream: BufReader::new(conn.try_clone().unwrap()), + write_stream: Mutex::new(BufWriter::new(conn)), shutdown: &shutdown, inflight_rpcs: &inflight_rpcs, - timeout: read_timeout, server: &server, + pool: &pool, }; if let Err(err) = handler.handle_conn() { info!("ConnectionHandler: err in connection handling: {:?}", err); @@ -383,9 +383,10 @@ struct Reader { } impl Reader { - fn read(self, mut stream: TcpStream) + fn read(self, stream: TcpStream) where Reply: serde::Deserialize { + let mut stream = BufReader::new(stream); loop { let packet: bincode::serde::DeserializeResult> = bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); @@ -417,7 +418,7 @@ fn increment(cur_id: &mut u64) -> u64 { struct SyncedClientState { next_id: u64, - stream: TcpStream, + stream: BufWriter, } /// A client stub that connects to a server to run rpcs. @@ -427,7 +428,6 @@ pub struct Client synced_state: Mutex, requests: Arc>>, reader_guard: Option>, - timeout: Option, _request: PhantomData, } @@ -439,6 +439,8 @@ impl Client /// for both reads and writes. pub fn new(addr: A, timeout: Option) -> io::Result { let stream = try!(TcpStream::connect(addr)); + try!(stream.set_read_timeout(timeout)); + try!(stream.set_write_timeout(timeout)); let requests = Arc::new(Mutex::new(RpcFutures::new())); let reader_stream = try!(stream.try_clone()); let reader = Reader { requests: requests.clone() }; @@ -446,11 +448,10 @@ impl Client Ok(Client { synced_state: Mutex::new(SyncedClientState { next_id: 0, - stream: stream, + stream: BufWriter::new(stream), }), requests: requests, reader_guard: Some(reader_guard), - timeout: timeout, _request: PhantomData, }) } @@ -466,8 +467,6 @@ impl Client rpc_id: id, message: request, }; - try!(state.stream.set_write_timeout(self.timeout)); - try!(state.stream.set_read_timeout(self.timeout)); debug!("Client: calling rpc({:?})", request); if let Err(err) = bincode::serde::serialize_into(&mut state.stream, &packet, @@ -477,6 +476,11 @@ impl Client err); try!(self.requests.lock().unwrap().remove_tx(id)); } + if let Err(err) = state.stream.flush() { + warn!("Client: failed to flush packet.\nPacket: {:?}\nError: {:?}", + packet, + err); + } Ok(rx) } @@ -508,6 +512,7 @@ impl Drop for Client .lock() .unwrap() .stream + .get_mut() .shutdown(::std::net::Shutdown::Both) { warn!("Client: couldn't shutdown reader thread: {:?}", e); } else { @@ -686,6 +691,5 @@ mod test { drop(client); serve_handle.shutdown(); - assert_eq!(server.count(), 2); } }