mirror of
https://github.com/OMGeeky/tarpc.git
synced 2025-12-30 16:18:56 +01:00
Merge branch 'orderly-shutdown' into 'master'
Orderly shutdown of serving threads when calling ServeHandle::shutdown. Use condition variable server-side. Change client-side reader thread to empty the sender hashmap before returning. See merge request !1
This commit is contained in:
@@ -8,3 +8,4 @@ serde = "*"
|
||||
bincode = "*"
|
||||
serde_macros = "*"
|
||||
log = "*"
|
||||
env_logger = "*"
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
//! }
|
||||
//!
|
||||
//! use self::my_server::*;
|
||||
//! use std::time::Duration;
|
||||
//!
|
||||
//! impl my_server::Service for () {
|
||||
//! fn hello(&self, s: String) -> String {
|
||||
@@ -30,8 +31,9 @@
|
||||
//!
|
||||
//! fn main() {
|
||||
//! let addr = "127.0.0.1:9000";
|
||||
//! let shutdown = my_server::serve(addr, ()).unwrap();
|
||||
//! let client = Client::new(addr).unwrap();
|
||||
//! let shutdown = my_server::serve(addr, (),
|
||||
//! Some(Duration::from_secs(30))).unwrap();
|
||||
//! let client = Client::new(addr, None).unwrap();
|
||||
//! assert_eq!(3, client.add(1, 2).unwrap());
|
||||
//! assert_eq!("Hello, Mom!".to_string(),
|
||||
//! client.hello("Mom".to_string()).unwrap());
|
||||
|
||||
@@ -49,15 +49,15 @@ macro_rules! request_variant {
|
||||
|
||||
// The main macro that creates RPC services.
|
||||
#[macro_export]
|
||||
macro_rules! rpc {
|
||||
macro_rules! rpc {
|
||||
(
|
||||
mod $server:ident {
|
||||
|
||||
service {
|
||||
$(
|
||||
$(
|
||||
$(#[$attr:meta])*
|
||||
rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty;
|
||||
)*
|
||||
)*
|
||||
}
|
||||
}
|
||||
) => {
|
||||
@@ -66,7 +66,7 @@ macro_rules! rpc {
|
||||
|
||||
items { }
|
||||
|
||||
service {
|
||||
service {
|
||||
$(
|
||||
$(#[$attr])*
|
||||
rpc $fn_name($($arg: $in_),*) -> $out;
|
||||
@@ -125,10 +125,11 @@ macro_rules! rpc {
|
||||
|
||||
impl Client {
|
||||
#[doc="Create a new client that connects to the given address."]
|
||||
pub fn new<A>(addr: A) -> $crate::Result<Self>
|
||||
pub fn new<A>(addr: A, timeout: ::std::option::Option<::std::time::Duration>)
|
||||
-> $crate::Result<Self>
|
||||
where A: ::std::net::ToSocketAddrs,
|
||||
{
|
||||
let inner = try!($crate::protocol::Client::new(addr));
|
||||
let inner = try!($crate::protocol::Client::new(addr, timeout));
|
||||
Ok(Client(inner))
|
||||
}
|
||||
|
||||
@@ -137,9 +138,11 @@ macro_rules! rpc {
|
||||
|
||||
struct __Server<S: 'static + Service>(S);
|
||||
|
||||
impl<S> $crate::protocol::Serve<__Request, __Reply> for __Server<S>
|
||||
impl<S> $crate::protocol::Serve for __Server<S>
|
||||
where S: 'static + Service
|
||||
{
|
||||
type Request = __Request;
|
||||
type Reply = __Reply;
|
||||
fn serve(&self, request: __Request) -> __Reply {
|
||||
match request {
|
||||
$(
|
||||
@@ -151,12 +154,15 @@ macro_rules! rpc {
|
||||
}
|
||||
|
||||
#[doc="Start a running service."]
|
||||
pub fn serve<A, S>(addr: A, service: S) -> $crate::Result<$crate::protocol::ServeHandle>
|
||||
pub fn serve<A, S>(addr: A,
|
||||
service: S,
|
||||
read_timeout: ::std::option::Option<::std::time::Duration>)
|
||||
-> $crate::Result<$crate::protocol::ServeHandle>
|
||||
where A: ::std::net::ToSocketAddrs,
|
||||
S: 'static + Service
|
||||
{
|
||||
let server = ::std::sync::Arc::new(__Server(service));
|
||||
Ok(try!($crate::protocol::serve_async(addr, server)))
|
||||
Ok(try!($crate::protocol::serve_async(addr, server, read_timeout)))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -165,6 +171,12 @@ macro_rules! rpc {
|
||||
#[cfg(test)]
|
||||
#[allow(dead_code)]
|
||||
mod test {
|
||||
use std::time::Duration;
|
||||
|
||||
fn test_timeout() -> Option<Duration> {
|
||||
Some(Duration::from_secs(5))
|
||||
}
|
||||
|
||||
rpc! {
|
||||
mod my_server {
|
||||
items {
|
||||
@@ -197,8 +209,8 @@ mod test {
|
||||
fn simple_test() {
|
||||
println!("Starting");
|
||||
let addr = "127.0.0.1:9000";
|
||||
let shutdown = my_server::serve(addr, ()).unwrap();
|
||||
let client = Client::new(addr).unwrap();
|
||||
let shutdown = my_server::serve(addr, (), test_timeout()).unwrap();
|
||||
let client = Client::new(addr, None).unwrap();
|
||||
assert_eq!(3, client.add(1, 2).unwrap());
|
||||
let foo = Foo { message: "Adam".into() };
|
||||
let want = Foo { message: format!("Hello, {}", &foo.message) };
|
||||
|
||||
@@ -6,8 +6,10 @@ use std::convert;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
use std::net::{TcpListener, TcpStream, SocketAddr, ToSocketAddrs};
|
||||
use std::sync::{self, Mutex, Arc};
|
||||
use std::sync::{self, Arc, Condvar, Mutex};
|
||||
use std::sync::mpsc::{channel, Sender, TryRecvError};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::Duration;
|
||||
use std::thread::{self, JoinHandle};
|
||||
|
||||
/// Client errors that can occur during rpc calls
|
||||
@@ -23,6 +25,10 @@ pub enum Error {
|
||||
/// Channels are used for the client's inter-thread communication. This message is
|
||||
/// propagated if the receiver unexpectedly hangs up.
|
||||
Sender,
|
||||
/// An internal message failed to be received.
|
||||
/// Channels are used for the client's inter-thread communication. This message is
|
||||
/// propagated if the sender unexpectedly hangs up.
|
||||
Receiver,
|
||||
/// The server hung up.
|
||||
ConnectionBroken,
|
||||
}
|
||||
@@ -57,47 +63,146 @@ impl<T> convert::From<sync::mpsc::SendError<T>> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl convert::From<sync::mpsc::RecvError> for Error {
|
||||
fn from(_: sync::mpsc::RecvError) -> Error {
|
||||
Error::Receiver
|
||||
}
|
||||
}
|
||||
|
||||
/// Return type of rpc calls: either the successful return value, or a client error.
|
||||
pub type Result<T> = ::std::result::Result<T, Error>;
|
||||
|
||||
fn handle_conn<F, Request, Reply>(stream: TcpStream, f: F) -> Result<()>
|
||||
where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize,
|
||||
Reply: 'static + fmt::Debug + serde::ser::Serialize,
|
||||
F: 'static + Clone + Serve<Request, Reply>
|
||||
{
|
||||
let mut read_stream = try!(stream.try_clone());
|
||||
let stream = Arc::new(Mutex::new(stream));
|
||||
loop {
|
||||
let request_packet: Packet<Request> =
|
||||
try!(bincode::serde::deserialize_from(&mut read_stream, bincode::SizeLimit::Infinite));
|
||||
match request_packet {
|
||||
Packet::Shutdown => {
|
||||
let stream = stream.clone();
|
||||
let mut my_stream = stream.lock().unwrap();
|
||||
try!(bincode::serde::serialize_into(&mut *my_stream,
|
||||
&request_packet,
|
||||
bincode::SizeLimit::Infinite));
|
||||
break;
|
||||
}
|
||||
Packet::Message(id, message) => {
|
||||
let f = f.clone();
|
||||
let arc_stream = stream.clone();
|
||||
thread::spawn(move || {
|
||||
let reply = f.serve(message);
|
||||
let reply_packet = Packet::Message(id, reply);
|
||||
let mut my_stream = arc_stream.lock().unwrap();
|
||||
bincode::serde::serialize_into(&mut *my_stream,
|
||||
&reply_packet,
|
||||
bincode::SizeLimit::Infinite)
|
||||
.unwrap();
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
#[derive(Clone)]
|
||||
struct InflightRpcs {
|
||||
inflight_rpcs: Arc<(Mutex<u64>, Condvar)>,
|
||||
}
|
||||
|
||||
/// Provides methods for blocking until the server completes,
|
||||
impl InflightRpcs {
|
||||
fn new(mutex: Mutex<u64>, cvar: Condvar) -> InflightRpcs {
|
||||
InflightRpcs { inflight_rpcs: Arc::new((mutex, cvar)) }
|
||||
}
|
||||
|
||||
fn wait_until_zero(&self) {
|
||||
let &(ref count, ref cvar) = &*self.inflight_rpcs;
|
||||
let mut count = count.lock().unwrap();
|
||||
while *count != 0 {
|
||||
count = cvar.wait(count).unwrap();
|
||||
}
|
||||
info!("serve_async: shutdown complete ({} connections alive)",
|
||||
*count);
|
||||
}
|
||||
|
||||
fn increment(&self) {
|
||||
let &(ref count, _) = &*self.inflight_rpcs;
|
||||
*count.lock().unwrap() += 1;
|
||||
}
|
||||
|
||||
fn decrement(&self) {
|
||||
let &(ref count, _) = &*self.inflight_rpcs;
|
||||
*count.lock().unwrap() -= 1;
|
||||
}
|
||||
|
||||
|
||||
fn decrement_and_notify(&self) {
|
||||
let &(ref count, ref cvar) = &*self.inflight_rpcs;
|
||||
*count.lock().unwrap() -= 1;
|
||||
cvar.notify_one();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
struct ConnectionHandler {
|
||||
read_stream: TcpStream,
|
||||
write_stream: Arc<Mutex<TcpStream>>,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
inflight_rpcs: InflightRpcs,
|
||||
timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
impl Drop for ConnectionHandler {
|
||||
fn drop(&mut self) {
|
||||
trace!("ConnectionHandler: finished serving client.");
|
||||
self.inflight_rpcs.decrement_and_notify();
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionHandler {
|
||||
fn read<Request>(&mut self) -> bincode::serde::DeserializeResult<Packet<Request>>
|
||||
where Request: serde::de::Deserialize
|
||||
{
|
||||
try!(self.read_stream.set_read_timeout(self.timeout));
|
||||
bincode::serde::deserialize_from(&mut self.read_stream, bincode::SizeLimit::Infinite)
|
||||
}
|
||||
|
||||
fn handle_conn<F>(&mut self, f: F) -> Result<()>
|
||||
where F: 'static + Clone + Serve
|
||||
{
|
||||
trace!("ConnectionHandler: serving client...");
|
||||
loop {
|
||||
match self.read() {
|
||||
Ok(Packet {
|
||||
rpc_id: id,
|
||||
message: message
|
||||
}) => {
|
||||
let f = f.clone();
|
||||
let inflight_rpcs = self.inflight_rpcs.clone();
|
||||
inflight_rpcs.increment();
|
||||
let stream = self.write_stream.clone();
|
||||
thread::spawn(move || {
|
||||
let reply = f.serve(message);
|
||||
let reply_packet = Packet {
|
||||
rpc_id: id,
|
||||
message: reply
|
||||
};
|
||||
let mut stream = stream.lock().unwrap();
|
||||
if let Err(e) =
|
||||
bincode::serde::serialize_into(&mut *stream,
|
||||
&reply_packet,
|
||||
bincode::SizeLimit::Infinite) {
|
||||
warn!("ConnectionHandler: failed to write reply to Client: {:?}",
|
||||
e);
|
||||
}
|
||||
inflight_rpcs.decrement();
|
||||
});
|
||||
if self.shutdown.load(Ordering::SeqCst) {
|
||||
info!("ConnectionHandler: server shutdown, so closing connection.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(bincode::serde::DeserializeError::IoError(ref err))
|
||||
if Self::timed_out(err.kind()) => {
|
||||
if !self.shutdown.load(Ordering::SeqCst) {
|
||||
info!("ConnectionHandler: read timed out ({:?}). Server not shutdown, so \
|
||||
retrying read.",
|
||||
err);
|
||||
continue;
|
||||
} else {
|
||||
info!("ConnectionHandler: read timed out ({:?}). Server shutdown, so \
|
||||
closing connection.",
|
||||
err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("ConnectionHandler: closing client connection due to error while \
|
||||
serving: {:?}",
|
||||
e);
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn timed_out(error_kind: io::ErrorKind) -> bool {
|
||||
match error_kind {
|
||||
io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides methods for blocking until the server completes,
|
||||
pub struct ServeHandle {
|
||||
tx: Sender<()>,
|
||||
join_handle: JoinHandle<()>,
|
||||
@@ -118,47 +223,64 @@ impl ServeHandle {
|
||||
/// Shutdown the server. Gracefully shuts down the serve thread but currently does not
|
||||
/// gracefully close open connections.
|
||||
pub fn shutdown(self) {
|
||||
info!("ServeHandle: attempting to shut down the server.");
|
||||
self.tx.send(()).expect(&line!().to_string());
|
||||
if let Ok(_) = TcpStream::connect(self.addr) {
|
||||
self.join_handle.join().expect(&line!().to_string());
|
||||
} else {
|
||||
warn!("Best effort shutdown of serve thread failed");
|
||||
warn!("ServeHandle: best effort shutdown of serve thread failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Start
|
||||
pub fn serve_async<A, F, Request, Reply>(addr: A, f: F) -> io::Result<ServeHandle>
|
||||
/// Start
|
||||
pub fn serve_async<A, F>(addr: A, f: F, read_timeout: Option<Duration>) -> io::Result<ServeHandle>
|
||||
where A: ToSocketAddrs,
|
||||
Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize,
|
||||
Reply: 'static + fmt::Debug + serde::ser::Serialize,
|
||||
F: 'static + Clone + Send + Serve<Request, Reply>
|
||||
F: 'static + Clone + Send + Serve
|
||||
{
|
||||
let listener = try!(TcpListener::bind(&addr));
|
||||
let addr = try!(listener.local_addr());
|
||||
info!("Spinning up server on {:?}", addr);
|
||||
info!("serve_async: spinning up server on {:?}", addr);
|
||||
let (die_tx, die_rx) = channel();
|
||||
let join_handle = thread::spawn(move || {
|
||||
let shutdown = Arc::new(AtomicBool::new(false));
|
||||
let inflight_rpcs = InflightRpcs::new(Mutex::new(0), Condvar::new());
|
||||
for conn in listener.incoming() {
|
||||
match die_rx.try_recv() {
|
||||
Ok(_) => break,
|
||||
Ok(_) => {
|
||||
info!("serve_async: shutdown received. Waiting for open connections to \
|
||||
return...");
|
||||
shutdown.store(true, Ordering::SeqCst);
|
||||
inflight_rpcs.wait_until_zero();
|
||||
break;
|
||||
}
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
info!("Sender disconnected.");
|
||||
info!("serve_async: sender disconnected.");
|
||||
break;
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
let conn = match conn {
|
||||
Err(err) => {
|
||||
error!("Failed to accept connection: {:?}", err);
|
||||
error!("serve_async: failed to accept connection: {:?}", err);
|
||||
return;
|
||||
}
|
||||
Ok(c) => c,
|
||||
};
|
||||
let f = f.clone();
|
||||
let shutdown = shutdown.clone();
|
||||
inflight_rpcs.increment();
|
||||
let inflight_rpcs = inflight_rpcs.clone();
|
||||
let mut handler = ConnectionHandler {
|
||||
read_stream: conn.try_clone().unwrap(),
|
||||
write_stream: Arc::new(Mutex::new(conn)),
|
||||
shutdown: shutdown,
|
||||
inflight_rpcs: inflight_rpcs,
|
||||
timeout: read_timeout,
|
||||
};
|
||||
thread::spawn(move || {
|
||||
if let Err(err) = handle_conn(conn, f) {
|
||||
error!("Error in connection handling: {:?}", err);
|
||||
if let Err(err) = handler.handle_conn(f) {
|
||||
error!("ConnectionHandler: error in connection handling: {:?}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -171,46 +293,74 @@ pub fn serve_async<A, F, Request, Reply>(addr: A, f: F) -> io::Result<ServeHandl
|
||||
}
|
||||
|
||||
/// A service provided by a server
|
||||
pub trait Serve<Request, Reply>: Send + Sync {
|
||||
pub trait Serve: Send + Sync {
|
||||
/// The type of request received by the server
|
||||
type Request: 'static + fmt::Debug + serde::ser::Serialize + serde::de::Deserialize + Send;
|
||||
/// The type of reply sent by the server
|
||||
type Reply: 'static + fmt::Debug + serde::ser::Serialize + serde::de::Deserialize;
|
||||
|
||||
/// Return a reply for a given request
|
||||
fn serve(&self, request: Request) -> Reply;
|
||||
fn serve(&self, request: Self::Request) -> Self::Reply;
|
||||
}
|
||||
|
||||
impl<Request, Reply, S> Serve<Request, Reply> for Arc<S>
|
||||
where S: Serve<Request, Reply>
|
||||
impl<S> Serve for Arc<S> where S: Serve
|
||||
{
|
||||
fn serve(&self, request: Request) -> Reply {
|
||||
type Request = S::Request;
|
||||
type Reply = S::Reply;
|
||||
|
||||
fn serve(&self, request: S::Request) -> S::Reply {
|
||||
S::serve(self, request)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
enum Packet<T> {
|
||||
Message(u64, T),
|
||||
Shutdown,
|
||||
struct Packet<T> {
|
||||
rpc_id: u64,
|
||||
message: T
|
||||
}
|
||||
|
||||
fn reader<Reply>(mut stream: TcpStream, requests: Arc<Mutex<HashMap<u64, Sender<Reply>>>>)
|
||||
where Reply: serde::Deserialize
|
||||
{
|
||||
loop {
|
||||
let packet: bincode::serde::DeserializeResult<Packet<Reply>> =
|
||||
bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite);
|
||||
match packet {
|
||||
Ok(Packet::Message(id, reply)) => {
|
||||
let mut requests = requests.lock().unwrap();
|
||||
let reply_tx = requests.remove(&id).unwrap();
|
||||
reply_tx.send(reply).unwrap();
|
||||
struct Reader<Reply> {
|
||||
requests: Arc<Mutex<Option<HashMap<u64, Sender<Reply>>>>>,
|
||||
}
|
||||
|
||||
impl<Reply> Reader<Reply> {
|
||||
fn read(self, mut stream: TcpStream)
|
||||
where Reply: serde::Deserialize
|
||||
{
|
||||
loop {
|
||||
let packet: bincode::serde::DeserializeResult<Packet<Reply>> =
|
||||
bincode::serde::deserialize_from(&mut stream, bincode::SizeLimit::Infinite);
|
||||
match packet {
|
||||
Ok(Packet {
|
||||
rpc_id: id,
|
||||
message: reply
|
||||
}) => {
|
||||
debug!("Client: received message, id={}", id);
|
||||
let mut requests = self.requests.lock().unwrap();
|
||||
let mut requests = requests.as_mut().unwrap();
|
||||
let reply_tx = requests.remove(&id).unwrap();
|
||||
reply_tx.send(reply).unwrap();
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Client: reader thread encountered an unexpected error while parsing; \
|
||||
returning now. Error: {:?}",
|
||||
err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Packet::Shutdown) => {
|
||||
break;
|
||||
}
|
||||
// TODO: This shutdown logic is janky.. What's the right way to do this?
|
||||
Err(err) => panic!("unexpected error while parsing!: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Reply> Drop for Reader<Reply> {
|
||||
fn drop(&mut self) {
|
||||
let mut guard = self.requests.lock().unwrap();
|
||||
guard.as_mut().unwrap().clear();
|
||||
// remove the hashmap so no one can put more senders and accidentally block
|
||||
guard.take();
|
||||
}
|
||||
}
|
||||
|
||||
fn increment(cur_id: &mut u64) -> u64 {
|
||||
let id = *cur_id;
|
||||
*cur_id += 1;
|
||||
@@ -227,8 +377,9 @@ pub struct Client<Request, Reply>
|
||||
where Request: serde::ser::Serialize
|
||||
{
|
||||
synced_state: Mutex<SyncedClientState>,
|
||||
requests: Arc<Mutex<HashMap<u64, Sender<Reply>>>>,
|
||||
requests: Arc<Mutex<Option<HashMap<u64, Sender<Reply>>>>>,
|
||||
reader_guard: Option<thread::JoinHandle<()>>,
|
||||
timeout: Option<Duration>,
|
||||
_request: PhantomData<Request>,
|
||||
}
|
||||
|
||||
@@ -237,12 +388,12 @@ impl<Request, Reply> Client<Request, Reply>
|
||||
Request: serde::ser::Serialize
|
||||
{
|
||||
/// Create a new client that connects to `addr`
|
||||
pub fn new<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
|
||||
pub fn new<A: ToSocketAddrs>(addr: A, timeout: Option<Duration>) -> io::Result<Self> {
|
||||
let stream = try!(TcpStream::connect(addr));
|
||||
let requests = Arc::new(Mutex::new(HashMap::new()));
|
||||
let requests = Arc::new(Mutex::new(Some(HashMap::new())));
|
||||
let reader_stream = try!(stream.try_clone());
|
||||
let reader_requests = requests.clone();
|
||||
let reader_guard = thread::spawn(move || reader(reader_stream, reader_requests));
|
||||
let reader = Reader { requests: requests.clone() };
|
||||
let reader_guard = thread::spawn(move || reader.read(reader_stream));
|
||||
Ok(Client {
|
||||
synced_state: Mutex::new(SyncedClientState {
|
||||
next_id: 0,
|
||||
@@ -250,6 +401,7 @@ impl<Request, Reply> Client<Request, Reply>
|
||||
}),
|
||||
requests: requests,
|
||||
reader_guard: Some(reader_guard),
|
||||
timeout: timeout,
|
||||
_request: PhantomData,
|
||||
})
|
||||
}
|
||||
@@ -262,21 +414,36 @@ impl<Request, Reply> Client<Request, Reply>
|
||||
let mut state = self.synced_state.lock().unwrap();
|
||||
let id = increment(&mut state.next_id);
|
||||
{
|
||||
let mut requests = self.requests.lock().unwrap();
|
||||
requests.insert(id, tx);
|
||||
if let Some(ref mut requests) = *self.requests.lock().unwrap() {
|
||||
requests.insert(id, tx);
|
||||
} else {
|
||||
return Err(Error::ConnectionBroken);
|
||||
}
|
||||
}
|
||||
let packet = Packet::Message(id, request);
|
||||
let packet = Packet {
|
||||
rpc_id: id,
|
||||
message: request,
|
||||
};
|
||||
try!(state.stream.set_write_timeout(self.timeout));
|
||||
try!(state.stream.set_read_timeout(self.timeout));
|
||||
debug!("Client: calling rpc({:?})", request);
|
||||
if let Err(err) = bincode::serde::serialize_into(&mut state.stream,
|
||||
&packet,
|
||||
bincode::SizeLimit::Infinite) {
|
||||
warn!("Failed to write client packet.\nPacket: {:?}\nError: {:?}",
|
||||
warn!("Client: failed to write packet.\nPacket: {:?}\nError: {:?}",
|
||||
packet,
|
||||
err);
|
||||
self.requests.lock().unwrap().remove(&id);
|
||||
if let Some(requests) = self.requests.lock().unwrap().as_mut() {
|
||||
requests.remove(&id);
|
||||
} else {
|
||||
warn!("Client: couldn't remove sender for request {} because reader thread \
|
||||
returned.",
|
||||
id);
|
||||
}
|
||||
return Err(err.into());
|
||||
}
|
||||
drop(state);
|
||||
Ok(rx.recv().unwrap())
|
||||
Ok(try!(rx.recv()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,24 +451,30 @@ impl<Request, Reply> Drop for Client<Request, Reply>
|
||||
where Request: serde::ser::Serialize
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
{
|
||||
let mut state = self.synced_state.lock().unwrap();
|
||||
let packet: Packet<Request> = Packet::Shutdown;
|
||||
if let Err(err) = bincode::serde::serialize_into(&mut state.stream,
|
||||
&packet,
|
||||
bincode::SizeLimit::Infinite) {
|
||||
warn!("While disconnecting client from server: {:?}", err);
|
||||
}
|
||||
if let Err(e) = self.synced_state
|
||||
.lock()
|
||||
.unwrap()
|
||||
.stream
|
||||
.shutdown(::std::net::Shutdown::Both) {
|
||||
warn!("Client: couldn't shutdown reader thread: {:?}", e);
|
||||
} else {
|
||||
self.reader_guard.take().unwrap().join().unwrap();
|
||||
}
|
||||
self.reader_guard.take().unwrap().join().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
extern crate env_logger;
|
||||
|
||||
use super::{Client, Serve, serve_async};
|
||||
use std::sync::{Arc, Mutex, Barrier};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
fn test_timeout() -> Option<Duration> {
|
||||
Some(Duration::from_millis(1))
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
|
||||
enum Request {
|
||||
@@ -317,7 +490,10 @@ mod test {
|
||||
counter: Mutex<u64>,
|
||||
}
|
||||
|
||||
impl Serve<Request, Reply> for Server {
|
||||
impl Serve for Server {
|
||||
type Request = Request;
|
||||
type Reply = Reply;
|
||||
|
||||
fn serve(&self, _: Request) -> Reply {
|
||||
let mut counter = self.counter.lock().unwrap();
|
||||
let reply = Reply::Increment(*counter);
|
||||
@@ -337,21 +513,23 @@ mod test {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle() {
|
||||
fn handle() {
|
||||
let _ = env_logger::init();
|
||||
let server = Arc::new(Server::new());
|
||||
let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap();
|
||||
let client: Client<Request, Reply> = Client::new(serve_handle.local_addr().clone())
|
||||
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
|
||||
let client: Client<Request, Reply> = Client::new(serve_handle.local_addr().clone(), None)
|
||||
.expect(&line!().to_string());
|
||||
drop(client);
|
||||
serve_handle.shutdown();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple() {
|
||||
fn simple() {
|
||||
let _ = env_logger::init();
|
||||
let server = Arc::new(Server::new());
|
||||
let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap();
|
||||
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
|
||||
let addr = serve_handle.local_addr().clone();
|
||||
let client = Client::new(addr).unwrap();
|
||||
let client = Client::new(addr, None).unwrap();
|
||||
assert_eq!(Reply::Increment(0),
|
||||
client.rpc(&Request::Increment).unwrap());
|
||||
assert_eq!(1, server.count());
|
||||
@@ -367,7 +545,9 @@ mod test {
|
||||
inner: Server,
|
||||
}
|
||||
|
||||
impl Serve<Request, Reply> for BarrierServer {
|
||||
impl Serve for BarrierServer {
|
||||
type Request = Request;
|
||||
type Reply = Reply;
|
||||
fn serve(&self, request: Request) -> Reply {
|
||||
self.barrier.wait();
|
||||
self.inner.serve(request)
|
||||
@@ -388,11 +568,36 @@ mod test {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent() {
|
||||
let server = Arc::new(BarrierServer::new(10));
|
||||
let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap();
|
||||
fn force_shutdown() {
|
||||
let _ = env_logger::init();
|
||||
let server = Arc::new(Server::new());
|
||||
let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap();
|
||||
let addr = serve_handle.local_addr().clone();
|
||||
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr).unwrap());
|
||||
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None).unwrap());
|
||||
let thread = thread::spawn(move || serve_handle.shutdown());
|
||||
info!("force_shutdown:: rpc1: {:?}", client.rpc(&Request::Increment));
|
||||
thread.join().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn client_failed_rpc() {
|
||||
let _ = env_logger::init();
|
||||
let server = Arc::new(Server::new());
|
||||
let serve_handle = serve_async("0.0.0.0:0", server, Some(Duration::new(0, 10))).unwrap();
|
||||
let addr = serve_handle.local_addr().clone();
|
||||
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None).unwrap());
|
||||
serve_handle.shutdown();
|
||||
let _ = client.rpc(&Request::Increment); // First failure will trigger reader to shutdown
|
||||
let _ = client.rpc(&Request::Increment); // Test whether second failure hangs
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concurrent() {
|
||||
let _ = env_logger::init();
|
||||
let server = Arc::new(BarrierServer::new(10));
|
||||
let serve_handle = serve_async("0.0.0.0:0", server.clone(), test_timeout()).unwrap();
|
||||
let addr = serve_handle.local_addr().clone();
|
||||
let client: Arc<Client<Request, Reply>> = Arc::new(Client::new(addr, None).unwrap());
|
||||
let mut join_handles = vec![];
|
||||
for _ in 0..10 {
|
||||
let my_client = client.clone();
|
||||
|
||||
Reference in New Issue
Block a user