From e59116fb482f2962304c79e825a510e1a5d7c97f Mon Sep 17 00:00:00 2001 From: Tim Date: Mon, 6 Mar 2017 20:57:12 -0800 Subject: [PATCH] Add server::Handle::shutdown (#117) * Add server::Handle::shutdown * Hybrid approach: lameduck + total shutdown when all clients disconnect. * The future handle has addr() and shutdown(), but not run(). --- README.md | 8 +- benches/latency.rs | 5 +- examples/concurrency.rs | 6 +- examples/pubsub.rs | 18 +- examples/readme_errors.rs | 3 +- examples/readme_futures.rs | 4 +- examples/readme_sync.rs | 3 +- examples/server_calling_server.rs | 8 +- examples/throughput.rs | 3 +- examples/two_clients.rs | 16 +- src/lib.rs | 2 +- src/macros.rs | 224 ++++++++++---- src/server.rs | 465 ++++++++++++++++++++++++++---- 13 files changed, 623 insertions(+), 142 deletions(-) diff --git a/README.md b/README.md index af3fe31..de47c22 100644 --- a/README.md +++ b/README.md @@ -131,13 +131,13 @@ impl FutureService for HelloServer { fn main() { let mut reactor = reactor::Core::new().unwrap(); - let (addr, server) = HelloServer.listen("localhost:10000".first_socket_addr(), + let (handle, server) = HelloServer.listen("localhost:10000".first_socket_addr(), &reactor.handle(), server::Options::default()) .unwrap(); reactor.handle().spawn(server); let options = client::Options::default().handle(reactor.handle()); - reactor.run(FutureClient::connect(addr, options) + reactor.run(FutureClient::connect(handle.addr(), options) .map_err(tarpc::Error::from) .and_then(|client| client.hello("Mom".to_string())) .map(|resp| println!("{}", resp))) @@ -210,14 +210,14 @@ fn get_acceptor() -> TlsAcceptor { fn main() { let mut reactor = reactor::Core::new().unwrap(); let acceptor = get_acceptor(); - let (addr, server) = HelloServer.listen("localhost:10000".first_socket_addr(), + let (handle, server) = HelloServer.listen("localhost:10000".first_socket_addr(), &reactor.handle(), server::Options::default().tls(acceptor)).unwrap(); reactor.handle().spawn(server); let options = client::Options::default() .handle(reactor.handle()) .tls(client::tls::Context::new("foobar.com").unwrap()); - reactor.run(FutureClient::connect(addr, options) + reactor.run(FutureClient::connect(handle.addr(), options) .map_err(tarpc::Error::from) .and_then(|client| client.hello("Mom".to_string())) .map(|resp| println!("{}", resp))) diff --git a/benches/latency.rs b/benches/latency.rs index df1672a..c8357ed 100644 --- a/benches/latency.rs +++ b/benches/latency.rs @@ -40,12 +40,13 @@ impl FutureService for Server { fn latency(bencher: &mut Bencher) { let _ = env_logger::init(); let mut reactor = reactor::Core::new().unwrap(); - let (addr, server) = Server.listen("localhost:0".first_socket_addr(), + let (handle, server) = Server.listen("localhost:0".first_socket_addr(), &reactor.handle(), server::Options::default()) .unwrap(); reactor.handle().spawn(server); - let client = FutureClient::connect(addr, client::Options::default().handle(reactor.handle())); + let client = FutureClient::connect(handle.addr(), + client::Options::default().handle(reactor.handle())); let client = reactor.run(client).unwrap(); bencher.iter(|| reactor.run(client.ack()).unwrap()); diff --git a/examples/concurrency.rs b/examples/concurrency.rs index 35c5dfd..767057f 100644 --- a/examples/concurrency.rs +++ b/examples/concurrency.rs @@ -167,20 +167,20 @@ fn main() { .unwrap_or(4); let mut reactor = reactor::Core::new().unwrap(); - let (addr, server) = Server::new() + let (handle, server) = Server::new() .listen("localhost:0".first_socket_addr(), &reactor.handle(), server::Options::default()) .unwrap(); reactor.handle().spawn(server); - info!("Server listening on {}.", addr); + info!("Server listening on {}.", handle.addr()); let clients = (0..num_clients) // Spin up a couple threads to drive the clients. .map(|i| (i, spawn_core())) .map(|(i, remote)| { info!("Client {} connecting...", i); - FutureClient::connect(addr, client::Options::default().remote(remote)) + FutureClient::connect(handle.addr(), client::Options::default().remote(remote)) .map_err(|e| panic!(e)) }); diff --git a/examples/pubsub.rs b/examples/pubsub.rs index 189f7d8..7435314 100644 --- a/examples/pubsub.rs +++ b/examples/pubsub.rs @@ -58,12 +58,15 @@ impl subscriber::FutureService for Subscriber { } impl Subscriber { - fn listen(id: u32, handle: &reactor::Handle, options: server::Options) -> SocketAddr { - let (addr, server) = Subscriber { id: id } + fn listen(id: u32, + handle: &reactor::Handle, + options: server::Options) + -> server::future::Handle { + let (server_handle, server) = Subscriber { id: id } .listen("localhost:0".first_socket_addr(), handle, options) .unwrap(); handle.spawn(server); - addr + server_handle } } @@ -120,7 +123,7 @@ impl publisher::FutureService for Publisher { fn main() { let _ = env_logger::init(); let mut reactor = reactor::Core::new().unwrap(); - let (publisher_addr, server) = Publisher::new() + let (publisher_handle, server) = Publisher::new() .listen("localhost:0".first_socket_addr(), &reactor.handle(), server::Options::default()) @@ -131,10 +134,11 @@ fn main() { let subscriber2 = Subscriber::listen(1, &reactor.handle(), server::Options::default()); let publisher = - reactor.run(publisher::FutureClient::connect(publisher_addr, client::Options::default())) + reactor.run(publisher::FutureClient::connect(publisher_handle.addr(), + client::Options::default())) .unwrap(); - reactor.run(publisher.subscribe(0, subscriber1) - .and_then(|_| publisher.subscribe(1, subscriber2)) + reactor.run(publisher.subscribe(0, subscriber1.addr()) + .and_then(|_| publisher.subscribe(1, subscriber2.addr())) .map_err(|e| panic!(e)) .and_then(|_| { println!("Broadcasting..."); diff --git a/examples/readme_errors.rs b/examples/readme_errors.rs index 4833c10..c823732 100644 --- a/examples/readme_errors.rs +++ b/examples/readme_errors.rs @@ -55,8 +55,7 @@ impl SyncService for HelloServer { fn main() { let (tx, rx) = mpsc::channel(); thread::spawn(move || { - let mut handle = HelloServer.listen("localhost:10000", server::Options::default()) - .unwrap(); + let handle = HelloServer.listen("localhost:10000", server::Options::default()).unwrap(); tx.send(handle.addr()).unwrap(); handle.run(); }); diff --git a/examples/readme_futures.rs b/examples/readme_futures.rs index c4cb4b6..8e68ffd 100644 --- a/examples/readme_futures.rs +++ b/examples/readme_futures.rs @@ -34,14 +34,14 @@ impl FutureService for HelloServer { fn main() { let mut reactor = reactor::Core::new().unwrap(); - let (addr, server) = HelloServer.listen("localhost:10000".first_socket_addr(), + let (handle, server) = HelloServer.listen("localhost:10000".first_socket_addr(), &reactor.handle(), server::Options::default()) .unwrap(); reactor.handle().spawn(server); let options = client::Options::default().handle(reactor.handle()); - reactor.run(FutureClient::connect(addr, options) + reactor.run(FutureClient::connect(handle.addr(), options) .map_err(tarpc::Error::from) .and_then(|client| client.hello("Mom".to_string())) .map(|resp| println!("{}", resp))) diff --git a/examples/readme_sync.rs b/examples/readme_sync.rs index 85d50de..5ded1a4 100644 --- a/examples/readme_sync.rs +++ b/examples/readme_sync.rs @@ -34,8 +34,7 @@ impl SyncService for HelloServer { fn main() { let (tx, rx) = mpsc::channel(); thread::spawn(move || { - let mut handle = HelloServer.listen("localhost:0", server::Options::default()) - .unwrap(); + let handle = HelloServer.listen("localhost:0", server::Options::default()).unwrap(); tx.send(handle.addr()).unwrap(); handle.run(); }); diff --git a/examples/server_calling_server.rs b/examples/server_calling_server.rs index 56714e8..256e223 100644 --- a/examples/server_calling_server.rs +++ b/examples/server_calling_server.rs @@ -72,16 +72,16 @@ impl DoubleFutureService for DoubleServer { fn main() { let _ = env_logger::init(); let mut reactor = reactor::Core::new().unwrap(); - let (add_addr, server) = AddServer.listen("localhost:0".first_socket_addr(), + let (add, server) = AddServer.listen("localhost:0".first_socket_addr(), &reactor.handle(), server::Options::default()) .unwrap(); reactor.handle().spawn(server); let options = client::Options::default().handle(reactor.handle()); - let add_client = reactor.run(add::FutureClient::connect(add_addr, options)).unwrap(); + let add_client = reactor.run(add::FutureClient::connect(add.addr(), options)).unwrap(); - let (double_addr, server) = DoubleServer::new(add_client) + let (double, server) = DoubleServer::new(add_client) .listen("localhost:0".first_socket_addr(), &reactor.handle(), server::Options::default()) @@ -89,7 +89,7 @@ fn main() { reactor.handle().spawn(server); let double_client = - reactor.run(double::FutureClient::connect(double_addr, client::Options::default())) + reactor.run(double::FutureClient::connect(double.addr(), client::Options::default())) .unwrap(); reactor.run(futures::stream::futures_unordered((0..5).map(|i| double_client.double(i))) .map_err(|e| println!("{}", e)) diff --git a/examples/throughput.rs b/examples/throughput.rs index b1d93df..421bc53 100644 --- a/examples/throughput.rs +++ b/examples/throughput.rs @@ -66,7 +66,8 @@ fn bench_tarpc(target: u64) { tx.send(addr).unwrap(); reactor.run(server).unwrap(); }); - let mut client = SyncClient::connect(rx.recv().unwrap(), client::Options::default()).unwrap(); + let mut client = SyncClient::connect(rx.recv().unwrap().addr(), client::Options::default()) + .unwrap(); let start = time::Instant::now(); let mut nread = 0; while nread < target { diff --git a/examples/two_clients.rs b/examples/two_clients.rs index 5ec2954..55dd0ba 100644 --- a/examples/two_clients.rs +++ b/examples/two_clients.rs @@ -66,30 +66,30 @@ fn main() { let (tx, rx) = mpsc::channel(); thread::spawn(move || { let mut reactor = reactor::Core::new().unwrap(); - let (addr, server) = Bar.listen("localhost:0".first_socket_addr(), + let (handle, server) = Bar.listen("localhost:0".first_socket_addr(), &reactor.handle(), server::Options::default()) .unwrap(); - tx.send(addr).unwrap(); + tx.send(handle).unwrap(); reactor.run(server).unwrap(); }); - let addr = rx.recv().unwrap(); - bar::SyncClient::connect(addr, client::Options::default()).unwrap() + let handle = rx.recv().unwrap(); + bar::SyncClient::connect(handle.addr(), client::Options::default()).unwrap() }; let mut baz_client = { let (tx, rx) = mpsc::channel(); thread::spawn(move || { let mut reactor = reactor::Core::new().unwrap(); - let (addr, server) = Baz.listen("localhost:0".first_socket_addr(), + let (handle, server) = Baz.listen("localhost:0".first_socket_addr(), &reactor.handle(), server::Options::default()) .unwrap(); - tx.send(addr).unwrap(); + tx.send(handle).unwrap(); reactor.run(server).unwrap(); }); - let addr = rx.recv().unwrap(); - baz::SyncClient::connect(addr, client::Options::default()).unwrap() + let handle = rx.recv().unwrap(); + baz::SyncClient::connect(handle.addr(), client::Options::default()).unwrap() }; diff --git a/src/lib.rs b/src/lib.rs index 352b1b8..dbed326 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -114,7 +114,7 @@ //! ``` //! #![deny(missing_docs)] -#![feature(never_type, plugin, struct_field_attributes, fn_traits, unboxed_closures)] +#![feature(fn_traits, move_cell, never_type, plugin, struct_field_attributes, unboxed_closures)] #![plugin(tarpc_plugins)] extern crate byteorder; diff --git a/src/macros.rs b/src/macros.rs index 3cd7be6..6a07784 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -495,13 +495,13 @@ macro_rules! service { addr: ::std::net::SocketAddr, handle: &$crate::tokio_core::reactor::Handle, options: $crate::server::Options) - -> ::std::io::Result<(::std::net::SocketAddr, Listen)> + -> ::std::io::Result<($crate::server::future::Handle, Listen)> { - $crate::server::listen(tarpc_service_AsyncServer__(self), + $crate::server::future::Handle::listen(tarpc_service_AsyncServer__(self), addr, handle, options) - .map(|(addr, inner)| (addr, Listen { inner })) + .map(|(handle, inner)| (handle, Listen { inner })) } } @@ -526,7 +526,7 @@ macro_rules! service { /// /// To actually run the server, call `run` on the returned handle. fn listen(self, addr: A, options: $crate::server::Options) - -> ::std::io::Result<$crate::server::Handle> + -> ::std::io::Result<$crate::server::sync::Handle> where A: ::std::net::ToSocketAddrs { let tarpc_service__ = tarpc_service_AsyncServer__(SyncServer__ { @@ -536,14 +536,9 @@ macro_rules! service { let tarpc_service_addr__ = $crate::util::FirstSocketAddr::try_first_socket_addr(&addr)?; - let reactor_ = $crate::tokio_core::reactor::Core::new()?; - let (addr_, server_) = $crate::server::listen( - tarpc_service__, - tarpc_service_addr__, - &reactor_.handle(), - options)?; - reactor_.handle().spawn(server_); - return Ok($crate::server::Handle::new(reactor_, addr_)); + return $crate::server::sync::Handle::listen(tarpc_service__, + tarpc_service_addr__, + options); #[derive(Clone)] struct SyncServer__ { @@ -891,50 +886,75 @@ mod functional_test { } } - fn start_server_with_sync_client(server: S) -> io::Result<(SocketAddr, C)> + fn get_sync_client(addr: SocketAddr) -> io::Result + where C: client::sync::ClientExt + { + C::connect(addr, get_tls_client_options()) + } + + fn get_future_client(addr: SocketAddr, handle: reactor::Handle) -> C::ConnectFut + where C: client::future::ClientExt + { + C::connect(addr, get_tls_client_options().handle(handle)) + } + + fn start_server_with_sync_client(server: S) + -> io::Result<(SocketAddr, C, server::Shutdown)> where C: client::sync::ClientExt, S: SyncServiceExt { let options = get_tls_server_options(); let (tx, rx) = ::std::sync::mpsc::channel(); ::std::thread::spawn(move || { - let mut handle = unwrap!(server.listen("localhost:0".first_socket_addr(), + let handle = unwrap!(server.listen("localhost:0".first_socket_addr(), options)); - tx.send(handle.addr()).unwrap(); + tx.send((handle.addr(), handle.shutdown())).unwrap(); handle.run(); }); - let addr = rx.recv().unwrap(); + let (addr, shutdown) = rx.recv().unwrap(); let client = unwrap!(C::connect(addr, get_tls_client_options())); - Ok((addr, client)) + Ok((addr, client, shutdown)) } fn start_server_with_async_client(server: S) - -> io::Result<(SocketAddr, reactor::Core, C)> + -> io::Result<(server::future::Handle, reactor::Core, C)> where C: client::future::ClientExt, S: FutureServiceExt { let mut reactor = reactor::Core::new()?; let server_options = get_tls_server_options(); - let (addr, server) = server.listen("localhost:0".first_socket_addr(), + let (handle, server) = server.listen("localhost:0".first_socket_addr(), &reactor.handle(), server_options)?; reactor.handle().spawn(server); let client_options = get_tls_client_options().handle(reactor.handle()); - let client = unwrap!(reactor.run(C::connect(addr, client_options))); - Ok((addr, reactor, client)) + let client = unwrap!(reactor.run(C::connect(handle.addr(), client_options))); + Ok((handle, reactor, client)) + } + + fn return_server(server: S) + -> io::Result<(server::future::Handle, reactor::Core, Listen)> + where S: FutureServiceExt + { + let mut reactor = reactor::Core::new()?; + let server_options = get_tls_server_options(); + let (handle, server) = server.listen("localhost:0".first_socket_addr(), + &reactor.handle(), + server_options)?; + Ok((handle, reactor, server)) } fn start_err_server_with_async_client(server: S) - -> io::Result<(SocketAddr, reactor::Core, C)> + -> io::Result<(server::future::Handle, reactor::Core, C)> where C: client::future::ClientExt, S: error_service::FutureServiceExt { let mut reactor = reactor::Core::new()?; let server_options = get_tls_server_options(); - let (addr, server) = server.listen("localhost:0".first_socket_addr(), + let (handle, server) = server.listen("localhost:0".first_socket_addr(), &reactor.handle(), server_options)?; reactor.handle().spawn(server); let client_options = get_tls_client_options().handle(reactor.handle()); - let client = unwrap!(reactor.run(C::connect(addr, client_options))); - Ok((addr, reactor, client)) + let client = unwrap!(reactor.run(C::connect(handle.addr(), client_options))); + Ok((handle, reactor, client)) } } else { fn get_server_options() -> server::Options { @@ -951,56 +971,74 @@ mod functional_test { C::connect(addr, get_client_options()) } - fn start_server_with_sync_client(server: S) -> io::Result<(SocketAddr, C)> + fn get_future_client(addr: SocketAddr, handle: reactor::Handle) -> C::ConnectFut + where C: client::future::ClientExt + { + C::connect(addr, get_client_options().handle(handle)) + } + + fn start_server_with_sync_client(server: S) + -> io::Result<(SocketAddr, C, server::Shutdown)> where C: client::sync::ClientExt, S: SyncServiceExt { let options = get_server_options(); let (tx, rx) = ::std::sync::mpsc::channel(); ::std::thread::spawn(move || { - let mut handle = unwrap!(server.listen("localhost:0".first_socket_addr(), - options)); - tx.send(handle.addr()).unwrap(); + let handle = unwrap!(server.listen("localhost:0".first_socket_addr(), options)); + tx.send((handle.addr(), handle.shutdown())).unwrap(); handle.run(); }); - let addr = rx.recv().unwrap(); + let (addr, shutdown) = rx.recv().unwrap(); let client = unwrap!(get_sync_client(addr)); - Ok((addr, client)) + Ok((addr, client, shutdown)) } fn start_server_with_async_client(server: S) - -> io::Result<(SocketAddr, reactor::Core, C)> + -> io::Result<(server::future::Handle, reactor::Core, C)> where C: client::future::ClientExt, S: FutureServiceExt { let mut reactor = reactor::Core::new()?; let options = get_server_options(); - let (addr, server) = server.listen("localhost:0".first_socket_addr(), + let (handle, server) = server.listen("localhost:0".first_socket_addr(), &reactor.handle(), options)?; reactor.handle().spawn(server); - let client = unwrap!(reactor.run(C::connect(addr, get_client_options()))); - Ok((addr, reactor, client)) + let client = unwrap!(reactor.run(C::connect(handle.addr(), get_client_options()))); + Ok((handle, reactor, client)) + } + + fn return_server(server: S) + -> io::Result<(server::future::Handle, reactor::Core, Listen)> + where S: FutureServiceExt + { + let reactor = reactor::Core::new()?; + let options = get_server_options(); + let (handle, server) = server.listen("localhost:0".first_socket_addr(), + &reactor.handle(), + options)?; + Ok((handle, reactor, server)) } fn start_err_server_with_async_client(server: S) - -> io::Result<(SocketAddr, reactor::Core, C)> + -> io::Result<(server::future::Handle, reactor::Core, C)> where C: client::future::ClientExt, S: error_service::FutureServiceExt { let mut reactor = reactor::Core::new()?; let options = get_server_options(); - let (addr, server) = server.listen("localhost:0".first_socket_addr(), + let (handle, server) = server.listen("localhost:0".first_socket_addr(), &reactor.handle(), options)?; reactor.handle().spawn(server); - let client = C::connect(addr, get_client_options()); + let client = C::connect(handle.addr(), get_client_options()); let client = unwrap!(reactor.run(client)); - Ok((addr, reactor, client)) + Ok((handle, reactor, client)) } } } mod sync { - use super::{SyncClient, SyncService, env_logger, start_server_with_sync_client}; + use super::{SyncClient, SyncService, get_sync_client, env_logger, start_server_with_sync_client}; use util::Never; #[derive(Clone, Copy)] @@ -1018,16 +1056,65 @@ mod functional_test { #[test] fn simple() { let _ = env_logger::init(); - let (_, mut client) = unwrap!(start_server_with_sync_client::(Server)); + let (_, mut client, _) = unwrap!(start_server_with_sync_client::(Server)); assert_eq!(3, client.add(1, 2).unwrap()); assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).unwrap()); } + #[test] + fn shutdown() { + use futures::Future; + + let _ = env_logger::init(); + let (addr, mut client, shutdown) = + unwrap!(start_server_with_sync_client::(Server)); + assert_eq!(3, client.add(1, 2).unwrap()); + assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).unwrap()); + + info!("Dropping client."); + drop(client); + let (tx, rx) = ::std::sync::mpsc::channel(); + let (tx2, rx2) = ::std::sync::mpsc::channel(); + let shutdown2 = shutdown.clone(); + ::std::thread::spawn(move || { + let mut client = get_sync_client::(addr).unwrap(); + tx.send(()).unwrap(); + let add = client.add(3, 2).unwrap(); + drop(client); + // Make sure 2 shutdowns are concurrent safe. + shutdown2.shutdown().wait().unwrap(); + tx2.send(add).unwrap(); + }); + rx.recv().unwrap(); + shutdown.shutdown().wait().unwrap(); + // Existing clients are served + assert_eq!(5, rx2.recv().unwrap()); + + let e = get_sync_client::(addr).err().unwrap(); + debug!("(Success) shutdown caused client err: {}", e); + } + + #[test] + fn no_shutdown() { + let _ = env_logger::init(); + let (addr, mut client, shutdown) = + unwrap!(start_server_with_sync_client::(Server)); + assert_eq!(3, client.add(1, 2).unwrap()); + assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).unwrap()); + + drop(shutdown); + + // Existing clients are served. + assert_eq!(3, client.add(1, 2).unwrap()); + // New connections are accepted. + assert!(get_sync_client::(addr).is_ok()); + } + #[test] fn other_service() { let _ = env_logger::init(); - let (_, mut client) = + let (_, mut client, _) = unwrap!(start_server_with_sync_client::(Server)); match client.foo().err().expect("failed unwrap") { @@ -1038,7 +1125,8 @@ mod functional_test { } mod future { - use super::{FutureClient, FutureService, env_logger, start_server_with_async_client}; + use super::{FutureClient, FutureService, env_logger, get_future_client, return_server, + start_server_with_async_client}; use futures::{Finished, finished}; use tokio_core::reactor; use util::Never; @@ -1070,6 +1158,31 @@ mod functional_test { reactor.run(client.hey("Tim".to_string())).unwrap()); } + #[test] + fn shutdown() { + use futures::Future; + use tokio_core::reactor; + + let _ = env_logger::init(); + let (handle, mut reactor, server) = unwrap!(return_server::(Server)); + + let (tx, rx) = ::std::sync::mpsc::channel(); + ::std::thread::spawn(move || { + let mut reactor = reactor::Core::new().unwrap(); + let client = get_future_client::(handle.addr(), reactor.handle()); + let client = reactor.run(client).unwrap(); + let add = reactor.run(client.add(3, 2)).unwrap(); + assert_eq!(add, 5); + trace!("Dropping client."); + drop(reactor); + debug!("Shutting down..."); + handle.shutdown().shutdown().wait().unwrap(); + tx.send(add).unwrap(); + }); + reactor.run(server).unwrap(); + assert_eq!(rx.recv().unwrap(), 5); + } + #[test] fn concurrent() { let _ = env_logger::init(); @@ -1103,11 +1216,12 @@ mod functional_test { let _ = env_logger::init(); let reactor = reactor::Core::new().unwrap(); - let (addr, _) = Server.listen("localhost:0".first_socket_addr(), + let handle = Server.listen("localhost:0".first_socket_addr(), &reactor.handle(), server::Options::default()) - .unwrap(); - Server.listen(addr, &reactor.handle(), server::Options::default()).unwrap(); + .unwrap() + .0; + Server.listen(handle.addr(), &reactor.handle(), server::Options::default()).unwrap(); } #[test] @@ -1119,22 +1233,20 @@ mod functional_test { let _ = env_logger::init(); let mut reactor = reactor::Core::new().unwrap(); - let (addr, server) = Server.listen("localhost:0".first_socket_addr(), + let (handle, server) = Server.listen("localhost:0".first_socket_addr(), &reactor.handle(), server::Options::default()) .unwrap(); reactor.handle().spawn(server); - let client = FutureClient::connect(addr, - client::Options::default() - .handle(reactor.handle())); + let client = FutureClient::connect(handle.addr(), + client::Options::default().handle(reactor.handle())); let client = unwrap!(reactor.run(client)); assert_eq!(reactor.run(client.add(1, 2)).unwrap(), 3); drop(client); - let client = FutureClient::connect(addr, - client::Options::default() - .handle(reactor.handle())); + let client = FutureClient::connect(handle.addr(), + client::Options::default().handle(reactor.handle())); let client = unwrap!(reactor.run(client)); assert_eq!(reactor.run(client.add(1, 2)).unwrap(), 3); } @@ -1154,13 +1266,13 @@ mod functional_test { assert_eq!("Hey, Tim.", reactor.run(client.hey("Tim".to_string())).unwrap()); - let (addr, server) = Server.listen("localhost:0".first_socket_addr(), + let (handle, server) = Server.listen("localhost:0".first_socket_addr(), &reactor.handle(), server::Options::default()) .unwrap(); reactor.handle().spawn(server); let options = client::Options::default().handle(reactor.handle()); - let client = reactor.run(FutureClient::connect(addr, options)).unwrap(); + let client = reactor.run(FutureClient::connect(handle.addr(), options)).unwrap(); assert_eq!(3, reactor.run(client.add(1, 2)).unwrap()); assert_eq!("Hey, Tim.", reactor.run(client.hey("Tim".to_string())).unwrap()); diff --git a/src/server.rs b/src/server.rs index 4aa4c67..d6f8ec3 100644 --- a/src/server.rs +++ b/src/server.rs @@ -5,17 +5,21 @@ use bincode; use errors::WireError; -use futures::{Future, Poll, Stream, future, stream}; +use futures::{Future, Poll, Stream, future as futures, stream}; +use futures::sync::{mpsc, oneshot}; +use futures::unsync; use net2; use protocol::Proto; use serde::{Deserialize, Serialize}; +use std::cell::Cell; use std::io; use std::net::SocketAddr; +use std::rc::Rc; use tokio_core::io::Io; use tokio_core::net::{Incoming, TcpListener, TcpStream}; use tokio_core::reactor; use tokio_proto::BindServer; -use tokio_service::NewService; +use tokio_service::{NewService, Service}; cfg_if! { if #[cfg(feature = "tls")] { @@ -33,30 +37,30 @@ enum Acceptor { } #[cfg(feature = "tls")] -type Accept = future::Either, - fn(TlsStream) -> StreamType>, - fn(native_tls::Error) -> io::Error>, - future::FutureResult>; +type Accept = futures::Either, + fn(TlsStream) -> StreamType>, + fn(native_tls::Error) -> io::Error>, + futures::FutureResult>; #[cfg(not(feature = "tls"))] -type Accept = future::FutureResult; +type Accept = futures::FutureResult; impl Acceptor { #[cfg(feature = "tls")] fn accept(&self, socket: TcpStream) -> Accept { match *self { Acceptor::Tls(ref tls_acceptor) => { - future::Either::A(tls_acceptor.accept_async(socket) + futures::Either::A(tls_acceptor.accept_async(socket) .map(StreamType::Tls as _) .map_err(native_to_io)) } - Acceptor::Tcp => future::Either::B(future::ok(StreamType::Tcp(socket))), + Acceptor::Tcp => futures::Either::B(futures::ok(StreamType::Tcp(socket))), } } #[cfg(not(feature = "tls"))] fn accept(&self, socket: TcpStream) -> Accept { - future::ok(socket) + futures::ok(socket) } } @@ -119,51 +123,406 @@ impl Options { #[doc(hidden)] pub type Response = Result>; -#[doc(hidden)] -pub fn listen(new_service: S, - addr: SocketAddr, - handle: &reactor::Handle, - options: Options) - -> io::Result<(SocketAddr, Listen)> - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: Deserialize + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static -{ - listen_with(new_service, addr, handle, Acceptor::from(options)) +/// A hook to shut down a running server. +#[derive(Clone)] +pub struct Shutdown { + tx: mpsc::UnboundedSender>, } -/// A handle to a bound server. Must be run to start serving requests. -#[must_use = "A server does nothing until `run` is called."] -pub struct Handle { - reactor: reactor::Core, - addr: SocketAddr, +/// A future that resolves when server shutdown completes. +pub struct ShutdownFuture { + inner: futures::Either, + futures::OrElse, Result<(), ()>, AlwaysOk>>, } -impl Handle { - #[doc(hidden)] - pub fn new(reactor: reactor::Core, addr: SocketAddr) -> Self { - Handle { - reactor: reactor, - addr: addr, +impl Future for ShutdownFuture { + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll<(), ()> { + self.inner.poll() + } +} + +impl Shutdown { + /// Initiates an orderly server shutdown. + /// + /// First, the server enters lameduck mode, in which + /// existing connections are honored but no new connections are accepted. Then, once all + /// connections are closed, it initates total shutdown. + /// + /// This fn will not return until the server is completely shut down. + pub fn shutdown(self) -> ShutdownFuture { + let (tx, rx) = oneshot::channel(); + let inner = if let Err(_) = self.tx.send(tx) { + trace!("Server already initiated shutdown."); + futures::Either::A(futures::ok(())) + } else { + futures::Either::B(rx.or_else(AlwaysOk)) + }; + ShutdownFuture { inner: inner } + } +} + +enum ConnectionAction { + Increment, + Decrement, +} + +#[derive(Clone)] +struct ConnectionTracker { + tx: unsync::mpsc::UnboundedSender, +} + +impl ConnectionTracker { + fn increment(&self) { + let _ = self.tx.send(ConnectionAction::Increment); + } + + fn decrement(&self) { + debug!("Closing connection"); + let _ = self.tx.send(ConnectionAction::Decrement); + } +} + +struct ConnectionTrackingService { + service: S, + tracker: ConnectionTracker, +} + +impl Service for ConnectionTrackingService { + type Request = S::Request; + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn call(&self, req: Self::Request) -> Self::Future { + trace!("Calling service."); + self.service.call(req) + } +} + +impl Drop for ConnectionTrackingService { + fn drop(&mut self) { + debug!("Dropping ConnnectionTrackingService."); + self.tracker.decrement(); + } +} + +struct ConnectionTrackingNewService { + new_service: S, + connection_tracker: ConnectionTracker, +} + +impl NewService for ConnectionTrackingNewService { + type Request = S::Request; + type Response = S::Response; + type Error = S::Error; + type Instance = ConnectionTrackingService; + + fn new_service(&self) -> io::Result { + self.connection_tracker.increment(); + Ok(ConnectionTrackingService { + service: self.new_service.new_service()?, + tracker: self.connection_tracker.clone(), + }) + } +} + +/// Future-specific server utilities. +pub mod future { + pub use super::*; + + /// A handle to a bound server. + #[derive(Clone)] + pub struct Handle { + addr: SocketAddr, + shutdown: Shutdown, + } + + impl Handle { + #[doc(hidden)] + pub fn listen(new_service: S, + addr: SocketAddr, + handle: &reactor::Handle, + options: Options) + -> io::Result<(Self, Listen)> + where S: NewService, + Response = Response, + Error = io::Error> + 'static, + Req: Deserialize + 'static, + Resp: Serialize + 'static, + E: Serialize + 'static + { + let (addr, shutdown, server) = + listen_with(new_service, addr, handle, Acceptor::from(options))?; + Ok((Handle { + addr: addr, + shutdown: shutdown, + }, + server)) + } + + /// Returns a hook for shutting down the server. + pub fn shutdown(&self) -> Shutdown { + self.shutdown.clone() + } + + /// The socket address the server is bound to. + pub fn addr(&self) -> SocketAddr { + self.addr } } +} - /// Runs the server on the current thread, blocking indefinitely. - pub fn run(&mut self) -> ! { - loop { - self.reactor.turn(None) +/// Sync-specific server utilities. +pub mod sync { + pub use super::*; + + /// A handle to a bound server. Must be run to start serving requests. + #[must_use = "A server does nothing until `run` is called."] + pub struct Handle { + reactor: reactor::Core, + handle: future::Handle, + server: Box>, + } + + impl Handle { + #[doc(hidden)] + pub fn listen(new_service: S, + addr: SocketAddr, + options: Options) + -> io::Result + where S: NewService, + Response = Response, + Error = io::Error> + 'static, + Req: Deserialize + 'static, + Resp: Serialize + 'static, + E: Serialize + 'static + { + let reactor = reactor::Core::new()?; + let (handle, server) = + future::Handle::listen(new_service, addr, &reactor.handle(), options)?; + let server = Box::new(server); + Ok(Handle { + reactor: reactor, + handle: handle, + server: server, + }) + } + + /// Runs the server on the current thread, blocking indefinitely. + pub fn run(mut self) { + trace!("Running..."); + match self.reactor.run(self.server) { + Ok(()) => debug!("Server successfully shutdown."), + Err(()) => debug!("Server shutdown due to error."), + } + } + + /// Returns a hook for shutting down the server. + pub fn shutdown(&self) -> Shutdown { + self.handle.shutdown().clone() + } + + /// The socket address the server is bound to. + pub fn addr(&self) -> SocketAddr { + self.handle.addr() } } +} - /// The socket address the server is bound to. - pub fn addr(&self) -> SocketAddr { - self.addr +struct ShutdownSetter { + shutdown: Rc>>>, +} + +impl FnOnce<(oneshot::Sender<()>,)> for ShutdownSetter { + type Output = (); + + extern "rust-call" fn call_once(self, tx: (oneshot::Sender<()>,)) { + self.call(tx); } } +impl FnMut<(oneshot::Sender<()>,)> for ShutdownSetter { + extern "rust-call" fn call_mut(&mut self, tx: (oneshot::Sender<()>,)) { + self.call(tx); + } +} + +impl Fn<(oneshot::Sender<()>,)> for ShutdownSetter { + extern "rust-call" fn call(&self, (tx,): (oneshot::Sender<()>,)) { + debug!("Received shutdown request."); + self.shutdown.set(Some(tx)); + } +} + +struct ConnectionWatcher { + connections: Rc>, +} + +impl FnOnce<(ConnectionAction,)> for ConnectionWatcher { + type Output = (); + + extern "rust-call" fn call_once(self, action: (ConnectionAction,)) { + self.call(action); + } +} + +impl FnMut<(ConnectionAction,)> for ConnectionWatcher { + extern "rust-call" fn call_mut(&mut self, action: (ConnectionAction,)) { + self.call(action); + } +} + +impl Fn<(ConnectionAction,)> for ConnectionWatcher { + extern "rust-call" fn call(&self, (action,): (ConnectionAction,)) { + match action { + ConnectionAction::Increment => self.connections.set(self.connections.get() + 1), + ConnectionAction::Decrement => self.connections.set(self.connections.get() - 1), + } + trace!("Open connections: {}", self.connections.get()); + } +} + +struct ShutdownPredicate { + shutdown: Rc>>>, + connections: Rc>, +} + +impl FnOnce for ShutdownPredicate { + type Output = Result; + + extern "rust-call" fn call_once(self, arg: T) -> Self::Output { + self.call(arg) + } +} + +impl FnMut for ShutdownPredicate { + extern "rust-call" fn call_mut(&mut self, arg: T) -> Self::Output { + self.call(arg) + } +} + +impl Fn for ShutdownPredicate { + extern "rust-call" fn call(&self, _: T) -> Self::Output { + match self.shutdown.take() { + Some(shutdown) => { + let num_connections = self.connections.get(); + debug!("Lameduck mode: {} open connections", num_connections); + if num_connections == 0 { + debug!("Shutting down."); + let _ = shutdown.complete(()); + Ok(false) + } else { + self.shutdown.set(Some(shutdown)); + Ok(true) + } + } + None => Ok(true), + } + } +} + +struct Warn(&'static str); + +impl FnOnce for Warn { + type Output = (); + + extern "rust-call" fn call_once(self, arg: T) -> Self::Output { + self.call(arg) + } +} + +impl FnMut for Warn { + extern "rust-call" fn call_mut(&mut self, arg: T) -> Self::Output { + self.call(arg) + } +} + +impl Fn for Warn { + extern "rust-call" fn call(&self, _: T) -> Self::Output { + warn!("{}", self.0) + } +} + +struct AlwaysOk; + +impl FnOnce for AlwaysOk { + type Output = Result<(), ()>; + + extern "rust-call" fn call_once(self, arg: T) -> Self::Output { + self.call(arg) + } +} + +impl FnMut for AlwaysOk { + extern "rust-call" fn call_mut(&mut self, arg: T) -> Self::Output { + self.call(arg) + } +} + +impl Fn for AlwaysOk { + extern "rust-call" fn call(&self, _: T) -> Self::Output { + Ok(()) + } +} + +type ShutdownStream = stream::Map>>, + ShutdownSetter>; + +type ConnectionStream = stream::Map, + ConnectionWatcher>; + +struct ShutdownWatcher { + inner: stream::ForEach, + ShutdownPredicate, + Result>, + Warn>, + AlwaysOk, + Result<(), ()>>, +} + +impl Future for ShutdownWatcher { + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll<(), ()> { + self.inner.poll() + } +} + +/// Creates a future that completes when a shutdown is signaled and no connections are open. +fn shutdown_watcher() -> (ConnectionTracker, Shutdown, ShutdownWatcher) { + let (shutdown_tx, shutdown_rx) = mpsc::unbounded::>(); + let (connection_tx, connection_rx) = unsync::mpsc::unbounded(); + let shutdown = Rc::new(Cell::new(None)); + let connections = Rc::new(Cell::new(0)); + let shutdown2 = shutdown.clone(); + let connections2 = connections.clone(); + + let inner = shutdown_rx.take(1) + .map(ShutdownSetter { shutdown: shutdown }) + .merge(connection_rx.map(ConnectionWatcher { connections: connections })) + .take_while(ShutdownPredicate { + shutdown: shutdown2, + connections: connections2, + }) + .map_err(Warn("UnboundedReceiver resolved to an Err; can it do that?")) + .for_each(AlwaysOk); + + (ConnectionTracker { tx: connection_tx }, + Shutdown { tx: shutdown_tx }, + ShutdownWatcher { inner: inner }) +} + +type AcceptStream = stream::AndThen; + +type BindStream = stream::ForEach>, + io::Result<()>>; + /// The future representing a running server. #[doc(hidden)] pub struct Listen @@ -174,10 +533,10 @@ pub struct Listen Resp: Serialize + 'static, E: Serialize + 'static { - inner: future::MapErr, - Bind, - io::Result<()>>, - fn(io::Error)>, + inner: futures::Then, fn(io::Error)>, + ShutdownWatcher>, + Result<(), ()>, + AlwaysOk>, } impl Future for Listen @@ -201,7 +560,7 @@ fn listen_with(new_service: S, addr: SocketAddr, handle: &reactor::Handle, acceptor: Acceptor) - -> io::Result<(SocketAddr, Listen)> + -> io::Result<(SocketAddr, Shutdown, Listen)> where S: NewService, Response = Response, Error = io::Error> + 'static, @@ -215,14 +574,20 @@ fn listen_with(new_service: S, let handle = handle.clone(); - let inner = listener.incoming() + let (connection_tracker, shutdown, shutdown_future) = shutdown_watcher(); + let server = listener.incoming() .and_then(acceptor) .for_each(Bind { handle: handle, - new_service: new_service, + new_service: ConnectionTrackingNewService { + connection_tracker: connection_tracker, + new_service: new_service, + }, }) .map_err(log_err as _); - Ok((addr, Listen { inner: inner })) + + let server = server.select(shutdown_future).then(AlwaysOk); + Ok((addr, shutdown, Listen { inner: server })) } fn log_err(e: io::Error) {