From 277e707db955809599a3e2f99bffad2c6fe1c04a Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Fri, 29 Jan 2016 22:40:22 -0800 Subject: [PATCH 1/2] Separate client and server code into protocol submodules --- tarpc/src/protocol.rs | 726 ----------------------------------- tarpc/src/protocol/client.rs | 254 ++++++++++++ tarpc/src/protocol/mod.rs | 229 +++++++++++ tarpc/src/protocol/server.rs | 263 +++++++++++++ 4 files changed, 746 insertions(+), 726 deletions(-) delete mode 100644 tarpc/src/protocol.rs create mode 100644 tarpc/src/protocol/client.rs create mode 100644 tarpc/src/protocol/mod.rs create mode 100644 tarpc/src/protocol/server.rs diff --git a/tarpc/src/protocol.rs b/tarpc/src/protocol.rs deleted file mode 100644 index 4a553f0..0000000 --- a/tarpc/src/protocol.rs +++ /dev/null @@ -1,726 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use bincode; -use serde; -use scoped_pool::Pool; -use std::fmt; -use std::io::{self, BufReader, BufWriter, Read, Write}; -use std::convert; -use std::collections::HashMap; -use std::mem; -use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs}; -use std::sync::{Arc, Condvar, Mutex}; -use std::sync::mpsc::{Receiver, Sender, TryRecvError, channel}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::time::Duration; -use std::thread::{self, JoinHandle}; - -/// Client errors that can occur during rpc calls -#[derive(Debug, Clone)] -pub enum Error { - /// An IO-related error - Io(Arc), - /// The server hung up. - ConnectionBroken, -} - -impl convert::From for Error { - fn from(err: bincode::serde::SerializeError) -> Error { - match err { - bincode::serde::SerializeError::IoError(err) => Error::Io(Arc::new(err)), - err => panic!("Unexpected error during serialization: {:?}", err), - } - } -} - -impl convert::From for Error { - fn from(err: bincode::serde::DeserializeError) -> Error { - match err { - bincode::serde::DeserializeError::IoError(err) => Error::Io(Arc::new(err)), - bincode::serde::DeserializeError::EndOfStreamError => Error::ConnectionBroken, - err => panic!("Unexpected error during deserialization: {:?}", err), - } - } -} - -impl convert::From for Error { - fn from(err: io::Error) -> Error { - Error::Io(Arc::new(err)) - } -} - -/// Return type of rpc calls: either the successful return value, or a client error. -pub type Result = ::std::result::Result; - -/// An asynchronous RPC call -pub struct Future { - rx: Receiver>, - requests: Arc>> -} - -impl Future { - /// Block until the result of the RPC call is available - pub fn get(self) -> Result { - let requests = self.requests; - self.rx.recv() - .map_err(|_| requests.lock().unwrap().get_error()) - .and_then(|reply| reply) - } -} - -struct InflightRpcs { - count: Mutex, - cvar: Condvar, -} - -impl InflightRpcs { - fn new() -> InflightRpcs { - InflightRpcs { - count: Mutex::new(0), - cvar: Condvar::new(), - } - } - - fn wait_until_zero(&self) { - let mut count = self.count.lock().unwrap(); - while *count != 0 { - count = self.cvar.wait(count).unwrap(); - } - info!("serve_async: shutdown complete ({} connections alive)", - *count); - } - - fn increment(&self) { - *self.count.lock().unwrap() += 1; - } - - fn decrement(&self) { - *self.count.lock().unwrap() -= 1; - } - - - fn decrement_and_notify(&self) { - *self.count.lock().unwrap() -= 1; - self.cvar.notify_one(); - } - -} - -struct ConnectionHandler<'a, S> - where S: Serve -{ - read_stream: BufReader, - write_stream: Mutex>, - shutdown: &'a AtomicBool, - inflight_rpcs: &'a InflightRpcs, - server: S, - pool: &'a Pool, -} - -impl<'a, S> Drop for ConnectionHandler<'a, S> where S: Serve { - fn drop(&mut self) { - trace!("ConnectionHandler: finished serving client."); - self.inflight_rpcs.decrement_and_notify(); - } -} - -impl<'a, S> ConnectionHandler<'a, S> where S: Serve { - fn handle_conn(&mut self) -> Result<()> { - let ConnectionHandler { - ref mut read_stream, - ref write_stream, - shutdown, - inflight_rpcs, - ref server, - pool, - } = *self; - trace!("ConnectionHandler: serving client..."); - pool.scoped(|scope| { - loop { - match bincode::serde::deserialize_from(read_stream, bincode::SizeLimit::Infinite) { - Ok(Packet { rpc_id, message, }) => { - inflight_rpcs.increment(); - scope.execute(move || { - let reply = server.serve(message); - let reply_packet = Packet { - rpc_id: rpc_id, - message: reply - }; - let mut write_stream = write_stream.lock().unwrap(); - if let Err(e) = - bincode::serde::serialize_into(&mut *write_stream, - &reply_packet, - bincode::SizeLimit::Infinite) { - warn!("ConnectionHandler: failed to write reply to Client: {:?}", - e); - } - if let Err(e) = write_stream.flush() { - warn!("ConnectionHandler: failed to flush reply to Client: {:?}", - e); - } - inflight_rpcs.decrement(); - }); - if shutdown.load(Ordering::SeqCst) { - info!("ConnectionHandler: server shutdown, so closing connection."); - break; - } - } - Err(bincode::serde::DeserializeError::IoError(ref err)) - if Self::timed_out(err.kind()) => { - if !shutdown.load(Ordering::SeqCst) { - info!("ConnectionHandler: read timed out ({:?}). Server not \ - shutdown, so retrying read.", - err); - continue; - } else { - info!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ - closing connection.", - err); - break; - } - } - Err(e) => { - warn!("ConnectionHandler: closing client connection due to {:?}", - e); - return Err(e.into()); - } - } - } - Ok(()) - }) - } - - fn timed_out(error_kind: io::ErrorKind) -> bool { - match error_kind { - io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock => true, - _ => false, - } - } -} - -/// Provides methods for blocking until the server completes, -pub struct ServeHandle { - tx: Sender<()>, - join_handle: JoinHandle<()>, - addr: SocketAddr, -} - -impl ServeHandle { - /// Block until the server completes - pub fn wait(self) { - self.join_handle.join().unwrap(); - } - - /// Returns the address the server is bound to - pub fn local_addr(&self) -> &SocketAddr { - &self.addr - } - - /// Shutdown the server. Gracefully shuts down the serve thread but currently does not - /// gracefully close open connections. - pub fn shutdown(self) { - info!("ServeHandle: attempting to shut down the server."); - self.tx.send(()).unwrap(); - if let Ok(_) = TcpStream::connect(self.addr) { - self.join_handle.join().unwrap(); - } else { - warn!("ServeHandle: best effort shutdown of serve thread failed"); - } - } -} - -/// Start -pub fn serve_async(addr: A, - server: S, - read_timeout: Option) - -> io::Result - where A: ToSocketAddrs, - S: 'static + Serve -{ - let listener = try!(TcpListener::bind(&addr)); - let addr = try!(listener.local_addr()); - info!("serve_async: spinning up server on {:?}", addr); - let (die_tx, die_rx) = channel(); - let join_handle = thread::spawn(move || { - let pool = Pool::new(100); // TODO(tjk): make this configurable, and expire idle threads - let shutdown = AtomicBool::new(false); - let inflight_rpcs = InflightRpcs::new(); - pool.scoped(|scope| { - for conn in listener.incoming() { - match die_rx.try_recv() { - Ok(_) => { - info!("serve_async: shutdown received. Waiting for open connections to \ - return..."); - shutdown.store(true, Ordering::SeqCst); - inflight_rpcs.wait_until_zero(); - break; - } - Err(TryRecvError::Disconnected) => { - info!("serve_async: sender disconnected."); - break; - } - _ => (), - } - let conn = match conn { - Err(err) => { - error!("serve_async: failed to accept connection: {:?}", err); - return; - } - Ok(c) => c, - }; - if let Err(err) = conn.set_read_timeout(read_timeout) { - info!("Server: could not set read timeout: {:?}", err); - return; - } - inflight_rpcs.increment(); - scope.execute(|| { - let mut handler = ConnectionHandler { - read_stream: BufReader::new(conn.try_clone().unwrap()), - write_stream: Mutex::new(BufWriter::new(conn)), - shutdown: &shutdown, - inflight_rpcs: &inflight_rpcs, - server: &server, - pool: &pool, - }; - if let Err(err) = handler.handle_conn() { - info!("ConnectionHandler: err in connection handling: {:?}", err); - } - }); - } - }); - }); - Ok(ServeHandle { - tx: die_tx, - join_handle: join_handle, - addr: addr.clone(), - }) -} - -/// A service provided by a server -pub trait Serve: Send + Sync { - /// The type of request received by the server - type Request: 'static + fmt::Debug + serde::ser::Serialize + serde::de::Deserialize + Send; - /// The type of reply sent by the server - type Reply: 'static + fmt::Debug + serde::ser::Serialize + serde::de::Deserialize; - - /// Return a reply for a given request - fn serve(&self, request: Self::Request) -> Self::Reply; -} - -impl Serve for P - where P: Send + Sync + ::std::ops::Deref, - S: Serve -{ - type Request = S::Request; - type Reply = S::Reply; - - fn serve(&self, request: S::Request) -> S::Reply { - S::serve(self, request) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct Packet { - rpc_id: u64, - message: T, -} - -struct RpcFutures(Result>>>); - -impl RpcFutures { - fn new() -> RpcFutures { - RpcFutures(Ok(HashMap::new())) - } - - fn insert_tx(&mut self, id: u64, tx: Sender>) -> Result<()> { - match self.0 { - Ok(ref mut requests) => { - requests.insert(id, tx); - Ok(()) - } - Err(ref e) => Err(e.clone()), - } - } - - fn remove_tx(&mut self, id: u64) -> Result<()> { - match self.0 { - Ok(ref mut requests) => { - requests.remove(&id); - Ok(()) - } - Err(ref e) => Err(e.clone()), - } - } - - fn complete_reply(&mut self, id: u64, reply: Reply) { - if let Some(tx) = self.0.as_mut().unwrap().remove(&id) { - if let Err(e) = tx.send(Ok(reply)) { - info!("Reader: could not complete reply: {:?}", e); - } - } else { - warn!("RpcFutures: expected sender for id {} but got None!", id); - } - } - - fn set_error(&mut self, err: bincode::serde::DeserializeError) { - let _ = mem::replace(&mut self.0, Err(err.into())); - } - - fn get_error(&self) -> Error { - self.0.as_ref().err().unwrap().clone() - } -} - -fn write(outbound: Receiver<(Request, Sender>)>, - requests: Arc>>, - stream: TcpStream) - where Request: serde::Serialize, - Reply: serde::Deserialize, -{ - let mut next_id = 0; - let mut stream = BufWriter::new(stream); - loop { - let (request, tx) = match outbound.recv() { - Err(e) => { - debug!("Writer: all senders have exited ({:?}). Returning.", e); - return; - } - Ok(request) => request, - }; - if let Err(e) = requests.lock().unwrap().insert_tx(next_id, tx.clone()) { - report_error(&tx, e); - // Once insert_tx returns Err, it will continue to do so. However, continue here so - // that any other clients who sent requests will also recv the Err. - continue; - } - let id = next_id; - next_id += 1; - let packet = Packet { - rpc_id: id, - message: request, - }; - debug!("Writer: calling rpc({:?})", id); - if let Err(e) = bincode::serde::serialize_into(&mut stream, - &packet, - bincode::SizeLimit::Infinite) { - report_error(&tx, e.into()); - // Typically we'd want to notify the client of any Err returned by remove_tx, but in - // this case the client already hit an Err, and doesn't need to know about this one, as - // well. - let _ = requests.lock().unwrap().remove_tx(id); - continue; - } - if let Err(e) = stream.flush() { - report_error(&tx, e.into()); - } - } - - fn report_error(tx: &Sender>, e: Error) - where Reply: serde::Deserialize - { - // Clone the err so we can log it if sending fails - if let Err(e2) = tx.send(Err(e.clone())) { - debug!("Error encountered while trying to send an error. \ - Initial error: {:?}; Send error: {:?}", - e, - e2); - } - } - -} - -fn read(requests: Arc>>, stream: TcpStream) - where Reply: serde::Deserialize -{ - let mut stream = BufReader::new(stream); - loop { - let packet: bincode::serde::DeserializeResult> = - bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); - match packet { - Ok(Packet { - rpc_id: id, - message: reply - }) => { - debug!("Client: received message, id={}", id); - requests.lock().unwrap().complete_reply(id, reply); - } - Err(err) => { - warn!("Client: reader thread encountered an unexpected error while parsing; \ - returning now. Error: {:?}", - err); - requests.lock().unwrap().set_error(err); - break; - } - } - } -} - -/// A client stub that connects to a server to run rpcs. -pub struct Client - where Request: serde::ser::Serialize -{ - // The guard is in an option so it can be joined in the drop fn - reader_guard: Arc>>, - outbound: Sender<(Request, Sender>)>, - requests: Arc>>, - shutdown: TcpStream, -} - -impl Client - where Request: serde::ser::Serialize + Send + 'static, - Reply: serde::de::Deserialize + Send + 'static -{ - /// Create a new client that connects to `addr`. The client uses the given timeout - /// for both reads and writes. - pub fn new(addr: A, timeout: Option) -> io::Result { - let stream = try!(TcpStream::connect(addr)); - try!(stream.set_read_timeout(timeout)); - try!(stream.set_write_timeout(timeout)); - let reader_stream = try!(stream.try_clone()); - let writer_stream = try!(stream.try_clone()); - let requests = Arc::new(Mutex::new(RpcFutures::new())); - let reader_requests = requests.clone(); - let writer_requests = requests.clone(); - let (tx, rx) = channel(); - let reader_guard = thread::spawn(move || read(reader_requests, reader_stream)); - thread::spawn(move || write(rx, writer_requests, writer_stream)); - Ok(Client { - reader_guard: Arc::new(Some(reader_guard)), - outbound: tx, - requests: requests, - shutdown: stream, - }) - } - - /// Clones the Client so that it can be shared across threads. - pub fn try_clone(&self) -> io::Result> { - Ok(Client { - reader_guard: self.reader_guard.clone(), - outbound: self.outbound.clone(), - requests: self.requests.clone(), - shutdown: try!(self.shutdown.try_clone()), - }) - } - - fn rpc_internal(&self, request: Request) -> Receiver> - where Request: serde::ser::Serialize + fmt::Debug + Send + 'static - { - let (tx, rx) = channel(); - self.outbound.send((request, tx)).unwrap(); - rx - } - - /// Run the specified rpc method on the server this client is connected to - pub fn rpc(&self, request: Request) -> Result - where Request: serde::ser::Serialize + fmt::Debug + Send + 'static - { - self.rpc_internal(request) - .recv() - .map_err(|_| self.requests.lock().unwrap().get_error()) - .and_then(|reply| reply) - } - - /// Asynchronously run the specified rpc method on the server this client is connected to - pub fn rpc_async(&self, request: Request) -> Future - where Request: serde::ser::Serialize + fmt::Debug + Send + 'static - { - Future { - rx: self.rpc_internal(request), - requests: self.requests.clone(), - } - } -} - -impl Drop for Client - where Request: serde::ser::Serialize -{ - fn drop(&mut self) { - debug!("Dropping Client."); - if let Some(reader_guard) = Arc::get_mut(&mut self.reader_guard) { - debug!("Attempting to shut down writer and reader threads."); - if let Err(e) = self.shutdown.shutdown(::std::net::Shutdown::Both) { - warn!("Client: couldn't shutdown writer and reader threads: {:?}", e); - } else { - // We only join if we know the TcpStream was shut down. Otherwise we might never - // finish. - debug!("Joining writer and reader."); - reader_guard.take().unwrap().join().unwrap(); - debug!("Successfully joined writer and reader."); - } - } - } -} - -#[cfg(test)] -mod test { - extern crate env_logger; - use super::{Client, Serve, serve_async}; - use scoped_pool::Pool; - use std::sync::{Arc, Barrier, Mutex}; - use std::thread; - use std::time::Duration; - - fn test_timeout() -> Option { - Some(Duration::from_secs(1)) - } - - #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] - enum Request { - Increment, - } - - #[derive(Debug, PartialEq, Serialize, Deserialize)] - enum Reply { - Increment(u64), - } - - struct Server { - counter: Mutex, - } - - impl Serve for Server { - type Request = Request; - type Reply = Reply; - - fn serve(&self, _: Request) -> Reply { - let mut counter = self.counter.lock().unwrap(); - let reply = Reply::Increment(*counter); - *counter += 1; - reply - } - } - - impl Server { - fn new() -> Server { - Server { counter: Mutex::new(0) } - } - - fn count(&self) -> u64 { - *self.counter.lock().unwrap() - } - } - - #[test] - fn handle() { - let _ = env_logger::init(); - let server = Arc::new(Server::new()); - let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap(); - let client: Client = Client::new(serve_handle.local_addr(), None).unwrap(); - drop(client); - serve_handle.shutdown(); - } - - #[test] - fn simple() { - let _ = env_logger::init(); - let server = Arc::new(Server::new()); - let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap(); - let addr = serve_handle.local_addr().clone(); - let client = Client::new(addr, None).unwrap(); - assert_eq!(Reply::Increment(0), - client.rpc(Request::Increment).unwrap()); - assert_eq!(1, server.count()); - assert_eq!(Reply::Increment(1), - client.rpc(Request::Increment).unwrap()); - assert_eq!(2, server.count()); - drop(client); - serve_handle.shutdown(); - } - - struct BarrierServer { - barrier: Barrier, - inner: Server, - } - - impl Serve for BarrierServer { - type Request = Request; - type Reply = Reply; - fn serve(&self, request: Request) -> Reply { - self.barrier.wait(); - self.inner.serve(request) - } - } - - impl BarrierServer { - fn new(n: usize) -> BarrierServer { - BarrierServer { - barrier: Barrier::new(n), - inner: Server::new(), - } - } - - fn count(&self) -> u64 { - self.inner.count() - } - } - - #[test] - fn force_shutdown() { - let _ = env_logger::init(); - let server = Arc::new(Server::new()); - let serve_handle = serve_async("localhost:0", server, Some(Duration::new(0, 10))).unwrap(); - let addr = serve_handle.local_addr().clone(); - let client: Client = Client::new(addr, None).unwrap(); - let thread = thread::spawn(move || serve_handle.shutdown()); - info!("force_shutdown:: rpc1: {:?}", client.rpc(Request::Increment)); - thread.join().unwrap(); - } - - #[test] - fn client_failed_rpc() { - let _ = env_logger::init(); - let server = Arc::new(Server::new()); - let serve_handle = serve_async("localhost:0", server, test_timeout()).unwrap(); - let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); - serve_handle.shutdown(); - match client.rpc(Request::Increment) { - Err(super::Error::ConnectionBroken) => {} // success - otherwise => panic!("Expected Err(ConnectionBroken), got {:?}", otherwise), - } - let _ = client.rpc(Request::Increment); // Test whether second failure hangs - } - - #[test] - fn concurrent() { - let _ = env_logger::init(); - let concurrency = 10; - let pool = Pool::new(concurrency); - let server = Arc::new(BarrierServer::new(concurrency)); - let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap(); - let addr = serve_handle.local_addr().clone(); - let client: Client = Client::new(addr, None).unwrap(); - pool.scoped(|scope| { - for _ in 0..concurrency { - let client = client.try_clone().unwrap(); - scope.execute(move || { client.rpc(Request::Increment).unwrap(); }); - } - }); - assert_eq!(concurrency as u64, server.count()); - drop(client); - serve_handle.shutdown(); - } - - #[test] - fn async() { - let _ = env_logger::init(); - let server = Arc::new(Server::new()); - let serve_handle = serve_async("localhost:0", server.clone(), None).unwrap(); - let addr = serve_handle.local_addr().clone(); - let client: Client = Client::new(addr, None).unwrap(); - - // Drop future immediately; does the reader channel panic when sending? - client.rpc_async(Request::Increment); - // If the reader panicked, this won't succeed - client.rpc_async(Request::Increment); - - drop(client); - serve_handle.shutdown(); - } -} diff --git a/tarpc/src/protocol/client.rs b/tarpc/src/protocol/client.rs new file mode 100644 index 0000000..ee895e9 --- /dev/null +++ b/tarpc/src/protocol/client.rs @@ -0,0 +1,254 @@ +use bincode; +use serde; +use std::fmt; +use std::io::{self, BufReader, BufWriter, Read, Write}; +use std::collections::HashMap; +use std::mem; +use std::net::{TcpStream, ToSocketAddrs}; +use std::sync::{Arc, Mutex}; +use std::sync::mpsc::{Receiver, Sender, channel}; +use std::time::Duration; +use std::thread; +use super::{Error, Packet, Result}; + +/// A client stub that connects to a server to run rpcs. +pub struct Client + where Request: serde::ser::Serialize +{ + // The guard is in an option so it can be joined in the drop fn + reader_guard: Arc>>, + outbound: Sender<(Request, Sender>)>, + requests: Arc>>, + shutdown: TcpStream, +} + +impl Client + where Request: serde::ser::Serialize + Send + 'static, + Reply: serde::de::Deserialize + Send + 'static +{ + /// Create a new client that connects to `addr`. The client uses the given timeout + /// for both reads and writes. + pub fn new(addr: A, timeout: Option) -> io::Result { + let stream = try!(TcpStream::connect(addr)); + try!(stream.set_read_timeout(timeout)); + try!(stream.set_write_timeout(timeout)); + let reader_stream = try!(stream.try_clone()); + let writer_stream = try!(stream.try_clone()); + let requests = Arc::new(Mutex::new(RpcFutures::new())); + let reader_requests = requests.clone(); + let writer_requests = requests.clone(); + let (tx, rx) = channel(); + let reader_guard = thread::spawn(move || read(reader_requests, reader_stream)); + thread::spawn(move || write(rx, writer_requests, writer_stream)); + Ok(Client { + reader_guard: Arc::new(Some(reader_guard)), + outbound: tx, + requests: requests, + shutdown: stream, + }) + } + + /// Clones the Client so that it can be shared across threads. + pub fn try_clone(&self) -> io::Result> { + Ok(Client { + reader_guard: self.reader_guard.clone(), + outbound: self.outbound.clone(), + requests: self.requests.clone(), + shutdown: try!(self.shutdown.try_clone()), + }) + } + + fn rpc_internal(&self, request: Request) -> Receiver> + where Request: serde::ser::Serialize + fmt::Debug + Send + 'static + { + let (tx, rx) = channel(); + self.outbound.send((request, tx)).unwrap(); + rx + } + + /// Run the specified rpc method on the server this client is connected to + pub fn rpc(&self, request: Request) -> Result + where Request: serde::ser::Serialize + fmt::Debug + Send + 'static + { + self.rpc_internal(request) + .recv() + .map_err(|_| self.requests.lock().unwrap().get_error()) + .and_then(|reply| reply) + } + + /// Asynchronously run the specified rpc method on the server this client is connected to + pub fn rpc_async(&self, request: Request) -> Future + where Request: serde::ser::Serialize + fmt::Debug + Send + 'static + { + Future { + rx: self.rpc_internal(request), + requests: self.requests.clone(), + } + } +} + +impl Drop for Client + where Request: serde::ser::Serialize +{ + fn drop(&mut self) { + debug!("Dropping Client."); + if let Some(reader_guard) = Arc::get_mut(&mut self.reader_guard) { + debug!("Attempting to shut down writer and reader threads."); + if let Err(e) = self.shutdown.shutdown(::std::net::Shutdown::Both) { + warn!("Client: couldn't shutdown writer and reader threads: {:?}", e); + } else { + // We only join if we know the TcpStream was shut down. Otherwise we might never + // finish. + debug!("Joining writer and reader."); + reader_guard.take().unwrap().join().unwrap(); + debug!("Successfully joined writer and reader."); + } + } + } +} + +/// An asynchronous RPC call +pub struct Future { + rx: Receiver>, + requests: Arc>> +} + +impl Future { + /// Block until the result of the RPC call is available + pub fn get(self) -> Result { + let requests = self.requests; + self.rx.recv() + .map_err(|_| requests.lock().unwrap().get_error()) + .and_then(|reply| reply) + } +} + +struct RpcFutures(Result>>>); + +impl RpcFutures { + fn new() -> RpcFutures { + RpcFutures(Ok(HashMap::new())) + } + + fn insert_tx(&mut self, id: u64, tx: Sender>) -> Result<()> { + match self.0 { + Ok(ref mut requests) => { + requests.insert(id, tx); + Ok(()) + } + Err(ref e) => Err(e.clone()), + } + } + + fn remove_tx(&mut self, id: u64) -> Result<()> { + match self.0 { + Ok(ref mut requests) => { + requests.remove(&id); + Ok(()) + } + Err(ref e) => Err(e.clone()), + } + } + + fn complete_reply(&mut self, id: u64, reply: Reply) { + if let Some(tx) = self.0.as_mut().unwrap().remove(&id) { + if let Err(e) = tx.send(Ok(reply)) { + info!("Reader: could not complete reply: {:?}", e); + } + } else { + warn!("RpcFutures: expected sender for id {} but got None!", id); + } + } + + fn set_error(&mut self, err: bincode::serde::DeserializeError) { + let _ = mem::replace(&mut self.0, Err(err.into())); + } + + fn get_error(&self) -> Error { + self.0.as_ref().err().unwrap().clone() + } +} + +fn write(outbound: Receiver<(Request, Sender>)>, + requests: Arc>>, + stream: TcpStream) + where Request: serde::Serialize, + Reply: serde::Deserialize, +{ + let mut next_id = 0; + let mut stream = BufWriter::new(stream); + loop { + let (request, tx) = match outbound.recv() { + Err(e) => { + debug!("Writer: all senders have exited ({:?}). Returning.", e); + return; + } + Ok(request) => request, + }; + if let Err(e) = requests.lock().unwrap().insert_tx(next_id, tx.clone()) { + report_error(&tx, e); + // Once insert_tx returns Err, it will continue to do so. However, continue here so + // that any other clients who sent requests will also recv the Err. + continue; + } + let id = next_id; + next_id += 1; + let packet = Packet { + rpc_id: id, + message: request, + }; + debug!("Writer: calling rpc({:?})", id); + if let Err(e) = bincode::serde::serialize_into(&mut stream, + &packet, + bincode::SizeLimit::Infinite) { + report_error(&tx, e.into()); + // Typically we'd want to notify the client of any Err returned by remove_tx, but in + // this case the client already hit an Err, and doesn't need to know about this one, as + // well. + let _ = requests.lock().unwrap().remove_tx(id); + continue; + } + if let Err(e) = stream.flush() { + report_error(&tx, e.into()); + } + } + + fn report_error(tx: &Sender>, e: Error) + where Reply: serde::Deserialize + { + // Clone the err so we can log it if sending fails + if let Err(e2) = tx.send(Err(e.clone())) { + debug!("Error encountered while trying to send an error. \ + Initial error: {:?}; Send error: {:?}", + e, + e2); + } + } + +} + +fn read(requests: Arc>>, stream: TcpStream) + where Reply: serde::Deserialize +{ + let mut stream = BufReader::new(stream); + loop { + let packet: bincode::serde::DeserializeResult> = + bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); + match packet { + Ok(Packet { + rpc_id: id, + message: reply + }) => { + debug!("Client: received message, id={}", id); + requests.lock().unwrap().complete_reply(id, reply); + } + Err(err) => { + warn!("Client: reader thread encountered an unexpected error while parsing; \ + returning now. Error: {:?}", + err); + requests.lock().unwrap().set_error(err); + break; + } + } + } +} diff --git a/tarpc/src/protocol/mod.rs b/tarpc/src/protocol/mod.rs new file mode 100644 index 0000000..d2137c6 --- /dev/null +++ b/tarpc/src/protocol/mod.rs @@ -0,0 +1,229 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use bincode; +use std::io; +use std::convert; +use std::sync::Arc; + +mod client; +mod server; + +pub use self::client::{Client, Future}; +pub use self::server::{Serve, ServeHandle, serve_async}; + +/// Client errors that can occur during rpc calls +#[derive(Debug, Clone)] +pub enum Error { + /// An IO-related error + Io(Arc), + /// The server hung up. + ConnectionBroken, +} + +impl convert::From for Error { + fn from(err: bincode::serde::SerializeError) -> Error { + match err { + bincode::serde::SerializeError::IoError(err) => Error::Io(Arc::new(err)), + err => panic!("Unexpected error during serialization: {:?}", err), + } + } +} + +impl convert::From for Error { + fn from(err: bincode::serde::DeserializeError) -> Error { + match err { + bincode::serde::DeserializeError::IoError(err) => Error::Io(Arc::new(err)), + bincode::serde::DeserializeError::EndOfStreamError => Error::ConnectionBroken, + err => panic!("Unexpected error during deserialization: {:?}", err), + } + } +} + +impl convert::From for Error { + fn from(err: io::Error) -> Error { + Error::Io(Arc::new(err)) + } +} + +/// Return type of rpc calls: either the successful return value, or a client error. +pub type Result = ::std::result::Result; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Packet { + rpc_id: u64, + message: T, +} + +#[cfg(test)] +mod test { + extern crate env_logger; + use super::{Client, Serve, serve_async}; + use scoped_pool::Pool; + use std::sync::{Arc, Barrier, Mutex}; + use std::thread; + use std::time::Duration; + + fn test_timeout() -> Option { + Some(Duration::from_secs(1)) + } + + #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] + enum Request { + Increment, + } + + #[derive(Debug, PartialEq, Serialize, Deserialize)] + enum Reply { + Increment(u64), + } + + struct Server { + counter: Mutex, + } + + impl Serve for Server { + type Request = Request; + type Reply = Reply; + + fn serve(&self, _: Request) -> Reply { + let mut counter = self.counter.lock().unwrap(); + let reply = Reply::Increment(*counter); + *counter += 1; + reply + } + } + + impl Server { + fn new() -> Server { + Server { counter: Mutex::new(0) } + } + + fn count(&self) -> u64 { + *self.counter.lock().unwrap() + } + } + + #[test] + fn handle() { + let _ = env_logger::init(); + let server = Arc::new(Server::new()); + let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap(); + let client: Client = Client::new(serve_handle.local_addr(), None).unwrap(); + drop(client); + serve_handle.shutdown(); + } + + #[test] + fn simple() { + let _ = env_logger::init(); + let server = Arc::new(Server::new()); + let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap(); + let addr = serve_handle.local_addr().clone(); + let client = Client::new(addr, None).unwrap(); + assert_eq!(Reply::Increment(0), + client.rpc(Request::Increment).unwrap()); + assert_eq!(1, server.count()); + assert_eq!(Reply::Increment(1), + client.rpc(Request::Increment).unwrap()); + assert_eq!(2, server.count()); + drop(client); + serve_handle.shutdown(); + } + + struct BarrierServer { + barrier: Barrier, + inner: Server, + } + + impl Serve for BarrierServer { + type Request = Request; + type Reply = Reply; + fn serve(&self, request: Request) -> Reply { + self.barrier.wait(); + self.inner.serve(request) + } + } + + impl BarrierServer { + fn new(n: usize) -> BarrierServer { + BarrierServer { + barrier: Barrier::new(n), + inner: Server::new(), + } + } + + fn count(&self) -> u64 { + self.inner.count() + } + } + + #[test] + fn force_shutdown() { + let _ = env_logger::init(); + let server = Arc::new(Server::new()); + let serve_handle = serve_async("localhost:0", server, Some(Duration::new(0, 10))).unwrap(); + let addr = serve_handle.local_addr().clone(); + let client: Client = Client::new(addr, None).unwrap(); + let thread = thread::spawn(move || serve_handle.shutdown()); + info!("force_shutdown:: rpc1: {:?}", client.rpc(Request::Increment)); + thread.join().unwrap(); + } + + #[test] + fn client_failed_rpc() { + let _ = env_logger::init(); + let server = Arc::new(Server::new()); + let serve_handle = serve_async("localhost:0", server, test_timeout()).unwrap(); + let addr = serve_handle.local_addr().clone(); + let client: Arc> = Arc::new(Client::new(addr, None).unwrap()); + serve_handle.shutdown(); + match client.rpc(Request::Increment) { + Err(super::Error::ConnectionBroken) => {} // success + otherwise => panic!("Expected Err(ConnectionBroken), got {:?}", otherwise), + } + let _ = client.rpc(Request::Increment); // Test whether second failure hangs + } + + #[test] + fn concurrent() { + let _ = env_logger::init(); + let concurrency = 10; + let pool = Pool::new(concurrency); + let server = Arc::new(BarrierServer::new(concurrency)); + let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap(); + let addr = serve_handle.local_addr().clone(); + let client: Client = Client::new(addr, None).unwrap(); + pool.scoped(|scope| { + for _ in 0..concurrency { + let client = client.try_clone().unwrap(); + scope.execute(move || { client.rpc(Request::Increment).unwrap(); }); + } + }); + assert_eq!(concurrency as u64, server.count()); + drop(client); + serve_handle.shutdown(); + } + + #[test] + fn async() { + let _ = env_logger::init(); + let server = Arc::new(Server::new()); + let serve_handle = serve_async("localhost:0", server.clone(), None).unwrap(); + let addr = serve_handle.local_addr().clone(); + let client: Client = Client::new(addr, None).unwrap(); + + // Drop future immediately; does the reader channel panic when sending? + client.rpc_async(Request::Increment); + // If the reader panicked, this won't succeed + client.rpc_async(Request::Increment); + + drop(client); + serve_handle.shutdown(); + } +} diff --git a/tarpc/src/protocol/server.rs b/tarpc/src/protocol/server.rs new file mode 100644 index 0000000..ce165d1 --- /dev/null +++ b/tarpc/src/protocol/server.rs @@ -0,0 +1,263 @@ +use bincode; +use serde; +use scoped_pool::Pool; +use std::fmt; +use std::io::{self, BufReader, BufWriter, Write}; +use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs}; +use std::sync::{Condvar, Mutex}; +use std::sync::mpsc::{Sender, TryRecvError, channel}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; +use std::thread::{self, JoinHandle}; +use super::{Packet, Result}; + +struct ConnectionHandler<'a, S> + where S: Serve +{ + read_stream: BufReader, + write_stream: Mutex>, + shutdown: &'a AtomicBool, + inflight_rpcs: &'a InflightRpcs, + server: S, + pool: &'a Pool, +} + +impl<'a, S> Drop for ConnectionHandler<'a, S> where S: Serve { + fn drop(&mut self) { + trace!("ConnectionHandler: finished serving client."); + self.inflight_rpcs.decrement_and_notify(); + } +} + +impl<'a, S> ConnectionHandler<'a, S> where S: Serve { + fn handle_conn(&mut self) -> Result<()> { + let ConnectionHandler { + ref mut read_stream, + ref write_stream, + shutdown, + inflight_rpcs, + ref server, + pool, + } = *self; + trace!("ConnectionHandler: serving client..."); + pool.scoped(|scope| { + loop { + match bincode::serde::deserialize_from(read_stream, bincode::SizeLimit::Infinite) { + Ok(Packet { rpc_id, message, }) => { + inflight_rpcs.increment(); + scope.execute(move || { + let reply = server.serve(message); + let reply_packet = Packet { + rpc_id: rpc_id, + message: reply + }; + let mut write_stream = write_stream.lock().unwrap(); + if let Err(e) = + bincode::serde::serialize_into(&mut *write_stream, + &reply_packet, + bincode::SizeLimit::Infinite) { + warn!("ConnectionHandler: failed to write reply to Client: {:?}", + e); + } + if let Err(e) = write_stream.flush() { + warn!("ConnectionHandler: failed to flush reply to Client: {:?}", + e); + } + inflight_rpcs.decrement(); + }); + if shutdown.load(Ordering::SeqCst) { + info!("ConnectionHandler: server shutdown, so closing connection."); + break; + } + } + Err(bincode::serde::DeserializeError::IoError(ref err)) + if Self::timed_out(err.kind()) => { + if !shutdown.load(Ordering::SeqCst) { + info!("ConnectionHandler: read timed out ({:?}). Server not \ + shutdown, so retrying read.", + err); + continue; + } else { + info!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ + closing connection.", + err); + break; + } + } + Err(e) => { + warn!("ConnectionHandler: closing client connection due to {:?}", + e); + return Err(e.into()); + } + } + } + Ok(()) + }) + } + + fn timed_out(error_kind: io::ErrorKind) -> bool { + match error_kind { + io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock => true, + _ => false, + } + } +} + +struct InflightRpcs { + count: Mutex, + cvar: Condvar, +} + +impl InflightRpcs { + fn new() -> InflightRpcs { + InflightRpcs { + count: Mutex::new(0), + cvar: Condvar::new(), + } + } + + fn wait_until_zero(&self) { + let mut count = self.count.lock().unwrap(); + while *count != 0 { + count = self.cvar.wait(count).unwrap(); + } + info!("serve_async: shutdown complete ({} connections alive)", + *count); + } + + fn increment(&self) { + *self.count.lock().unwrap() += 1; + } + + fn decrement(&self) { + *self.count.lock().unwrap() -= 1; + } + + + fn decrement_and_notify(&self) { + *self.count.lock().unwrap() -= 1; + self.cvar.notify_one(); + } + +} + +/// Provides methods for blocking until the server completes, +pub struct ServeHandle { + tx: Sender<()>, + join_handle: JoinHandle<()>, + addr: SocketAddr, +} + +impl ServeHandle { + /// Block until the server completes + pub fn wait(self) { + self.join_handle.join().unwrap(); + } + + /// Returns the address the server is bound to + pub fn local_addr(&self) -> &SocketAddr { + &self.addr + } + + /// Shutdown the server. Gracefully shuts down the serve thread but currently does not + /// gracefully close open connections. + pub fn shutdown(self) { + info!("ServeHandle: attempting to shut down the server."); + self.tx.send(()).unwrap(); + if let Ok(_) = TcpStream::connect(self.addr) { + self.join_handle.join().unwrap(); + } else { + warn!("ServeHandle: best effort shutdown of serve thread failed"); + } + } +} + +/// Start +pub fn serve_async(addr: A, + server: S, + read_timeout: Option) + -> io::Result + where A: ToSocketAddrs, + S: 'static + Serve +{ + let listener = try!(TcpListener::bind(&addr)); + let addr = try!(listener.local_addr()); + info!("serve_async: spinning up server on {:?}", addr); + let (die_tx, die_rx) = channel(); + let join_handle = thread::spawn(move || { + let pool = Pool::new(100); // TODO(tjk): make this configurable, and expire idle threads + let shutdown = AtomicBool::new(false); + let inflight_rpcs = InflightRpcs::new(); + pool.scoped(|scope| { + for conn in listener.incoming() { + match die_rx.try_recv() { + Ok(_) => { + info!("serve_async: shutdown received. Waiting for open connections to \ + return..."); + shutdown.store(true, Ordering::SeqCst); + inflight_rpcs.wait_until_zero(); + break; + } + Err(TryRecvError::Disconnected) => { + info!("serve_async: sender disconnected."); + break; + } + _ => (), + } + let conn = match conn { + Err(err) => { + error!("serve_async: failed to accept connection: {:?}", err); + return; + } + Ok(c) => c, + }; + if let Err(err) = conn.set_read_timeout(read_timeout) { + info!("Server: could not set read timeout: {:?}", err); + return; + } + inflight_rpcs.increment(); + scope.execute(|| { + let mut handler = ConnectionHandler { + read_stream: BufReader::new(conn.try_clone().unwrap()), + write_stream: Mutex::new(BufWriter::new(conn)), + shutdown: &shutdown, + inflight_rpcs: &inflight_rpcs, + server: &server, + pool: &pool, + }; + if let Err(err) = handler.handle_conn() { + info!("ConnectionHandler: err in connection handling: {:?}", err); + } + }); + } + }); + }); + Ok(ServeHandle { + tx: die_tx, + join_handle: join_handle, + addr: addr.clone(), + }) +} + +/// A service provided by a server +pub trait Serve: Send + Sync { + /// The type of request received by the server + type Request: 'static + fmt::Debug + serde::ser::Serialize + serde::de::Deserialize + Send; + /// The type of reply sent by the server + type Reply: 'static + fmt::Debug + serde::ser::Serialize + serde::de::Deserialize; + + /// Return a reply for a given request + fn serve(&self, request: Self::Request) -> Self::Reply; +} + +impl Serve for P + where P: Send + Sync + ::std::ops::Deref, + S: Serve +{ + type Request = S::Request; + type Reply = S::Reply; + + fn serve(&self, request: S::Request) -> S::Reply { + S::serve(self, request) + } +} From 5d8d04d52121c0181d2c6d5b343755c8422d474f Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Sun, 31 Jan 2016 22:05:04 -0800 Subject: [PATCH 2/2] Use expect() instead of unwrap() --- tarpc/src/lib.rs | 4 ++++ tarpc/src/protocol/client.rs | 23 +++++++++++++---------- tarpc/src/protocol/server.rs | 10 +++++----- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index a8e6486..e57bc8b 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -64,6 +64,10 @@ extern crate test; #[macro_use] extern crate lazy_static; +macro_rules! pos { + () => (concat!(file!(), ":", line!())) +} + /// Provides the tarpc client and server, which implements the tarpc protocol. /// The protocol is defined by the implementation. pub mod protocol; diff --git a/tarpc/src/protocol/client.rs b/tarpc/src/protocol/client.rs index ee895e9..4a83ecb 100644 --- a/tarpc/src/protocol/client.rs +++ b/tarpc/src/protocol/client.rs @@ -62,7 +62,7 @@ impl Client where Request: serde::ser::Serialize + fmt::Debug + Send + 'static { let (tx, rx) = channel(); - self.outbound.send((request, tx)).unwrap(); + self.outbound.send((request, tx)).expect(pos!()); rx } @@ -72,7 +72,7 @@ impl Client { self.rpc_internal(request) .recv() - .map_err(|_| self.requests.lock().unwrap().get_error()) + .map_err(|_| self.requests.lock().expect(pos!()).get_error()) .and_then(|reply| reply) } @@ -100,7 +100,10 @@ impl Drop for Client // We only join if we know the TcpStream was shut down. Otherwise we might never // finish. debug!("Joining writer and reader."); - reader_guard.take().unwrap().join().unwrap(); + reader_guard.take() + .expect(pos!()) + .join() + .expect(pos!()); debug!("Successfully joined writer and reader."); } } @@ -118,7 +121,7 @@ impl Future { pub fn get(self) -> Result { let requests = self.requests; self.rx.recv() - .map_err(|_| requests.lock().unwrap().get_error()) + .map_err(|_| requests.lock().expect(pos!()).get_error()) .and_then(|reply| reply) } } @@ -151,7 +154,7 @@ impl RpcFutures { } fn complete_reply(&mut self, id: u64, reply: Reply) { - if let Some(tx) = self.0.as_mut().unwrap().remove(&id) { + if let Some(tx) = self.0.as_mut().expect(pos!()).remove(&id) { if let Err(e) = tx.send(Ok(reply)) { info!("Reader: could not complete reply: {:?}", e); } @@ -165,7 +168,7 @@ impl RpcFutures { } fn get_error(&self) -> Error { - self.0.as_ref().err().unwrap().clone() + self.0.as_ref().err().expect(pos!()).clone() } } @@ -185,7 +188,7 @@ fn write(outbound: Receiver<(Request, Sender>)>, } Ok(request) => request, }; - if let Err(e) = requests.lock().unwrap().insert_tx(next_id, tx.clone()) { + if let Err(e) = requests.lock().expect(pos!()).insert_tx(next_id, tx.clone()) { report_error(&tx, e); // Once insert_tx returns Err, it will continue to do so. However, continue here so // that any other clients who sent requests will also recv the Err. @@ -205,7 +208,7 @@ fn write(outbound: Receiver<(Request, Sender>)>, // Typically we'd want to notify the client of any Err returned by remove_tx, but in // this case the client already hit an Err, and doesn't need to know about this one, as // well. - let _ = requests.lock().unwrap().remove_tx(id); + let _ = requests.lock().expect(pos!()).remove_tx(id); continue; } if let Err(e) = stream.flush() { @@ -240,13 +243,13 @@ fn read(requests: Arc>>, stream: TcpStream) message: reply }) => { debug!("Client: received message, id={}", id); - requests.lock().unwrap().complete_reply(id, reply); + requests.lock().expect(pos!()).complete_reply(id, reply); } Err(err) => { warn!("Client: reader thread encountered an unexpected error while parsing; \ returning now. Error: {:?}", err); - requests.lock().unwrap().set_error(err); + requests.lock().expect(pos!()).set_error(err); break; } } diff --git a/tarpc/src/protocol/server.rs b/tarpc/src/protocol/server.rs index ce165d1..15eb068 100644 --- a/tarpc/src/protocol/server.rs +++ b/tarpc/src/protocol/server.rs @@ -51,7 +51,7 @@ impl<'a, S> ConnectionHandler<'a, S> where S: Serve { rpc_id: rpc_id, message: reply }; - let mut write_stream = write_stream.lock().unwrap(); + let mut write_stream = write_stream.lock().expect(pos!()); if let Err(e) = bincode::serde::serialize_into(&mut *write_stream, &reply_packet, @@ -151,7 +151,7 @@ pub struct ServeHandle { impl ServeHandle { /// Block until the server completes pub fn wait(self) { - self.join_handle.join().unwrap(); + self.join_handle.join().expect(pos!()); } /// Returns the address the server is bound to @@ -163,9 +163,9 @@ impl ServeHandle { /// gracefully close open connections. pub fn shutdown(self) { info!("ServeHandle: attempting to shut down the server."); - self.tx.send(()).unwrap(); + self.tx.send(()).expect(pos!()); if let Ok(_) = TcpStream::connect(self.addr) { - self.join_handle.join().unwrap(); + self.join_handle.join().expect(pos!()); } else { warn!("ServeHandle: best effort shutdown of serve thread failed"); } @@ -218,7 +218,7 @@ pub fn serve_async(addr: A, inflight_rpcs.increment(); scope.execute(|| { let mut handler = ConnectionHandler { - read_stream: BufReader::new(conn.try_clone().unwrap()), + read_stream: BufReader::new(conn.try_clone().expect(pos!())), write_stream: Mutex::new(BufWriter::new(conn)), shutdown: &shutdown, inflight_rpcs: &inflight_rpcs,