Properly wait for spawned connection handler threads to shutdown. Set client timeout to None in tests.

This commit is contained in:
Tim Kuehn
2016-01-14 09:22:46 -08:00
parent 8c51d2ca1b
commit 0df3cfdd98

View File

@@ -72,56 +72,101 @@ impl convert::From<sync::mpsc::RecvError> for Error {
/// Return type of rpc calls: either the successful return value, or a client error.
pub type Result<T> = ::std::result::Result<T, Error>;
struct ConnectionHandler {
shutdown: Arc<AtomicBool>,
#[derive(Clone)]
struct OpenConnections {
open_connections: Arc<(Mutex<u64>, Condvar)>,
}
impl OpenConnections {
fn new(mutex: Mutex<u64>, cvar: Condvar) -> OpenConnections {
OpenConnections {
open_connections: Arc::new((mutex, cvar)),
}
}
fn wait_until_zero(&self) {
let &(ref count, ref cvar) = &*self.open_connections;
let mut count = count.lock().unwrap();
while *count != 0 {
count = cvar.wait(count).unwrap();
}
info!("serve_async: shutdown complete ({} connections alive)", *count);
}
fn increment(&self) {
let &(ref count, _) = &*self.open_connections;
*count.lock().unwrap() += 1;
}
fn decrement(&self) {
let &(ref count, _) = &*self.open_connections;
*count.lock().unwrap() -= 1;
}
fn decrement_and_notify(&self) {
let &(ref count, ref cvar) = &*self.open_connections;
*count.lock().unwrap() -= 1;
cvar.notify_one();
}
}
struct ConnectionHandler {
read_stream: TcpStream,
write_stream: Arc<Mutex<TcpStream>>,
shutdown: Arc<AtomicBool>,
open_connections: OpenConnections,
timeout: Option<Duration>,
}
impl Drop for ConnectionHandler {
fn drop(&mut self) {
let &(ref count, ref cvar) = &*self.open_connections;
*count.lock().unwrap() -= 1;
cvar.notify_one();
if let Err(e) = bincode::serde::serialize_into(&mut self.read_stream,
&Packet::Shutdown::<()>,
bincode::SizeLimit::Infinite) {
warn!("ConnectionHandler: could not notify client of shutdown: {:?}", e);
}
trace!("ConnectionHandler: finished serving client.");
self.open_connections.decrement_and_notify();
}
}
impl ConnectionHandler {
fn handle_conn<F, Request, Reply>(&self, stream: TcpStream, f: F) -> Result<()>
fn read<Request>(&mut self) -> bincode::serde::DeserializeResult<Packet<Request>>
where Request: serde::de::Deserialize
{
try!(self.read_stream.set_read_timeout(self.timeout));
bincode::serde::deserialize_from(&mut self.read_stream, bincode::SizeLimit::Infinite)
}
fn handle_conn<F, Request, Reply>(&mut self, f: F) -> Result<()>
where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize,
Reply: 'static + fmt::Debug + serde::ser::Serialize,
F: 'static + Clone + Serve<Request, Reply>
{
trace!("ConnectionHandler: serving client...");
let mut read_stream = try!(stream.try_clone());
let stream = Arc::new(Mutex::new(stream));
loop {
try!(read_stream.set_read_timeout(self.timeout));
match bincode::serde::deserialize_from(&mut read_stream, bincode::SizeLimit::Infinite) {
Ok(request_packet @ Packet::Shutdown) => {
let stream = stream.clone();
let mut my_stream = stream.lock().unwrap();
try!(bincode::serde::serialize_into(&mut *my_stream,
&request_packet,
bincode::SizeLimit::Infinite));
break;
}
match self.read() {
Ok(Packet::Shutdown) => break,
Ok(Packet::Message(id, message)) => {
let f = f.clone();
let arc_stream = stream.clone();
let open_connections = self.open_connections.clone();
open_connections.increment();
let stream = self.write_stream.clone();
thread::spawn(move || {
let reply = f.serve(message);
let reply_packet = Packet::Message(id, reply);
let mut my_stream = arc_stream.lock().unwrap();
bincode::serde::serialize_into(&mut *my_stream,
let mut stream = stream.lock().unwrap();
if let Err(e) = bincode::serde::serialize_into(&mut *stream,
&reply_packet,
bincode::SizeLimit::Infinite)
.unwrap();
bincode::SizeLimit::Infinite) {
warn!("ConnectionHandler: failed to write reply to Client: {:?}", e);
}
open_connections.decrement();
});
}
Err(bincode::serde::DeserializeError::IoError(ref err))
if Self::timed_out(err.kind()) => {
Err(bincode::serde::DeserializeError::IoError(ref err)) if Self::timed_out(err.kind()) => {
if !self.shutdown() {
warn!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \
retrying read.",
@@ -131,10 +176,6 @@ impl ConnectionHandler {
warn!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \
closing connection.",
err);
let mut stream = stream.lock().unwrap();
try!(bincode::serde::serialize_into(&mut *stream,
&Packet::Shutdown::<Reply>,
bincode::SizeLimit::Infinite));
break;
}
}
@@ -208,20 +249,14 @@ pub fn serve_async<A, F, Request, Reply>(addr: A,
let (die_tx, die_rx) = channel();
let join_handle = thread::spawn(move || {
let shutdown = Arc::new(AtomicBool::new(false));
let open_connections = Arc::new((Mutex::new(0), Condvar::new()));
let open_connections = OpenConnections::new(Mutex::new(0), Condvar::new());
for conn in listener.incoming() {
match die_rx.try_recv() {
Ok(_) => {
info!("serve_async: shutdown received. Waiting for open connections to \
return...");
shutdown.store(true, Ordering::SeqCst);
let &(ref count, ref cvar) = &*open_connections;
let mut count = count.lock().unwrap();
while *count != 0 {
count = cvar.wait(count).unwrap();
}
info!("serve_async: shutdown complete ({} connections alive)",
*count);
open_connections.wait_until_zero();
break;
}
Err(TryRecvError::Disconnected) => {
@@ -239,16 +274,17 @@ pub fn serve_async<A, F, Request, Reply>(addr: A,
};
let f = f.clone();
let shutdown = shutdown.clone();
let &(ref count, _) = &*open_connections;
*count.lock().unwrap() += 1;
open_connections.increment();
let open_connections = open_connections.clone();
let mut handler = ConnectionHandler {
read_stream: conn.try_clone().unwrap(),
write_stream: Arc::new(Mutex::new(conn)),
shutdown: shutdown,
open_connections: open_connections,
timeout: read_timeout,
};
thread::spawn(move || {
let handler = ConnectionHandler {
shutdown: shutdown,
open_connections: open_connections,
timeout: read_timeout,
};
if let Err(err) = handler.handle_conn(conn, f) {
if let Err(err) = handler.handle_conn(f) {
error!("ConnectionHandler: error in connection handling: {:?}", err);
}
});
@@ -472,7 +508,7 @@ mod test {
let server = Arc::new(Server::new());
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
let client: Client<Request, Reply> = Client::new(serve_handle.local_addr().clone(),
test_timeout())
None)
.expect(&line!().to_string());
drop(client);
serve_handle.shutdown();
@@ -484,7 +520,7 @@ mod test {
let server = Arc::new(Server::new());
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
let addr = serve_handle.local_addr().clone();
let client = Client::new(addr, test_timeout()).unwrap();
let client = Client::new(addr, None).unwrap();
assert_eq!(Reply::Increment(0),
client.rpc(&Request::Increment).unwrap());
assert_eq!(1, server.count());
@@ -526,7 +562,7 @@ mod test {
let server = Arc::new(Server::new());
let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap();
let addr = serve_handle.local_addr().clone();
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, test_timeout())
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None)
.unwrap());
let thread = thread::spawn(move || serve_handle.shutdown());
info!("force_shutdown:: rpc1: {:?}",
@@ -542,7 +578,7 @@ mod test {
let server = Arc::new(BarrierServer::new(10));
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
let addr = serve_handle.local_addr().clone();
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, test_timeout())
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None)
.unwrap());
let mut join_handles = vec![];
for _ in 0..10 {