Use a writer thread that handles all outbound requests.

This commit is contained in:
Tim Kuehn
2016-01-29 01:24:19 -08:00
parent e711bb006c
commit 84d402ebf5
2 changed files with 147 additions and 116 deletions

View File

@@ -21,7 +21,7 @@ macro_rules! client_methods {
) => (
$(#[$attr])*
pub fn $fn_name(&self, $($arg: $in_),*) -> $crate::Result<$out> {
let reply = try!((self.0).rpc(&request_variant!($fn_name $($arg),*)));
let reply = try!((self.0).rpc(request_variant!($fn_name $($arg),*)));
let __Reply::$fn_name(reply) = reply;
Ok(reply)
}
@@ -32,7 +32,7 @@ macro_rules! client_methods {
)*) => ( $(
$(#[$attr])*
pub fn $fn_name(&self, $($arg: $in_),*) -> $crate::Result<$out> {
let reply = try!((self.0).rpc(&request_variant!($fn_name $($arg),*)));
let reply = try!((self.0).rpc(request_variant!($fn_name $($arg),*)));
if let __Reply::$fn_name(reply) = reply {
Ok(reply)
} else {
@@ -57,7 +57,7 @@ macro_rules! async_client_methods {
let __Reply::$fn_name(reply) = reply;
reply
}
let reply = (self.0).rpc_async(&request_variant!($fn_name $($arg),*));
let reply = (self.0).rpc_async(request_variant!($fn_name $($arg),*));
Future {
future: reply,
mapper: mapper,
@@ -77,7 +77,7 @@ macro_rules! async_client_methods {
panic!("Incorrect reply variant returned from protocol::Clientrpc; expected `{}`, but got {:?}", stringify!($fn_name), reply);
}
}
let reply = (self.0).rpc_async(&request_variant!($fn_name $($arg),*));
let reply = (self.0).rpc_async(request_variant!($fn_name $($arg),*));
Future {
future: reply,
mapper: mapper,

View File

@@ -13,7 +13,6 @@ use std::fmt;
use std::io::{self, BufReader, BufWriter, Read, Write};
use std::convert;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::mem;
use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
use std::sync::{Arc, Condvar, Mutex};
@@ -61,7 +60,7 @@ pub type Result<T> = ::std::result::Result<T, Error>;
/// An asynchronous RPC call
pub struct Future<T> {
rx: Result<Receiver<T>>,
rx: Receiver<Result<T>>,
requests: Arc<Mutex<RpcFutures<T>>>
}
@@ -69,9 +68,9 @@ impl<T> Future<T> {
/// Block until the result of the RPC call is available
pub fn get(self) -> Result<T> {
let requests = self.requests;
try!(self.rx)
.recv()
self.rx.recv()
.map_err(|_| requests.lock().unwrap().get_error())
.and_then(|reply| reply)
}
}
@@ -332,14 +331,14 @@ struct Packet<T> {
message: T,
}
struct RpcFutures<Reply>(Result<HashMap<u64, Sender<Reply>>>);
struct RpcFutures<Reply>(Result<HashMap<u64, Sender<Result<Reply>>>>);
impl<Reply> RpcFutures<Reply> {
fn new() -> RpcFutures<Reply> {
RpcFutures(Ok(HashMap::new()))
}
fn insert_tx(&mut self, id: u64, tx: Sender<Reply>) -> Result<()> {
fn insert_tx(&mut self, id: u64, tx: Sender<Result<Reply>>) -> Result<()> {
match self.0 {
Ok(ref mut requests) => {
requests.insert(id, tx);
@@ -361,7 +360,7 @@ impl<Reply> RpcFutures<Reply> {
fn complete_reply(&mut self, id: u64, reply: Reply) {
if let Some(tx) = self.0.as_mut().unwrap().remove(&id) {
if let Err(e) = tx.send(reply) {
if let Err(e) = tx.send(Ok(reply)) {
info!("Reader: could not complete reply: {:?}", e);
}
} else {
@@ -378,62 +377,104 @@ impl<Reply> RpcFutures<Reply> {
}
}
struct Reader<Reply> {
requests: Arc<Mutex<RpcFutures<Reply>>>,
}
fn write<Request, Reply>(outbound: Receiver<(Request, Sender<Result<Reply>>)>,
requests: Arc<Mutex<RpcFutures<Reply>>>,
stream: TcpStream)
where Request: serde::Serialize,
Reply: serde::Deserialize,
{
let mut next_id = 0;
let mut stream = BufWriter::new(stream);
loop {
let (request, tx) = match outbound.recv() {
Err(e) => {
debug!("Writer: all senders have exited ({:?}). Returning.", e);
return;
}
Ok(request) => request,
};
if let Err(e) = requests.lock().unwrap().insert_tx(next_id, tx.clone()) {
report_error(&tx, e);
// Once insert_tx returns Err, it will continue to do so. However, continue here so
// that any other clients who sent requests will also recv the Err.
continue;
}
let id = next_id;
next_id += 1;
let packet = Packet {
rpc_id: id,
message: request,
};
debug!("Writer: calling rpc({:?})", id);
if let Err(e) = bincode::serde::serialize_into(&mut stream,
&packet,
bincode::SizeLimit::Infinite) {
report_error(&tx, e.into());
// Typically we'd want to notify the client of any Err returned by remove_tx, but in
// this case the client already hit an Err, and doesn't need to know about this one, as
// well.
let _ = requests.lock().unwrap().remove_tx(id);
continue;
}
if let Err(e) = stream.flush() {
report_error(&tx, e.into());
}
}
impl<Reply> Reader<Reply> {
fn read(self, stream: TcpStream)
fn report_error<Reply>(tx: &Sender<Result<Reply>>, e: Error)
where Reply: serde::Deserialize
{
let mut stream = BufReader::new(stream);
loop {
let packet: bincode::serde::DeserializeResult<Packet<Reply>> =
bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite);
match packet {
Ok(Packet {
rpc_id: id,
message: reply
}) => {
debug!("Client: received message, id={}", id);
self.requests.lock().unwrap().complete_reply(id, reply);
}
Err(err) => {
warn!("Client: reader thread encountered an unexpected error while parsing; \
returning now. Error: {:?}",
err);
self.requests.lock().unwrap().set_error(err);
break;
}
// Clone the err so we can log it if sending fails
if let Err(e2) = tx.send(Err(e.clone())) {
debug!("Error encountered while trying to send an error. \
Initial error: {:?}; Send error: {:?}",
e,
e2);
}
}
}
fn read<Reply>(requests: Arc<Mutex<RpcFutures<Reply>>>, stream: TcpStream)
where Reply: serde::Deserialize
{
let mut stream = BufReader::new(stream);
loop {
let packet: bincode::serde::DeserializeResult<Packet<Reply>> =
bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite);
match packet {
Ok(Packet {
rpc_id: id,
message: reply
}) => {
debug!("Client: received message, id={}", id);
requests.lock().unwrap().complete_reply(id, reply);
}
Err(err) => {
warn!("Client: reader thread encountered an unexpected error while parsing; \
returning now. Error: {:?}",
err);
requests.lock().unwrap().set_error(err);
break;
}
}
}
}
fn increment(cur_id: &mut u64) -> u64 {
let id = *cur_id;
*cur_id += 1;
id
}
struct SyncedClientState {
next_id: u64,
stream: BufWriter<TcpStream>,
}
/// A client stub that connects to a server to run rpcs.
pub struct Client<Request, Reply>
where Request: serde::ser::Serialize
{
synced_state: Mutex<SyncedClientState>,
// The guard is in an option so it can be joined in the drop fn
reader_guard: Arc<Option<thread::JoinHandle<()>>>,
outbound: Sender<(Request, Sender<Result<Reply>>)>,
requests: Arc<Mutex<RpcFutures<Reply>>>,
reader_guard: Option<thread::JoinHandle<()>>,
_request: PhantomData<Request>,
shutdown: TcpStream,
}
impl<Request, Reply> Client<Request, Reply>
where Reply: serde::de::Deserialize + Send + 'static,
Request: serde::ser::Serialize
where Request: serde::ser::Serialize + Send + 'static,
Reply: serde::de::Deserialize + Send + 'static
{
/// Create a new client that connects to `addr`. The client uses the given timeout
/// for both reads and writes.
@@ -441,60 +482,52 @@ impl<Request, Reply> Client<Request, Reply>
let stream = try!(TcpStream::connect(addr));
try!(stream.set_read_timeout(timeout));
try!(stream.set_write_timeout(timeout));
let requests = Arc::new(Mutex::new(RpcFutures::new()));
let reader_stream = try!(stream.try_clone());
let reader = Reader { requests: requests.clone() };
let reader_guard = thread::spawn(move || reader.read(reader_stream));
let writer_stream = try!(stream.try_clone());
let requests = Arc::new(Mutex::new(RpcFutures::new()));
let reader_requests = requests.clone();
let writer_requests = requests.clone();
let (tx, rx) = channel();
let reader_guard = thread::spawn(move || read(reader_requests, reader_stream));
thread::spawn(move || write(rx, writer_requests, writer_stream));
Ok(Client {
synced_state: Mutex::new(SyncedClientState {
next_id: 0,
stream: BufWriter::new(stream),
}),
reader_guard: Arc::new(Some(reader_guard)),
outbound: tx,
requests: requests,
reader_guard: Some(reader_guard),
_request: PhantomData,
shutdown: stream,
})
}
fn rpc_internal(&self, request: &Request) -> Result<Receiver<Reply>>
/// Clones the Client so that it can be shared across threads.
pub fn try_clone(&self) -> io::Result<Client<Request, Reply>> {
Ok(Client {
reader_guard: self.reader_guard.clone(),
outbound: self.outbound.clone(),
requests: self.requests.clone(),
shutdown: try!(self.shutdown.try_clone()),
})
}
fn rpc_internal(&self, request: Request) -> Receiver<Result<Reply>>
where Request: serde::ser::Serialize + fmt::Debug + Send + 'static
{
let (tx, rx) = channel();
let mut state = self.synced_state.lock().unwrap();
let id = increment(&mut state.next_id);
try!(self.requests.lock().unwrap().insert_tx(id, tx));
let packet = Packet {
rpc_id: id,
message: request,
};
debug!("Client: calling rpc({:?})", request);
if let Err(err) = bincode::serde::serialize_into(&mut state.stream,
&packet,
bincode::SizeLimit::Infinite) {
warn!("Client: failed to write packet.\nPacket: {:?}\nError: {:?}",
packet,
err);
try!(self.requests.lock().unwrap().remove_tx(id));
}
if let Err(err) = state.stream.flush() {
warn!("Client: failed to flush packet.\nPacket: {:?}\nError: {:?}",
packet,
err);
}
Ok(rx)
self.outbound.send((request, tx)).unwrap();
rx
}
/// Run the specified rpc method on the server this client is connected to
pub fn rpc(&self, request: &Request) -> Result<Reply>
pub fn rpc(&self, request: Request) -> Result<Reply>
where Request: serde::ser::Serialize + fmt::Debug + Send + 'static
{
try!(self.rpc_internal(request))
self.rpc_internal(request)
.recv()
.map_err(|_| self.requests.lock().unwrap().get_error())
.and_then(|reply| reply)
}
/// Asynchronously run the specified rpc method on the server this client is connected to
pub fn rpc_async(&self, request: &Request) -> Future<Reply>
pub fn rpc_async(&self, request: Request) -> Future<Reply>
where Request: serde::ser::Serialize + fmt::Debug + Send + 'static
{
Future {
@@ -508,15 +541,18 @@ impl<Request, Reply> Drop for Client<Request, Reply>
where Request: serde::ser::Serialize
{
fn drop(&mut self) {
if let Err(e) = self.synced_state
.lock()
.unwrap()
.stream
.get_mut()
.shutdown(::std::net::Shutdown::Both) {
warn!("Client: couldn't shutdown reader thread: {:?}", e);
} else {
self.reader_guard.take().unwrap().join().unwrap();
debug!("Dropping Client.");
if let Some(reader_guard) = Arc::get_mut(&mut self.reader_guard) {
debug!("Attempting to shut down writer and reader threads.");
if let Err(e) = self.shutdown.shutdown(::std::net::Shutdown::Both) {
warn!("Client: couldn't shutdown writer and reader threads: {:?}", e);
} else {
// We only join if we know the TcpStream was shut down. Otherwise we might never
// finish.
debug!("Joining writer and reader.");
reader_guard.take().unwrap().join().unwrap();
debug!("Successfully joined writer and reader.");
}
}
}
}
@@ -524,8 +560,8 @@ impl<Request, Reply> Drop for Client<Request, Reply>
#[cfg(test)]
mod test {
extern crate env_logger;
use super::{Client, Serve, serve_async};
use scoped_pool::Pool;
use std::sync::{Arc, Barrier, Mutex};
use std::thread;
use std::time::Duration;
@@ -588,10 +624,10 @@ mod test {
let addr = serve_handle.local_addr().clone();
let client = Client::new(addr, None).unwrap();
assert_eq!(Reply::Increment(0),
client.rpc(&Request::Increment).unwrap());
client.rpc(Request::Increment).unwrap());
assert_eq!(1, server.count());
assert_eq!(Reply::Increment(1),
client.rpc(&Request::Increment).unwrap());
client.rpc(Request::Increment).unwrap());
assert_eq!(2, server.count());
drop(client);
serve_handle.shutdown();
@@ -632,7 +668,7 @@ mod test {
let addr = serve_handle.local_addr().clone();
let client: Client<Request, Reply> = Client::new(addr, None).unwrap();
let thread = thread::spawn(move || serve_handle.shutdown());
info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment));
info!("force_shutdown:: rpc1: {:?}", client.rpc(Request::Increment));
thread.join().unwrap();
}
@@ -644,34 +680,29 @@ mod test {
let addr = serve_handle.local_addr().clone();
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None).unwrap());
serve_handle.shutdown();
match client.rpc(&Request::Increment) {
match client.rpc(Request::Increment) {
Err(super::Error::ConnectionBroken) => {} // success
otherwise => panic!("Expected Err(ConnectionBroken), got {:?}", otherwise),
}
let _ = client.rpc(&Request::Increment); // Test whether second failure hangs
let _ = client.rpc(Request::Increment); // Test whether second failure hangs
}
#[test]
fn concurrent() {
let _ = env_logger::init();
let concurrency = 10;
let pool = Pool::new(concurrency);
let server = Arc::new(BarrierServer::new(concurrency));
let serve_handle = serve_async("localhost:0", server.clone(), test_timeout()).unwrap();
let addr = serve_handle.local_addr().clone();
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None).unwrap());
let mut join_handles = vec![];
for _ in 0..concurrency {
let my_client = client.clone();
join_handles.push(thread::spawn(move || my_client.rpc(&Request::Increment).unwrap()));
}
for handle in join_handles.into_iter() {
handle.join().unwrap();
}
let client: Client<Request, Reply> = Client::new(addr, None).unwrap();
pool.scoped(|scope| {
for _ in 0..concurrency {
let client = client.try_clone().unwrap();
scope.execute(move || { client.rpc(Request::Increment).unwrap(); });
}
});
assert_eq!(concurrency as u64, server.count());
let client = match Arc::try_unwrap(client) {
Err(_) => panic!("couldn't unwrap arc"),
Ok(c) => c,
};
drop(client);
serve_handle.shutdown();
}
@@ -685,9 +716,9 @@ mod test {
let client: Client<Request, Reply> = Client::new(addr, None).unwrap();
// Drop future immediately; does the reader channel panic when sending?
client.rpc_async(&Request::Increment);
client.rpc_async(Request::Increment);
// If the reader panicked, this won't succeed
client.rpc_async(&Request::Increment);
client.rpc_async(Request::Increment);
drop(client);
serve_handle.shutdown();