Bundle of small changes.

1. Rename OpenConnections => InflightRpcs, as it represents all current rpc calls being processed.
2. Change Packet from a tuple struct to a regular struct, to clarify its fields.
3. Lower log statements from WARN to INFO where appropriate.
4. Remove shutdown method on ConnectionHandler to disambiguate with the shutdown field.
5. Add a test of client behavior when calling rpc on a client whose stream closed.
This commit is contained in:
Tim Kuehn
2016-01-15 00:28:42 -08:00
parent e4faff74be
commit ebd825e679

View File

@@ -73,17 +73,17 @@ impl convert::From<sync::mpsc::RecvError> for Error {
pub type Result<T> = ::std::result::Result<T, Error>; pub type Result<T> = ::std::result::Result<T, Error>;
#[derive(Clone)] #[derive(Clone)]
struct OpenConnections { struct InflightRpcs {
open_connections: Arc<(Mutex<u64>, Condvar)>, inflight_rpcs: Arc<(Mutex<u64>, Condvar)>,
} }
impl OpenConnections { impl InflightRpcs {
fn new(mutex: Mutex<u64>, cvar: Condvar) -> OpenConnections { fn new(mutex: Mutex<u64>, cvar: Condvar) -> InflightRpcs {
OpenConnections { open_connections: Arc::new((mutex, cvar)) } InflightRpcs { inflight_rpcs: Arc::new((mutex, cvar)) }
} }
fn wait_until_zero(&self) { fn wait_until_zero(&self) {
let &(ref count, ref cvar) = &*self.open_connections; let &(ref count, ref cvar) = &*self.inflight_rpcs;
let mut count = count.lock().unwrap(); let mut count = count.lock().unwrap();
while *count != 0 { while *count != 0 {
count = cvar.wait(count).unwrap(); count = cvar.wait(count).unwrap();
@@ -93,18 +93,18 @@ impl OpenConnections {
} }
fn increment(&self) { fn increment(&self) {
let &(ref count, _) = &*self.open_connections; let &(ref count, _) = &*self.inflight_rpcs;
*count.lock().unwrap() += 1; *count.lock().unwrap() += 1;
} }
fn decrement(&self) { fn decrement(&self) {
let &(ref count, _) = &*self.open_connections; let &(ref count, _) = &*self.inflight_rpcs;
*count.lock().unwrap() -= 1; *count.lock().unwrap() -= 1;
} }
fn decrement_and_notify(&self) { fn decrement_and_notify(&self) {
let &(ref count, ref cvar) = &*self.open_connections; let &(ref count, ref cvar) = &*self.inflight_rpcs;
*count.lock().unwrap() -= 1; *count.lock().unwrap() -= 1;
cvar.notify_one(); cvar.notify_one();
} }
@@ -115,14 +115,14 @@ struct ConnectionHandler {
read_stream: TcpStream, read_stream: TcpStream,
write_stream: Arc<Mutex<TcpStream>>, write_stream: Arc<Mutex<TcpStream>>,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
open_connections: OpenConnections, inflight_rpcs: InflightRpcs,
timeout: Option<Duration>, timeout: Option<Duration>,
} }
impl Drop for ConnectionHandler { impl Drop for ConnectionHandler {
fn drop(&mut self) { fn drop(&mut self) {
trace!("ConnectionHandler: finished serving client."); trace!("ConnectionHandler: finished serving client.");
self.open_connections.decrement_and_notify(); self.inflight_rpcs.decrement_and_notify();
} }
} }
@@ -140,14 +140,20 @@ impl ConnectionHandler {
trace!("ConnectionHandler: serving client..."); trace!("ConnectionHandler: serving client...");
loop { loop {
match self.read() { match self.read() {
Ok(Packet(id, message)) => { Ok(Packet {
rpc_id: id,
message: message
}) => {
let f = f.clone(); let f = f.clone();
let open_connections = self.open_connections.clone(); let inflight_rpcs = self.inflight_rpcs.clone();
open_connections.increment(); inflight_rpcs.increment();
let stream = self.write_stream.clone(); let stream = self.write_stream.clone();
thread::spawn(move || { thread::spawn(move || {
let reply = f.serve(message); let reply = f.serve(message);
let reply_packet = Packet(id, reply); let reply_packet = Packet {
rpc_id: id,
message: reply
};
let mut stream = stream.lock().unwrap(); let mut stream = stream.lock().unwrap();
if let Err(e) = if let Err(e) =
bincode::serde::serialize_into(&mut *stream, bincode::serde::serialize_into(&mut *stream,
@@ -156,18 +162,22 @@ impl ConnectionHandler {
warn!("ConnectionHandler: failed to write reply to Client: {:?}", warn!("ConnectionHandler: failed to write reply to Client: {:?}",
e); e);
} }
open_connections.decrement(); inflight_rpcs.decrement();
}); });
if self.shutdown.load(Ordering::SeqCst) {
info!("ConnectionHandler: server shutdown, so closing connection.");
break;
}
} }
Err(bincode::serde::DeserializeError::IoError(ref err)) Err(bincode::serde::DeserializeError::IoError(ref err))
if Self::timed_out(err.kind()) => { if Self::timed_out(err.kind()) => {
if !self.shutdown() { if !self.shutdown.load(Ordering::SeqCst) {
warn!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \ info!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \
retrying read.", retrying read.",
err); err);
continue; continue;
} else { } else {
warn!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \ info!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \
closing connection.", closing connection.",
err); err);
break; break;
@@ -184,10 +194,6 @@ impl ConnectionHandler {
Ok(()) Ok(())
} }
fn shutdown(&self) -> bool {
self.shutdown.load(Ordering::SeqCst)
}
fn timed_out(error_kind: io::ErrorKind) -> bool { fn timed_out(error_kind: io::ErrorKind) -> bool {
match error_kind { match error_kind {
io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock => true, io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock => true,
@@ -238,14 +244,14 @@ pub fn serve_async<A, F>(addr: A, f: F, read_timeout: Option<Duration>) -> io::R
let (die_tx, die_rx) = channel(); let (die_tx, die_rx) = channel();
let join_handle = thread::spawn(move || { let join_handle = thread::spawn(move || {
let shutdown = Arc::new(AtomicBool::new(false)); let shutdown = Arc::new(AtomicBool::new(false));
let open_connections = OpenConnections::new(Mutex::new(0), Condvar::new()); let inflight_rpcs = InflightRpcs::new(Mutex::new(0), Condvar::new());
for conn in listener.incoming() { for conn in listener.incoming() {
match die_rx.try_recv() { match die_rx.try_recv() {
Ok(_) => { Ok(_) => {
info!("serve_async: shutdown received. Waiting for open connections to \ info!("serve_async: shutdown received. Waiting for open connections to \
return..."); return...");
shutdown.store(true, Ordering::SeqCst); shutdown.store(true, Ordering::SeqCst);
open_connections.wait_until_zero(); inflight_rpcs.wait_until_zero();
break; break;
} }
Err(TryRecvError::Disconnected) => { Err(TryRecvError::Disconnected) => {
@@ -263,13 +269,13 @@ pub fn serve_async<A, F>(addr: A, f: F, read_timeout: Option<Duration>) -> io::R
}; };
let f = f.clone(); let f = f.clone();
let shutdown = shutdown.clone(); let shutdown = shutdown.clone();
open_connections.increment(); inflight_rpcs.increment();
let open_connections = open_connections.clone(); let inflight_rpcs = inflight_rpcs.clone();
let mut handler = ConnectionHandler { let mut handler = ConnectionHandler {
read_stream: conn.try_clone().unwrap(), read_stream: conn.try_clone().unwrap(),
write_stream: Arc::new(Mutex::new(conn)), write_stream: Arc::new(Mutex::new(conn)),
shutdown: shutdown, shutdown: shutdown,
open_connections: open_connections, inflight_rpcs: inflight_rpcs,
timeout: read_timeout, timeout: read_timeout,
}; };
thread::spawn(move || { thread::spawn(move || {
@@ -308,7 +314,10 @@ impl<S> Serve for Arc<S> where S: Serve
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
struct Packet<T>(u64, T); struct Packet<T> {
rpc_id: u64,
message: T
}
struct Reader<Reply> { struct Reader<Reply> {
requests: Arc<Mutex<Option<HashMap<u64, Sender<Reply>>>>>, requests: Arc<Mutex<Option<HashMap<u64, Sender<Reply>>>>>,
@@ -322,7 +331,10 @@ impl<Reply> Reader<Reply> {
let packet: bincode::serde::DeserializeResult<Packet<Reply>> = let packet: bincode::serde::DeserializeResult<Packet<Reply>> =
bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite); bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite);
match packet { match packet {
Ok(Packet(id, reply)) => { Ok(Packet {
rpc_id: id,
message: reply
}) => {
debug!("Client: received message, id={}", id); debug!("Client: received message, id={}", id);
let mut requests = self.requests.lock().unwrap(); let mut requests = self.requests.lock().unwrap();
let mut requests = requests.as_mut().unwrap(); let mut requests = requests.as_mut().unwrap();
@@ -408,7 +420,10 @@ impl<Request, Reply> Client<Request, Reply>
return Err(Error::ConnectionBroken); return Err(Error::ConnectionBroken);
} }
} }
let packet = Packet(id, request); let packet = Packet {
rpc_id: id,
message: request,
};
try!(state.stream.set_write_timeout(self.timeout)); try!(state.stream.set_write_timeout(self.timeout));
try!(state.stream.set_read_timeout(self.timeout)); try!(state.stream.set_read_timeout(self.timeout));
debug!("Client: calling rpc({:?})", request); debug!("Client: calling rpc({:?})", request);
@@ -559,13 +574,22 @@ mod test {
let addr = serve_handle.local_addr().clone(); let addr = serve_handle.local_addr().clone();
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None).unwrap()); let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None).unwrap());
let thread = thread::spawn(move || serve_handle.shutdown()); let thread = thread::spawn(move || serve_handle.shutdown());
info!("force_shutdown:: rpc1: {:?}", info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment));
client.rpc(&Request::Increment));
info!("force_shutdown:: rpc2: {:?}",
client.rpc(&Request::Increment));
thread.join().unwrap(); thread.join().unwrap();
} }
#[test]
fn client_failed_rpc() {
let _ = env_logger::init();
let server = Arc::new(Server::new());
let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap();
let addr = serve_handle.local_addr().clone();
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None).unwrap());
serve_handle.shutdown();
let _ = client.rpc(&Request::Increment); // First failure will trigger reader to shutdown
let _ = client.rpc(&Request::Increment); // Test whether second failure hangs
}
#[test] #[test]
fn concurrent() { fn concurrent() {
let _ = env_logger::init(); let _ = env_logger::init();