diff --git a/src/lib.rs b/src/lib.rs index ae161ad..41d3fb0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,39 @@ -//! Provides the tarpc client and server, which implements the tarpc protocol. The protocol -//! is defined by the implementation. +//! An RPC library for Rust. +//! +//! Example usage: +//! +//! ``` +//! # #![feature(custom_derive)] +//! # #![feature(custom_derive, plugin)] +//! # #![plugin(serde_macros)] +//! # #[macro_use] extern crate tarpc; +//! # extern crate serde; +//! rpc_service!(my_server: +//! hello(name: String) -> String; +//! add(x: i32, y: i32) -> i32; +//! ); +//! +//! use self::my_server::*; +//! +//! impl my_server::Service for () { +//! fn hello(&self, s: String) -> String { +//! format!("Hello, {}!", s) +//! } +//! fn add(&self, x: i32, y: i32) -> i32 { +//! x + y +//! } +//! } +//! +//! fn main() { +//! let addr = "127.0.0.1:9000"; +//! let shutdown = my_server::serve(addr, ()).unwrap(); +//! let client = Client::new(addr).unwrap(); +//! assert_eq!(3, client.add(1, 2).unwrap()); +//! assert_eq!("Hello, Mom!".to_string(), client.hello("Mom".to_string()).unwrap()); +//! drop(client); +//! shutdown.shutdown(); +//! } +//! ``` #![feature(const_fn)] #![feature(custom_derive, plugin)] @@ -11,392 +45,9 @@ extern crate serde_json; #[macro_use] extern crate log; -use serde::Deserialize; -use std::fmt; -use std::io::{self, Read}; -use std::convert; -use std::collections::HashMap; -use std::net::{TcpListener, TcpStream, SocketAddr, ToSocketAddrs}; -use std::sync::{self, Mutex, Arc}; -use std::sync::mpsc::{channel, Sender, TryRecvError}; -use std::thread::{self, JoinHandle}; +/// Provides the tarpc client and server, which implements the tarpc protocol. The protocol +/// is defined by the implementation. +pub mod protocol; -/// Client errors that can occur during rpc calls -#[derive(Debug)] -pub enum Error { - /// An IO-related error - Io(io::Error), - /// An error in serialization or deserialization - Json(serde_json::Error), - /// An internal message failed to send. - /// Channels are used for the client's inter-thread communication. This message is - /// propagated if the receiver unexpectedly hangs up. - Sender, - /// The server hung up. - ConnectionBroken, -} - -impl convert::From for Error { - fn from(err: serde_json::Error) -> Error { - match err { - serde_json::Error::IoError(err) => Error::Io(err), - err => Error::Json(err), - } - } -} - -impl convert::From for Error { - fn from(err: io::Error) -> Error { - Error::Io(err) - } -} - -impl convert::From> for Error { - fn from(_: sync::mpsc::SendError) -> Error { - Error::Sender - } -} - -/// Return type of rpc calls: either the successful return value, or a client error. -pub type Result = std::result::Result; - -fn handle_conn(stream: TcpStream, f: F) -> Result<()> - where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, - Reply: 'static + fmt::Debug + serde::ser::Serialize, - F: 'static + Clone + Serve -{ - let read_stream = try!(stream.try_clone()); - let mut de = serde_json::Deserializer::new(read_stream.bytes()); - let stream = Arc::new(Mutex::new(stream)); - loop { - let request_packet: Packet = try!(Packet::deserialize(&mut de)); - match request_packet { - Packet::Shutdown => { - let stream = stream.clone(); - let mut my_stream = stream.lock().unwrap(); - try!(serde_json::to_writer(&mut *my_stream, &request_packet)); - break; - } - Packet::Message(id, message) => { - let f = f.clone(); - let arc_stream = stream.clone(); - thread::spawn(move || { - let reply = f.serve(message); - let reply_packet = Packet::Message(id, reply); - let mut my_stream = arc_stream.lock().unwrap(); - serde_json::to_writer(&mut *my_stream, &reply_packet).unwrap(); - }); - } - } - } - Ok(()) -} - -/// Provides methods for blocking until the server completes, -pub struct ServeHandle { - tx: Sender<()>, - join_handle: JoinHandle<()>, - addr: SocketAddr, -} - -impl ServeHandle { - /// Block until the server completes - pub fn wait(self) { - self.join_handle.join().unwrap(); - } - - /// Returns the address the server is bound to - pub fn local_addr(&self) -> &SocketAddr { - &self.addr - } - - /// Shutdown the server. Gracefully shuts down the serve thread but currently does not - /// gracefully close open connections. - pub fn shutdown(self) { - self.tx.send(()).expect(&line!().to_string()); - if let Ok(_) = TcpStream::connect(self.addr) { - self.join_handle.join().expect(&line!().to_string()); - } else { - warn!("Best effort shutdown of serve thread failed"); - } - } -} - -/// Start -pub fn serve_async(addr: A, f: F) -> io::Result - where A: ToSocketAddrs, - Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, - Reply: 'static + fmt::Debug + serde::ser::Serialize, - F: 'static + Clone + Send + Serve, -{ - let listener = try!(TcpListener::bind(&addr)); - let addr = try!(listener.local_addr()); - info!("Spinning up server on {:?}", addr); - let (die_tx, die_rx) = channel(); - let join_handle = thread::spawn(move || { - for conn in listener.incoming() { - match die_rx.try_recv() { - Ok(_) => break, - Err(TryRecvError::Disconnected) => { - info!("Sender disconnected."); - break; - } - _ => (), - } - let conn = match conn { - Err(err) => { - error!("Failed to accept connection: {:?}", err); - return; - } - Ok(c) => c, - }; - let f = f.clone(); - thread::spawn(move || { - if let Err(err) = handle_conn(conn, f) { - error!("Error in connection handling: {:?}", err); - } - }); - } - }); - Ok(ServeHandle { - tx: die_tx, - join_handle: join_handle, - addr: addr.clone(), - }) -} - -/// A service provided by a server -pub trait Serve: Send + Sync { - /// Return a reply for a given request - fn serve(&self, request: Request) -> Reply; -} - -impl Serve for Arc - where S: Serve -{ - fn serve(&self, request: Request) -> Reply { - S::serve(self, request) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -enum Packet { - Message(u64, T), - Shutdown, -} - -fn reader(stream: TcpStream, requests: Arc>>>) - where Reply: serde::Deserialize -{ - let mut de = serde_json::Deserializer::new(stream.bytes()); - loop { - match Packet::deserialize(&mut de) { - Ok(Packet::Message(id, reply)) => { - let mut requests = requests.lock().unwrap(); - let reply_tx = requests.remove(&id).unwrap(); - reply_tx.send(reply).unwrap(); - } - Ok(Packet::Shutdown) => { - break; - } - // TODO: This shutdown logic is janky.. What's the right way to do this? - Err(err) => panic!("unexpected error while parsing!: {:?}", err), - } - } -} - -fn increment(cur_id: &mut u64) -> u64 { - let id = *cur_id; - *cur_id += 1; - id -} - -struct SyncedClientState { - next_id: u64, - stream: TcpStream, -} - -/// A client stub that connects to a server to run rpcs. -pub struct Client - where Request: serde::ser::Serialize -{ - synced_state: Mutex, - requests: Arc>>>, - reader_guard: Option>, - _request: std::marker::PhantomData, -} - -impl Client - where Reply: serde::de::Deserialize + Send + 'static, - Request: serde::ser::Serialize -{ - /// Create a new client that connects to `addr` - pub fn new(addr: SocketAddr) -> io::Result { - let stream = try!(TcpStream::connect(addr)); - let requests = Arc::new(Mutex::new(HashMap::new())); - let reader_stream = try!(stream.try_clone()); - let reader_requests = requests.clone(); - let reader_guard = thread::spawn(move || reader(reader_stream, reader_requests)); - Ok(Client { - synced_state: Mutex::new(SyncedClientState { - next_id: 0, - stream: stream, - }), - requests: requests, - reader_guard: Some(reader_guard), - _request: std::marker::PhantomData, - }) - } - - /// Run the specified rpc method on the server this client is connected to - pub fn rpc(&self, request: &Request) -> Result - where Request: serde::ser::Serialize + std::fmt::Debug + Send + 'static - { - let (tx, rx) = channel(); - let mut state = self.synced_state.lock().unwrap(); - let id = increment(&mut state.next_id); - { - let mut requests = self.requests.lock().unwrap(); - requests.insert(id, tx); - } - let packet = Packet::Message(id, request); - if let Err(err) = serde_json::to_writer(&mut state.stream, &packet) { - warn!("Failed to write client packet.\nPacket: {:?}\nError: {:?}", - packet, - err); - self.requests.lock().unwrap().remove(&id); - return Err(err.into()); - } - drop(state); - Ok(rx.recv().unwrap()) - } -} - -impl Drop for Client - where Request: serde::ser::Serialize -{ - fn drop(&mut self) { - { - let mut state = self.synced_state.lock().unwrap(); - let packet: Packet = Packet::Shutdown; - if let Err(err) = serde_json::to_writer(&mut state.stream, &packet) { - warn!("While disconnecting client from server: {:?}", err); - } - } - self.reader_guard.take().unwrap().join().unwrap(); - } -} - -#[cfg(test)] -mod test { - use super::*; - use std::sync::{Arc, Mutex, Barrier}; - use std::thread; - - #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] - enum Request { - Increment, - } - - #[derive(Debug, PartialEq, Serialize, Deserialize)] - enum Reply { - Increment(u64), - } - - struct Server { - counter: Mutex, - } - - impl Serve for Server { - fn serve(&self, _: Request) -> Reply { - let mut counter = self.counter.lock().unwrap(); - let reply = Reply::Increment(*counter); - *counter += 1; - reply - } - } - - impl Server { - fn new() -> Server { - Server { counter: Mutex::new(0) } - } - - fn count(&self) -> u64 { - *self.counter.lock().unwrap() - } - } - - #[test] - fn test_handle() { - let server = Arc::new(Server::new()); - let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); - let client: Client = Client::new(serve_handle.local_addr().clone()) - .expect(&line!().to_string()); - drop(client); - serve_handle.shutdown(); - } - - #[test] - fn test_simple() { - let server = Arc::new(Server::new()); - let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); - let addr = serve_handle.local_addr().clone(); - let client = Client::new(addr).unwrap(); - assert_eq!(Reply::Increment(0), - client.rpc(&Request::Increment).unwrap()); - assert_eq!(1, server.count()); - assert_eq!(Reply::Increment(1), - client.rpc(&Request::Increment).unwrap()); - assert_eq!(2, server.count()); - drop(client); - serve_handle.shutdown(); - } - - struct BarrierServer { - barrier: Barrier, - inner: Server, - } - - impl Serve for BarrierServer { - fn serve(&self, request: Request) -> Reply { - self.barrier.wait(); - self.inner.serve(request) - } - } - - impl BarrierServer { - fn new(n: usize) -> BarrierServer { - BarrierServer { - barrier: Barrier::new(n), - inner: Server::new(), - } - } - - fn count(&self) -> u64 { - self.inner.count() - } - } - - #[test] - fn test_concurrent() { - let server = Arc::new(BarrierServer::new(10)); - let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); - let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr).unwrap()); - let mut join_handles = vec![]; - for _ in 0..10 { - let my_client = client.clone(); - join_handles.push(thread::spawn(move || my_client.rpc(&Request::Increment).unwrap())); - } - for handle in join_handles.into_iter() { - handle.join().unwrap(); - } - assert_eq!(10, server.count()); - let client = match Arc::try_unwrap(client) { - Err(_) => panic!("couldn't unwrap arc"), - Ok(c) => c, - }; - drop(client); - serve_handle.shutdown(); - } -} +/// Provides the macro used for constructing rpc services and client stubs. +pub mod macros; diff --git a/src/macros/src/lib.rs b/src/macros.rs similarity index 72% rename from src/macros/src/lib.rs rename to src/macros.rs index 3833174..48b8596 100644 --- a/src/macros/src/lib.rs +++ b/src/macros.rs @@ -1,61 +1,14 @@ //! Provides a macro for creating an rpc service and client stub. -//! Ex: -//! -//! ``` -//! # #![feature(custom_derive)] -//! # #![feature(custom_derive, plugin)] -//! # #![plugin(serde_macros)] -//! # extern crate tarpc; -//! # #[macro_use] extern crate tarpc_macros; -//! # extern crate serde; -//! rpc_service!(my_server: -//! hello(name: String) -> String; -//! add(x: i32, y: i32) -> i32; -//! ); -//! -//! use self::my_server::*; -//! -//! impl my_server::Service for () { -//! fn hello(&self, s: String) -> String { -//! format!("Hello, {}!", s) -//! } -//! fn add(&self, x: i32, y: i32) -> i32 { -//! x + y -//! } -//! } -//! -//! fn main() { -//! let addr = "127.0.0.1:9000"; -//! let shutdown = my_server::serve(addr, ()).unwrap(); -//! let client = Client::new(addr).unwrap(); -//! assert_eq!(3, client.add(1, 2).unwrap()); -//! assert_eq!("Hello, Mom!".to_string(), client.hello("Mom".to_string()).unwrap()); -//! drop(client); -//! shutdown.shutdown(); -//! } -//! ``` -#![feature(custom_derive, plugin)] -#![plugin(serde_macros)] -#![deny(missing_docs)] - -extern crate serde; -extern crate tarpc; -#[macro_use] -extern crate log; - #[macro_export] macro_rules! rpc_service { ($server:ident: $( $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty;)*) => { #[allow(dead_code)] mod $server { - use std::net::{ - TcpStream, - ToSocketAddrs, - }; + use std::net::ToSocketAddrs; use std::io; use std::sync::Arc; - use tarpc::{ + use $crate::protocol::{ self, ServeHandle, serve_async, @@ -71,10 +24,10 @@ macro_rules! rpc_service { ($server:ident: InternalError, } - impl ::std::convert::From for Error { - fn from(err: tarpc::Error) -> Error { + impl ::std::convert::From for Error { + fn from(err: protocol::Error) -> Error { match err { - tarpc::Error::Io(err) => Error::Io(err), + protocol::Error::Io(err) => Error::Io(err), _ => Error::InternalError, } } @@ -114,15 +67,14 @@ macro_rules! rpc_service { ($server:ident: } #[doc="The client stub that makes RPC calls to the server."] - pub struct Client(tarpc::Client); + pub struct Client(protocol::Client); impl Client { #[doc="Create a new client that connects to the given address."] pub fn new(addr: A) -> Result where A: ToSocketAddrs, { - let stream = try!(TcpStream::connect(addr)); - let inner = try!(tarpc::Client::new(stream)); + let inner = try!(protocol::Client::new(addr)); Ok(Client(inner)) } @@ -140,7 +92,7 @@ macro_rules! rpc_service { ($server:ident: struct Server(S); - impl tarpc::Serve for Server + impl protocol::Serve for Server where S: 'static + Service { fn serve(&self, request: Request) -> Reply { diff --git a/src/macros/Cargo.toml b/src/macros/Cargo.toml deleted file mode 100644 index 6e6ef80..0000000 --- a/src/macros/Cargo.toml +++ /dev/null @@ -1,10 +0,0 @@ -[package] -name = "tarpc_macros" -version = "0.1.0" -authors = ["Adam Wright ", "Tim Kuehn "] - -[dependencies] -tarpc = { path = "../../" } -serde = "*" -serde_macros = "*" -log = "*" diff --git a/src/protocol.rs b/src/protocol.rs new file mode 100644 index 0000000..7a9a41d --- /dev/null +++ b/src/protocol.rs @@ -0,0 +1,391 @@ +use serde::{self, Deserialize}; +use serde_json; +use std::fmt; +use std::io::{self, Read}; +use std::convert; +use std::collections::HashMap; +use std::marker::PhantomData; +use std::net::{TcpListener, TcpStream, SocketAddr, ToSocketAddrs}; +use std::sync::{self, Mutex, Arc}; +use std::sync::mpsc::{channel, Sender, TryRecvError}; +use std::thread::{self, JoinHandle}; + +/// Client errors that can occur during rpc calls +#[derive(Debug)] +pub enum Error { + /// An IO-related error + Io(io::Error), + /// An error in serialization or deserialization + Json(serde_json::Error), + /// An internal message failed to send. + /// Channels are used for the client's inter-thread communication. This message is + /// propagated if the receiver unexpectedly hangs up. + Sender, + /// The server hung up. + ConnectionBroken, +} + +impl convert::From for Error { + fn from(err: serde_json::Error) -> Error { + match err { + serde_json::Error::IoError(err) => Error::Io(err), + err => Error::Json(err), + } + } +} + +impl convert::From for Error { + fn from(err: io::Error) -> Error { + Error::Io(err) + } +} + +impl convert::From> for Error { + fn from(_: sync::mpsc::SendError) -> Error { + Error::Sender + } +} + +/// Return type of rpc calls: either the successful return value, or a client error. +pub type Result = ::std::result::Result; + +fn handle_conn(stream: TcpStream, f: F) -> Result<()> + where Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, + Reply: 'static + fmt::Debug + serde::ser::Serialize, + F: 'static + Clone + Serve +{ + let read_stream = try!(stream.try_clone()); + let mut de = serde_json::Deserializer::new(read_stream.bytes()); + let stream = Arc::new(Mutex::new(stream)); + loop { + let request_packet: Packet = try!(Packet::deserialize(&mut de)); + match request_packet { + Packet::Shutdown => { + let stream = stream.clone(); + let mut my_stream = stream.lock().unwrap(); + try!(serde_json::to_writer(&mut *my_stream, &request_packet)); + break; + } + Packet::Message(id, message) => { + let f = f.clone(); + let arc_stream = stream.clone(); + thread::spawn(move || { + let reply = f.serve(message); + let reply_packet = Packet::Message(id, reply); + let mut my_stream = arc_stream.lock().unwrap(); + serde_json::to_writer(&mut *my_stream, &reply_packet).unwrap(); + }); + } + } + } + Ok(()) +} + +/// Provides methods for blocking until the server completes, +pub struct ServeHandle { + tx: Sender<()>, + join_handle: JoinHandle<()>, + addr: SocketAddr, +} + +impl ServeHandle { + /// Block until the server completes + pub fn wait(self) { + self.join_handle.join().unwrap(); + } + + /// Returns the address the server is bound to + pub fn local_addr(&self) -> &SocketAddr { + &self.addr + } + + /// Shutdown the server. Gracefully shuts down the serve thread but currently does not + /// gracefully close open connections. + pub fn shutdown(self) { + self.tx.send(()).expect(&line!().to_string()); + if let Ok(_) = TcpStream::connect(self.addr) { + self.join_handle.join().expect(&line!().to_string()); + } else { + warn!("Best effort shutdown of serve thread failed"); + } + } +} + +/// Start +pub fn serve_async(addr: A, f: F) -> io::Result + where A: ToSocketAddrs, + Request: 'static + fmt::Debug + Send + serde::de::Deserialize + serde::ser::Serialize, + Reply: 'static + fmt::Debug + serde::ser::Serialize, + F: 'static + Clone + Send + Serve, +{ + let listener = try!(TcpListener::bind(&addr)); + let addr = try!(listener.local_addr()); + info!("Spinning up server on {:?}", addr); + let (die_tx, die_rx) = channel(); + let join_handle = thread::spawn(move || { + for conn in listener.incoming() { + match die_rx.try_recv() { + Ok(_) => break, + Err(TryRecvError::Disconnected) => { + info!("Sender disconnected."); + break; + } + _ => (), + } + let conn = match conn { + Err(err) => { + error!("Failed to accept connection: {:?}", err); + return; + } + Ok(c) => c, + }; + let f = f.clone(); + thread::spawn(move || { + if let Err(err) = handle_conn(conn, f) { + error!("Error in connection handling: {:?}", err); + } + }); + } + }); + Ok(ServeHandle { + tx: die_tx, + join_handle: join_handle, + addr: addr.clone(), + }) +} + +/// A service provided by a server +pub trait Serve: Send + Sync { + /// Return a reply for a given request + fn serve(&self, request: Request) -> Reply; +} + +impl Serve for Arc + where S: Serve +{ + fn serve(&self, request: Request) -> Reply { + S::serve(self, request) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +enum Packet { + Message(u64, T), + Shutdown, +} + +fn reader(stream: TcpStream, requests: Arc>>>) + where Reply: serde::Deserialize +{ + let mut de = serde_json::Deserializer::new(stream.bytes()); + loop { + match Packet::deserialize(&mut de) { + Ok(Packet::Message(id, reply)) => { + let mut requests = requests.lock().unwrap(); + let reply_tx = requests.remove(&id).unwrap(); + reply_tx.send(reply).unwrap(); + } + Ok(Packet::Shutdown) => { + break; + } + // TODO: This shutdown logic is janky.. What's the right way to do this? + Err(err) => panic!("unexpected error while parsing!: {:?}", err), + } + } +} + +fn increment(cur_id: &mut u64) -> u64 { + let id = *cur_id; + *cur_id += 1; + id +} + +struct SyncedClientState { + next_id: u64, + stream: TcpStream, +} + +/// A client stub that connects to a server to run rpcs. +pub struct Client + where Request: serde::ser::Serialize +{ + synced_state: Mutex, + requests: Arc>>>, + reader_guard: Option>, + _request: PhantomData, +} + +impl Client + where Reply: serde::de::Deserialize + Send + 'static, + Request: serde::ser::Serialize +{ + /// Create a new client that connects to `addr` + pub fn new(addr: A) -> io::Result { + let stream = try!(TcpStream::connect(addr)); + let requests = Arc::new(Mutex::new(HashMap::new())); + let reader_stream = try!(stream.try_clone()); + let reader_requests = requests.clone(); + let reader_guard = thread::spawn(move || reader(reader_stream, reader_requests)); + Ok(Client { + synced_state: Mutex::new(SyncedClientState { + next_id: 0, + stream: stream, + }), + requests: requests, + reader_guard: Some(reader_guard), + _request: PhantomData, + }) + } + + /// Run the specified rpc method on the server this client is connected to + pub fn rpc(&self, request: &Request) -> Result + where Request: serde::ser::Serialize + fmt::Debug + Send + 'static + { + let (tx, rx) = channel(); + let mut state = self.synced_state.lock().unwrap(); + let id = increment(&mut state.next_id); + { + let mut requests = self.requests.lock().unwrap(); + requests.insert(id, tx); + } + let packet = Packet::Message(id, request); + if let Err(err) = serde_json::to_writer(&mut state.stream, &packet) { + warn!("Failed to write client packet.\nPacket: {:?}\nError: {:?}", + packet, + err); + self.requests.lock().unwrap().remove(&id); + return Err(err.into()); + } + drop(state); + Ok(rx.recv().unwrap()) + } +} + +impl Drop for Client + where Request: serde::ser::Serialize +{ + fn drop(&mut self) { + { + let mut state = self.synced_state.lock().unwrap(); + let packet: Packet = Packet::Shutdown; + if let Err(err) = serde_json::to_writer(&mut state.stream, &packet) { + warn!("While disconnecting client from server: {:?}", err); + } + } + self.reader_guard.take().unwrap().join().unwrap(); + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::sync::{Arc, Mutex, Barrier}; + use std::thread; + + #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] + enum Request { + Increment, + } + + #[derive(Debug, PartialEq, Serialize, Deserialize)] + enum Reply { + Increment(u64), + } + + struct Server { + counter: Mutex, + } + + impl Serve for Server { + fn serve(&self, _: Request) -> Reply { + let mut counter = self.counter.lock().unwrap(); + let reply = Reply::Increment(*counter); + *counter += 1; + reply + } + } + + impl Server { + fn new() -> Server { + Server { counter: Mutex::new(0) } + } + + fn count(&self) -> u64 { + *self.counter.lock().unwrap() + } + } + + #[test] + fn test_handle() { + let server = Arc::new(Server::new()); + let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); + let client: Client = Client::new(serve_handle.local_addr().clone()) + .expect(&line!().to_string()); + drop(client); + serve_handle.shutdown(); + } + + #[test] + fn test_simple() { + let server = Arc::new(Server::new()); + let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); + let addr = serve_handle.local_addr().clone(); + let client = Client::new(addr).unwrap(); + assert_eq!(Reply::Increment(0), + client.rpc(&Request::Increment).unwrap()); + assert_eq!(1, server.count()); + assert_eq!(Reply::Increment(1), + client.rpc(&Request::Increment).unwrap()); + assert_eq!(2, server.count()); + drop(client); + serve_handle.shutdown(); + } + + struct BarrierServer { + barrier: Barrier, + inner: Server, + } + + impl Serve for BarrierServer { + fn serve(&self, request: Request) -> Reply { + self.barrier.wait(); + self.inner.serve(request) + } + } + + impl BarrierServer { + fn new(n: usize) -> BarrierServer { + BarrierServer { + barrier: Barrier::new(n), + inner: Server::new(), + } + } + + fn count(&self) -> u64 { + self.inner.count() + } + } + + #[test] + fn test_concurrent() { + let server = Arc::new(BarrierServer::new(10)); + let serve_handle = serve_async("0.0.0.0:0", server.clone()).unwrap(); + let addr = serve_handle.local_addr().clone(); + let client: Arc> = Arc::new(Client::new(addr).unwrap()); + let mut join_handles = vec![]; + for _ in 0..10 { + let my_client = client.clone(); + join_handles.push(thread::spawn(move || my_client.rpc(&Request::Increment).unwrap())); + } + for handle in join_handles.into_iter() { + handle.join().unwrap(); + } + assert_eq!(10, server.count()); + let client = match Arc::try_unwrap(client) { + Err(_) => panic!("couldn't unwrap arc"), + Ok(c) => c, + }; + drop(client); + serve_handle.shutdown(); + } +}