mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-13 21:18:18 +01:00
Properly wait for spawned connection handler threads to shutdown. Set client timeout to None in tests.
This commit is contained in:
@@ -72,56 +72,101 @@ impl convert::From<sync::mpsc::RecvError> for Error {
|
|||||||
/// Return type of rpc calls: either the successful return value, or a client error.
|
/// Return type of rpc calls: either the successful return value, or a client error.
|
||||||
pub type Result<T> = ::std::result::Result<T, Error>;
|
pub type Result<T> = ::std::result::Result<T, Error>;
|
||||||
|
|
||||||
struct ConnectionHandler {
|
#[derive(Clone)]
|
||||||
shutdown: Arc<AtomicBool>,
|
struct OpenConnections {
|
||||||
open_connections: Arc<(Mutex<u64>, Condvar)>,
|
open_connections: Arc<(Mutex<u64>, Condvar)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenConnections {
|
||||||
|
fn new(mutex: Mutex<u64>, cvar: Condvar) -> OpenConnections {
|
||||||
|
OpenConnections {
|
||||||
|
open_connections: Arc::new((mutex, cvar)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn wait_until_zero(&self) {
|
||||||
|
let &(ref count, ref cvar) = &*self.open_connections;
|
||||||
|
let mut count = count.lock().unwrap();
|
||||||
|
while *count != 0 {
|
||||||
|
count = cvar.wait(count).unwrap();
|
||||||
|
}
|
||||||
|
info!("serve_async: shutdown complete ({} connections alive)", *count);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn increment(&self) {
|
||||||
|
let &(ref count, _) = &*self.open_connections;
|
||||||
|
*count.lock().unwrap() += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decrement(&self) {
|
||||||
|
let &(ref count, _) = &*self.open_connections;
|
||||||
|
*count.lock().unwrap() -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn decrement_and_notify(&self) {
|
||||||
|
let &(ref count, ref cvar) = &*self.open_connections;
|
||||||
|
*count.lock().unwrap() -= 1;
|
||||||
|
cvar.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ConnectionHandler {
|
||||||
|
read_stream: TcpStream,
|
||||||
|
write_stream: Arc<Mutex<TcpStream>>,
|
||||||
|
shutdown: Arc<AtomicBool>,
|
||||||
|
open_connections: OpenConnections,
|
||||||
timeout: Option<Duration>,
|
timeout: Option<Duration>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for ConnectionHandler {
|
impl Drop for ConnectionHandler {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
let &(ref count, ref cvar) = &*self.open_connections;
|
if let Err(e) = bincode::serde::serialize_into(&mut self.read_stream,
|
||||||
*count.lock().unwrap() -= 1;
|
&Packet::Shutdown::<()>,
|
||||||
cvar.notify_one();
|
bincode::SizeLimit::Infinite) {
|
||||||
|
warn!("ConnectionHandler: could not notify client of shutdown: {:?}", e);
|
||||||
|
}
|
||||||
trace!("ConnectionHandler: finished serving client.");
|
trace!("ConnectionHandler: finished serving client.");
|
||||||
|
self.open_connections.decrement_and_notify();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConnectionHandler {
|
impl ConnectionHandler {
|
||||||
fn handle_conn<F, Request, Reply>(&self, stream: TcpStream, f: F) -> Result<()>
|
fn read<Request>(&mut self) -> bincode::serde::DeserializeResult<Packet<Request>>
|
||||||
|
where Request: serde::de::Deserialize
|
||||||
|
{
|
||||||
|
try!(self.read_stream.set_read_timeout(self.timeout));
|
||||||
|
bincode::serde::deserialize_from(&mut self.read_stream, bincode::SizeLimit::Infinite)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_conn<F, Request, Reply>(&mut self, f: F) -> Result<()>
|
||||||
where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize,
|
where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize,
|
||||||
Reply: 'static + fmt::Debug + serde::ser::Serialize,
|
Reply: 'static + fmt::Debug + serde::ser::Serialize,
|
||||||
F: 'static + Clone + Serve<Request, Reply>
|
F: 'static + Clone + Serve<Request, Reply>
|
||||||
{
|
{
|
||||||
trace!("ConnectionHandler: serving client...");
|
trace!("ConnectionHandler: serving client...");
|
||||||
let mut read_stream = try!(stream.try_clone());
|
|
||||||
let stream = Arc::new(Mutex::new(stream));
|
|
||||||
loop {
|
loop {
|
||||||
try!(read_stream.set_read_timeout(self.timeout));
|
match self.read() {
|
||||||
match bincode::serde::deserialize_from(&mut read_stream, bincode::SizeLimit::Infinite) {
|
Ok(Packet::Shutdown) => break,
|
||||||
Ok(request_packet @ Packet::Shutdown) => {
|
|
||||||
let stream = stream.clone();
|
|
||||||
let mut my_stream = stream.lock().unwrap();
|
|
||||||
try!(bincode::serde::serialize_into(&mut *my_stream,
|
|
||||||
&request_packet,
|
|
||||||
bincode::SizeLimit::Infinite));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Ok(Packet::Message(id, message)) => {
|
Ok(Packet::Message(id, message)) => {
|
||||||
let f = f.clone();
|
let f = f.clone();
|
||||||
let arc_stream = stream.clone();
|
let open_connections = self.open_connections.clone();
|
||||||
|
open_connections.increment();
|
||||||
|
let stream = self.write_stream.clone();
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
let reply = f.serve(message);
|
let reply = f.serve(message);
|
||||||
let reply_packet = Packet::Message(id, reply);
|
let reply_packet = Packet::Message(id, reply);
|
||||||
let mut my_stream = arc_stream.lock().unwrap();
|
let mut stream = stream.lock().unwrap();
|
||||||
bincode::serde::serialize_into(&mut *my_stream,
|
if let Err(e) = bincode::serde::serialize_into(&mut *stream,
|
||||||
&reply_packet,
|
&reply_packet,
|
||||||
bincode::SizeLimit::Infinite)
|
bincode::SizeLimit::Infinite) {
|
||||||
.unwrap();
|
warn!("ConnectionHandler: failed to write reply to Client: {:?}", e);
|
||||||
|
}
|
||||||
|
open_connections.decrement();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Err(bincode::serde::DeserializeError::IoError(ref err))
|
Err(bincode::serde::DeserializeError::IoError(ref err)) if Self::timed_out(err.kind()) => {
|
||||||
if Self::timed_out(err.kind()) => {
|
|
||||||
if !self.shutdown() {
|
if !self.shutdown() {
|
||||||
warn!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \
|
warn!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \
|
||||||
retrying read.",
|
retrying read.",
|
||||||
@@ -131,10 +176,6 @@ impl ConnectionHandler {
|
|||||||
warn!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \
|
warn!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \
|
||||||
closing connection.",
|
closing connection.",
|
||||||
err);
|
err);
|
||||||
let mut stream = stream.lock().unwrap();
|
|
||||||
try!(bincode::serde::serialize_into(&mut *stream,
|
|
||||||
&Packet::Shutdown::<Reply>,
|
|
||||||
bincode::SizeLimit::Infinite));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -208,20 +249,14 @@ pub fn serve_async<A, F, Request, Reply>(addr: A,
|
|||||||
let (die_tx, die_rx) = channel();
|
let (die_tx, die_rx) = channel();
|
||||||
let join_handle = thread::spawn(move || {
|
let join_handle = thread::spawn(move || {
|
||||||
let shutdown = Arc::new(AtomicBool::new(false));
|
let shutdown = Arc::new(AtomicBool::new(false));
|
||||||
let open_connections = Arc::new((Mutex::new(0), Condvar::new()));
|
let open_connections = OpenConnections::new(Mutex::new(0), Condvar::new());
|
||||||
for conn in listener.incoming() {
|
for conn in listener.incoming() {
|
||||||
match die_rx.try_recv() {
|
match die_rx.try_recv() {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
info!("serve_async: shutdown received. Waiting for open connections to \
|
info!("serve_async: shutdown received. Waiting for open connections to \
|
||||||
return...");
|
return...");
|
||||||
shutdown.store(true, Ordering::SeqCst);
|
shutdown.store(true, Ordering::SeqCst);
|
||||||
let &(ref count, ref cvar) = &*open_connections;
|
open_connections.wait_until_zero();
|
||||||
let mut count = count.lock().unwrap();
|
|
||||||
while *count != 0 {
|
|
||||||
count = cvar.wait(count).unwrap();
|
|
||||||
}
|
|
||||||
info!("serve_async: shutdown complete ({} connections alive)",
|
|
||||||
*count);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Err(TryRecvError::Disconnected) => {
|
Err(TryRecvError::Disconnected) => {
|
||||||
@@ -239,16 +274,17 @@ pub fn serve_async<A, F, Request, Reply>(addr: A,
|
|||||||
};
|
};
|
||||||
let f = f.clone();
|
let f = f.clone();
|
||||||
let shutdown = shutdown.clone();
|
let shutdown = shutdown.clone();
|
||||||
let &(ref count, _) = &*open_connections;
|
open_connections.increment();
|
||||||
*count.lock().unwrap() += 1;
|
|
||||||
let open_connections = open_connections.clone();
|
let open_connections = open_connections.clone();
|
||||||
|
let mut handler = ConnectionHandler {
|
||||||
|
read_stream: conn.try_clone().unwrap(),
|
||||||
|
write_stream: Arc::new(Mutex::new(conn)),
|
||||||
|
shutdown: shutdown,
|
||||||
|
open_connections: open_connections,
|
||||||
|
timeout: read_timeout,
|
||||||
|
};
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
let handler = ConnectionHandler {
|
if let Err(err) = handler.handle_conn(f) {
|
||||||
shutdown: shutdown,
|
|
||||||
open_connections: open_connections,
|
|
||||||
timeout: read_timeout,
|
|
||||||
};
|
|
||||||
if let Err(err) = handler.handle_conn(conn, f) {
|
|
||||||
error!("ConnectionHandler: error in connection handling: {:?}", err);
|
error!("ConnectionHandler: error in connection handling: {:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -472,7 +508,7 @@ mod test {
|
|||||||
let server = Arc::new(Server::new());
|
let server = Arc::new(Server::new());
|
||||||
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
|
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
|
||||||
let client: Client<Request, Reply> = Client::new(serve_handle.local_addr().clone(),
|
let client: Client<Request, Reply> = Client::new(serve_handle.local_addr().clone(),
|
||||||
test_timeout())
|
None)
|
||||||
.expect(&line!().to_string());
|
.expect(&line!().to_string());
|
||||||
drop(client);
|
drop(client);
|
||||||
serve_handle.shutdown();
|
serve_handle.shutdown();
|
||||||
@@ -484,7 +520,7 @@ mod test {
|
|||||||
let server = Arc::new(Server::new());
|
let server = Arc::new(Server::new());
|
||||||
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
|
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
|
||||||
let addr = serve_handle.local_addr().clone();
|
let addr = serve_handle.local_addr().clone();
|
||||||
let client = Client::new(addr, test_timeout()).unwrap();
|
let client = Client::new(addr, None).unwrap();
|
||||||
assert_eq!(Reply::Increment(0),
|
assert_eq!(Reply::Increment(0),
|
||||||
client.rpc(&Request::Increment).unwrap());
|
client.rpc(&Request::Increment).unwrap());
|
||||||
assert_eq!(1, server.count());
|
assert_eq!(1, server.count());
|
||||||
@@ -526,7 +562,7 @@ mod test {
|
|||||||
let server = Arc::new(Server::new());
|
let server = Arc::new(Server::new());
|
||||||
let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap();
|
let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap();
|
||||||
let addr = serve_handle.local_addr().clone();
|
let addr = serve_handle.local_addr().clone();
|
||||||
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, test_timeout())
|
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None)
|
||||||
.unwrap());
|
.unwrap());
|
||||||
let thread = thread::spawn(move || serve_handle.shutdown());
|
let thread = thread::spawn(move || serve_handle.shutdown());
|
||||||
info!("force_shutdown:: rpc1: {:?}",
|
info!("force_shutdown:: rpc1: {:?}",
|
||||||
@@ -542,7 +578,7 @@ mod test {
|
|||||||
let server = Arc::new(BarrierServer::new(10));
|
let server = Arc::new(BarrierServer::new(10));
|
||||||
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
|
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
|
||||||
let addr = serve_handle.local_addr().clone();
|
let addr = serve_handle.local_addr().clone();
|
||||||
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, test_timeout())
|
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None)
|
||||||
.unwrap());
|
.unwrap());
|
||||||
let mut join_handles = vec![];
|
let mut join_handles = vec![];
|
||||||
for _ in 0..10 {
|
for _ in 0..10 {
|
||||||
|
|||||||
Reference in New Issue
Block a user