Orderly shutdown of serving threads when calling ServeHandle::shutdown.

This commit is contained in:
Tim Kuehn
2016-01-14 01:09:47 -08:00
parent f7bc1586c2
commit 91053b96c0
4 changed files with 183 additions and 62 deletions

View File

@@ -8,3 +8,4 @@ serde = "*"
bincode = "*"
serde_macros = "*"
log = "*"
env_logger = "*"

View File

@@ -18,6 +18,7 @@
//! }
//!
//! use self::my_server::*;
//! use std::time::Duration;
//!
//! impl my_server::Service for () {
//! fn hello(&self, s: String) -> String {
@@ -30,8 +31,8 @@
//!
//! fn main() {
//! let addr = "127.0.0.1:9000";
//! let shutdown = my_server::serve(addr, ()).unwrap();
//! let client = Client::new(addr).unwrap();
//! let shutdown = my_server::serve(addr, (), Some(Duration::from_secs(30))).unwrap();
//! let client = Client::new(addr, None).unwrap();
//! assert_eq!(3, client.add(1, 2).unwrap());
//! assert_eq!("Hello, Mom!".to_string(),
//! client.hello("Mom".to_string()).unwrap());

View File

@@ -125,10 +125,10 @@ macro_rules! rpc {
impl Client {
#[doc="Create a new client that connects to the given address."]
pub fn new<A>(addr: A) -> $crate::Result<Self>
pub fn new<A>(addr: A, timeout: ::std::option::Option<::std::time::Duration>) -> $crate::Result<Self>
where A: ::std::net::ToSocketAddrs,
{
let inner = try!($crate::protocol::Client::new(addr));
let inner = try!($crate::protocol::Client::new(addr, timeout));
Ok(Client(inner))
}
@@ -151,12 +151,12 @@ macro_rules! rpc {
}
#[doc="Start a running service."]
pub fn serve<A, S>(addr: A, service: S) -> $crate::Result<$crate::protocol::ServeHandle>
pub fn serve<A, S>(addr: A, service: S, read_timeout: ::std::option::Option<::std::time::Duration>) -> $crate::Result<$crate::protocol::ServeHandle>
where A: ::std::net::ToSocketAddrs,
S: 'static + Service
{
let server = ::std::sync::Arc::new(__Server(service));
Ok(try!($crate::protocol::serve_async(addr, server)))
Ok(try!($crate::protocol::serve_async(addr, server, read_timeout)))
}
}
}
@@ -165,6 +165,12 @@ macro_rules! rpc {
#[cfg(test)]
#[allow(dead_code)]
mod test {
use std::time::Duration;
fn test_timeout() -> Option<Duration> {
Some(Duration::from_secs(5))
}
rpc! {
mod my_server {
items {
@@ -197,8 +203,8 @@ mod test {
fn simple_test() {
println!("Starting");
let addr = "127.0.0.1:9000";
let shutdown = my_server::serve(addr, ()).unwrap();
let client = Client::new(addr).unwrap();
let shutdown = my_server::serve(addr, (), test_timeout()).unwrap();
let client = Client::new(addr, None).unwrap();
assert_eq!(3, client.add(1, 2).unwrap());
let foo = Foo { message: "Adam".into() };
let want = Foo { message: format!("Hello, {}", &foo.message) };

View File

@@ -6,8 +6,10 @@ use std::convert;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::net::{TcpListener, TcpStream, SocketAddr, ToSocketAddrs};
use std::sync::{self, Mutex, Arc};
use std::sync::{self, Arc, Condvar, Mutex};
use std::sync::mpsc::{channel, Sender, TryRecvError};
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use std::thread::{self, JoinHandle};
/// Client errors that can occur during rpc calls
@@ -23,6 +25,10 @@ pub enum Error {
/// Channels are used for the client's inter-thread communication. This message is
/// propagated if the receiver unexpectedly hangs up.
Sender,
/// An internal message failed to be received.
/// Channels are used for the client's inter-thread communication. This message is
/// propagated if the sender unexpectedly hangs up.
Receiver,
/// The server hung up.
ConnectionBroken,
}
@@ -57,44 +63,97 @@ impl<T> convert::From<sync::mpsc::SendError<T>> for Error {
}
}
impl convert::From<sync::mpsc::RecvError> for Error {
fn from(_: sync::mpsc::RecvError) -> Error {
Error::Receiver
}
}
/// Return type of rpc calls: either the successful return value, or a client error.
pub type Result<T> = ::std::result::Result<T, Error>;
fn handle_conn<F, Request, Reply>(stream: TcpStream, f: F) -> Result<()>
where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize,
Reply: 'static + fmt::Debug + serde::ser::Serialize,
F: 'static + Clone + Serve<Request, Reply>
{
let mut read_stream = try!(stream.try_clone());
let stream = Arc::new(Mutex::new(stream));
loop {
let request_packet: Packet<Request> =
try!(bincode::serde::deserialize_from(&mut read_stream, bincode::SizeLimit::Infinite));
match request_packet {
Packet::Shutdown => {
let stream = stream.clone();
let mut my_stream = stream.lock().unwrap();
try!(bincode::serde::serialize_into(&mut *my_stream,
&request_packet,
bincode::SizeLimit::Infinite));
break;
}
Packet::Message(id, message) => {
let f = f.clone();
let arc_stream = stream.clone();
thread::spawn(move || {
let reply = f.serve(message);
let reply_packet = Packet::Message(id, reply);
let mut my_stream = arc_stream.lock().unwrap();
bincode::serde::serialize_into(&mut *my_stream,
&reply_packet,
bincode::SizeLimit::Infinite)
.unwrap();
});
struct ConnectionHandler {
shutdown: Arc<AtomicBool>,
open_connections: Arc<(Mutex<u64>, Condvar)>,
timeout: Option<Duration>,
}
impl Drop for ConnectionHandler {
fn drop(&mut self) {
let &(ref count, ref cvar) = &*self.open_connections;
*count.lock().unwrap() -= 1;
cvar.notify_one();
trace!("ConnectionHandler: finished serving client.");
}
}
impl ConnectionHandler {
fn handle_conn<F, Request, Reply>(&self, stream: TcpStream, f: F) -> Result<()>
where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize,
Reply: 'static + fmt::Debug + serde::ser::Serialize,
F: 'static + Clone + Serve<Request, Reply>
{
trace!("ConnectionHandler: serving client...");
let mut read_stream = try!(stream.try_clone());
let stream = Arc::new(Mutex::new(stream));
loop {
try!(read_stream.set_read_timeout(self.timeout));
match bincode::serde::deserialize_from(&mut read_stream, bincode::SizeLimit::Infinite) {
Ok(request_packet @ Packet::Shutdown) => {
let stream = stream.clone();
let mut my_stream = stream.lock().unwrap();
try!(bincode::serde::serialize_into(&mut *my_stream,
&request_packet,
bincode::SizeLimit::Infinite));
break;
}
Ok(Packet::Message(id, message)) => {
let f = f.clone();
let arc_stream = stream.clone();
thread::spawn(move || {
let reply = f.serve(message);
let reply_packet = Packet::Message(id, reply);
let mut my_stream = arc_stream.lock().unwrap();
bincode::serde::serialize_into(&mut *my_stream,
&reply_packet,
bincode::SizeLimit::Infinite)
.unwrap();
});
}
Err(bincode::serde::DeserializeError::IoError(ref err))
if Self::timed_out(err.kind()) =>
{
if !self.shutdown() {
warn!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so retrying read.", err);
continue;
} else {
warn!("ConnectionHandler: read timed out ({:?}). Server shutdown, so closing connection.", err);
let mut stream = stream.lock().unwrap();
try!(bincode::serde::serialize_into(&mut *stream,
&Packet::Shutdown::<Reply>,
bincode::SizeLimit::Infinite));
break;
}
}
Err(e) => {
warn!("ConnectionHandler: closing client connection due to error while serving: {:?}", e);
return Err(e.into());
}
}
}
Ok(())
}
fn shutdown(&self) -> bool {
self.shutdown.load(Ordering::SeqCst)
}
fn timed_out(error_kind: io::ErrorKind) -> bool {
match error_kind {
io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock => true,
_ => false,
}
}
Ok(())
}
/// Provides methods for blocking until the server completes,
@@ -118,17 +177,18 @@ impl ServeHandle {
/// Shutdown the server. Gracefully shuts down the serve thread but currently does not
/// gracefully close open connections.
pub fn shutdown(self) {
info!("ServeHandle: attempting to shut down the server.");
self.tx.send(()).expect(&line!().to_string());
if let Ok(_) = TcpStream::connect(self.addr) {
self.join_handle.join().expect(&line!().to_string());
} else {
warn!("Best effort shutdown of serve thread failed");
warn!("ServeHandle: best effort shutdown of serve thread failed");
}
}
}
/// Start
pub fn serve_async<A, F, Request, Reply>(addr: A, f: F) -> io::Result<ServeHandle>
pub fn serve_async<A, F, Request, Reply>(addr: A, f: F, read_timeout: Option<Duration>) -> io::Result<ServeHandle>
where A: ToSocketAddrs,
Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize,
Reply: 'static + fmt::Debug + serde::ser::Serialize,
@@ -136,29 +196,50 @@ pub fn serve_async<A, F, Request, Reply>(addr: A, f: F) -> io::Result<ServeHandl
{
let listener = try!(TcpListener::bind(&addr));
let addr = try!(listener.local_addr());
info!("Spinning up server on {:?}", addr);
info!("serve_async: spinning up server on {:?}", addr);
let (die_tx, die_rx) = channel();
let join_handle = thread::spawn(move || {
let shutdown = Arc::new(AtomicBool::new(false));
let open_connections = Arc::new((Mutex::new(0), Condvar::new()));
for conn in listener.incoming() {
match die_rx.try_recv() {
Ok(_) => break,
Ok(_) => {
info!("serve_async: shutdown received. Waiting for open connections to return...");
shutdown.store(true, Ordering::SeqCst);
let &(ref count, ref cvar) = &*open_connections;
let mut count = count.lock().unwrap();
while *count != 0 {
count = cvar.wait(count).unwrap();
}
info!("serve_async: shutdown complete ({} connections alive)", *count);
break;
}
Err(TryRecvError::Disconnected) => {
info!("Sender disconnected.");
info!("serve_async: sender disconnected.");
break;
}
_ => (),
}
let conn = match conn {
Err(err) => {
error!("Failed to accept connection: {:?}", err);
error!("serve_async: failed to accept connection: {:?}", err);
return;
}
Ok(c) => c,
};
let f = f.clone();
let shutdown = shutdown.clone();
let &(ref count, _) = &*open_connections;
*count.lock().unwrap() += 1;
let open_connections = open_connections.clone();
thread::spawn(move || {
if let Err(err) = handle_conn(conn, f) {
error!("Error in connection handling: {:?}", err);
let handler = ConnectionHandler {
shutdown: shutdown,
open_connections: open_connections,
timeout: read_timeout,
};
if let Err(err) = handler.handle_conn(conn, f) {
error!("ConnectionHandler: error in connection handling: {:?}", err);
}
});
}
@@ -198,11 +279,14 @@ fn reader<Reply>(mut stream: TcpStream, requests: Arc<Mutex<HashMap<u64, Sender<
bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite);
match packet {
Ok(Packet::Message(id, reply)) => {
debug!("Client: received message, id={}", id);
let mut requests = requests.lock().unwrap();
let reply_tx = requests.remove(&id).unwrap();
reply_tx.send(reply).unwrap();
}
Ok(Packet::Shutdown) => {
info!("Client: got shutdown message.");
requests.lock().unwrap().clear();
break;
}
// TODO: This shutdown logic is janky.. What's the right way to do this?
@@ -229,6 +313,7 @@ pub struct Client<Request, Reply>
synced_state: Mutex<SyncedClientState>,
requests: Arc<Mutex<HashMap<u64, Sender<Reply>>>>,
reader_guard: Option<thread::JoinHandle<()>>,
timeout: Option<Duration>,
_request: PhantomData<Request>,
}
@@ -237,7 +322,7 @@ impl<Request, Reply> Client<Request, Reply>
Request: serde::ser::Serialize
{
/// Create a new client that connects to `addr`
pub fn new<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
pub fn new<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<Self> {
let stream = try!(TcpStream::connect(addr));
let requests = Arc::new(Mutex::new(HashMap::new()));
let reader_stream = try!(stream.try_clone());
@@ -250,6 +335,7 @@ impl<Request, Reply> Client<Request, Reply>
}),
requests: requests,
reader_guard: Some(reader_guard),
timeout: timeout,
_request: PhantomData,
})
}
@@ -266,17 +352,20 @@ impl<Request, Reply> Client<Request, Reply>
requests.insert(id, tx);
}
let packet = Packet::Message(id, request);
try!(state.stream.set_write_timeout(self.timeout));
try!(state.stream.set_read_timeout(self.timeout));
debug!("Client: calling rpc({:?})", request);
if let Err(err) = bincode::serde::serialize_into(&mut state.stream,
&packet,
bincode::SizeLimit::Infinite) {
warn!("Failed to write client packet.\nPacket: {:?}\nError: {:?}",
warn!("Client: failed to write packet.\nPacket: {:?}\nError: {:?}",
packet,
err);
self.requests.lock().unwrap().remove(&id);
return Err(err.into());
}
drop(state);
Ok(rx.recv().unwrap())
Ok(try!(rx.recv()))
}
}
@@ -299,9 +388,16 @@ impl<Request, Reply> Drop for Client<Request, Reply>
#[cfg(test)]
mod test {
extern crate env_logger;
use super::*;
use std::sync::{Arc, Mutex, Barrier};
use std::thread;
use std::time::Duration;
fn test_timeout() -> Option<Duration> {
Some(Duration::from_millis(1))
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
enum Request {
@@ -337,21 +433,24 @@ mod test {
}
#[test]
fn test_handle() {
fn handle() {
let _ = env_logger::init();
let server = Arc::new(Server::new());
let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap();
let client: Client<Request, Reply> = Client::new(serve_handle.local_addr().clone())
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
let client: Client<Request, Reply> = Client::new(serve_handle.local_addr().clone(),
test_timeout())
.expect(&line!().to_string());
drop(client);
serve_handle.shutdown();
}
#[test]
fn test_simple() {
fn simple() {
let _ = env_logger::init();
let server = Arc::new(Server::new());
let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap();
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
let addr = serve_handle.local_addr().clone();
let client = Client::new(addr).unwrap();
let client = Client::new(addr, test_timeout()).unwrap();
assert_eq!(Reply::Increment(0),
client.rpc(&Request::Increment).unwrap());
assert_eq!(1, server.count());
@@ -388,11 +487,25 @@ mod test {
}
#[test]
fn test_concurrent() {
let server = Arc::new(BarrierServer::new(10));
let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap();
fn force_shutdown() {
let _ = env_logger::init();
let server = Arc::new(Server::new());
let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap();
let addr = serve_handle.local_addr().clone();
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr).unwrap());
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, test_timeout())
.unwrap());
let thread = thread::spawn(move || serve_handle.shutdown());
info!("force_shutdown::client: {:?}", client.rpc(&Request::Increment));
thread.join().unwrap();
}
#[test]
fn concurrent() {
let _ = env_logger::init();
let server = Arc::new(BarrierServer::new(10));
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
let addr = serve_handle.local_addr().clone();
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, test_timeout()).unwrap());
let mut join_handles = vec![];
for _ in 0..10 {
let my_client = client.clone();