diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index efb2684..5fce4e8 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -60,4 +60,5 @@ pub mod protocol; /// Provides the macro used for constructing rpc services and client stubs. pub mod macros; -pub use protocol::{Config, Error, Result, ServeHandle}; +pub use protocol::{Config, Dialer, Error, Listener, Result, ServeHandle, Stream, TcpDialer, + TcpDialerExt, TcpTransport, Transport}; diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs index 02fac66..f068736 100644 --- a/tarpc/src/macros.rs +++ b/tarpc/src/macros.rs @@ -315,17 +315,18 @@ macro_rules! service_inner { )* #[doc="Spawn a running service."] - fn spawn(self, addr: A) -> $crate::Result<$crate::protocol::ServeHandle> + fn spawn(self, addr: A) + -> $crate::Result<$crate::protocol::ServeHandle<$crate::TcpDialer<::std::net::SocketAddr>>> where A: ::std::net::ToSocketAddrs, Self: 'static, { - self.spawn_with_config(addr, $crate::Config::default()) + self.spawn_with_config($crate::TcpTransport(addr), $crate::Config::default()) } #[doc="Spawn a running service."] - fn spawn_with_config(self, addr: A, config: $crate::Config) - -> $crate::Result<$crate::protocol::ServeHandle> - where A: ::std::net::ToSocketAddrs, + fn spawn_with_config(self, addr: T, config: $crate::Config) + -> $crate::Result<$crate::protocol::ServeHandle<::Dialer>> + where T: $crate::Transport, Self: 'static, { let server = ::std::sync::Arc::new(__Server(self)); @@ -385,25 +386,27 @@ macro_rules! service_inner { #[allow(unused)] #[doc="The client stub that makes RPC calls to the server."] - pub struct Client($crate::protocol::Client<__Request, __Reply>); + pub struct Client($crate::protocol::Client<__Request, __Reply, S>); - impl Client { - #[allow(unused)] - #[doc="Create a new client with default configuration that connects to the given \ - address."] + impl Client<::std::net::TcpStream> { pub fn new(addr: A) -> $crate::Result where A: ::std::net::ToSocketAddrs, { - Self::with_config(addr, $crate::Config::default()) + Self::with_config($crate::TcpDialer(addr), $crate::Config::default()) } + } + impl Client { + #[allow(unused)] + #[doc="Create a new client with default configuration that connects to the given \ + address."] #[allow(unused)] #[doc="Create a new client with the specified configuration that connects to the \ given address."] - pub fn with_config(addr: A, config: $crate::Config) -> $crate::Result - where A: ::std::net::ToSocketAddrs, + pub fn with_config(dialer: D, config: $crate::Config) -> $crate::Result + where D: $crate::Dialer, { - let inner = try!($crate::protocol::Client::with_config(addr, config)); + let inner = try!($crate::protocol::Client::with_config(dialer, config)); ::std::result::Result::Ok(Client(inner)) } @@ -424,25 +427,26 @@ macro_rules! service_inner { #[allow(unused)] #[doc="The client stub that makes asynchronous RPC calls to the server."] - pub struct AsyncClient($crate::protocol::Client<__Request, __Reply>); + pub struct AsyncClient($crate::protocol::Client<__Request, __Reply, S>); - impl AsyncClient { + impl AsyncClient<::std::net::TcpStream> { #[allow(unused)] #[doc="Create a new asynchronous client with default configuration that connects to \ the given address."] - pub fn new(addr: A) -> $crate::Result + pub fn new(addr: A) -> $crate::Result> where A: ::std::net::ToSocketAddrs, { - Self::with_config(addr, $crate::Config::default()) + Self::with_config($crate::TcpDialer(addr), $crate::Config::default()) } + } + impl AsyncClient { #[allow(unused)] #[doc="Create a new asynchronous client that connects to the given address."] - pub fn with_config(addr: A, config: $crate::Config) - -> $crate::Result - where A: ::std::net::ToSocketAddrs, + pub fn with_config(dialer: D, config: $crate::Config) -> $crate::Result + where D: $crate::Dialer { - let inner = try!($crate::protocol::Client::with_config(addr, config)); + let inner = try!($crate::protocol::Client::with_config(dialer, config)); ::std::result::Result::Ok(AsyncClient(inner)) } diff --git a/tarpc/src/protocol/client.rs b/tarpc/src/protocol/client.rs index edb0d36..38701cb 100644 --- a/tarpc/src/protocol/client.rs +++ b/tarpc/src/protocol/client.rs @@ -13,33 +13,41 @@ use std::sync::{Arc, Mutex}; use std::sync::mpsc::{Receiver, Sender, channel}; use std::thread; -use super::{Config, Deserialize, Error, Packet, Result, Serialize}; +use super::{Config, Deserialize, Dialer, Error, Packet, Result, Serialize, Stream, TcpDialer}; /// A client stub that connects to a server to run rpcs. -pub struct Client +pub struct Client where Request: serde::ser::Serialize { // The guard is in an option so it can be joined in the drop fn reader_guard: Arc>>, outbound: Sender<(Request, Sender>)>, requests: Arc>>, - shutdown: TcpStream, + shutdown: S, } -impl Client + +impl Client where Request: serde::ser::Serialize + Send + 'static, Reply: serde::de::Deserialize + Send + 'static { /// Create a new client that connects to `addr`. The client uses the given timeout /// for both reads and writes. pub fn new(addr: A) -> io::Result { - Self::with_config(addr, Config::default()) + Self::with_config(TcpDialer(addr), Config::default()) } +} +impl Client + where Request: serde::ser::Serialize + Send + 'static, + Reply: serde::de::Deserialize + Send + 'static +{ /// Create a new client that connects to `addr`. The client uses the given timeout /// for both reads and writes. - pub fn with_config(addr: A, config: Config) -> io::Result { - let stream = try!(TcpStream::connect(addr)); + pub fn with_config(dialer: D, config: Config) -> io::Result + where D: Dialer, + { + let stream = try!(dialer.dial()); try!(stream.set_read_timeout(config.timeout)); try!(stream.set_write_timeout(config.timeout)); let reader_stream = try!(stream.try_clone()); @@ -59,7 +67,7 @@ impl Client } /// Clones the Client so that it can be shared across threads. - pub fn try_clone(&self) -> io::Result> { + pub fn try_clone(&self) -> io::Result { Ok(Client { reader_guard: self.reader_guard.clone(), outbound: self.outbound.clone(), @@ -97,14 +105,14 @@ impl Client } } -impl Drop for Client +impl Drop for Client where Request: serde::ser::Serialize { fn drop(&mut self) { debug!("Dropping Client."); if let Some(reader_guard) = Arc::get_mut(&mut self.reader_guard) { debug!("Attempting to shut down writer and reader threads."); - if let Err(e) = self.shutdown.shutdown(::std::net::Shutdown::Both) { + if let Err(e) = self.shutdown.shutdown() { warn!("Client: couldn't shutdown writer and reader threads: {:?}", e); } else { @@ -185,9 +193,9 @@ impl RpcFutures { } } -fn write(outbound: Receiver<(Request, Sender>)>, +fn write(outbound: Receiver<(Request, Sender>)>, requests: Arc>>, - stream: TcpStream) + stream: S) where Request: serde::Serialize, Reply: serde::Deserialize { @@ -238,7 +246,7 @@ fn write(outbound: Receiver<(Request, Sender>)>, } -fn read(requests: Arc>>, stream: TcpStream) +fn read(requests: Arc>>, stream: S) where Reply: serde::Deserialize { let mut stream = BufReader::new(stream); diff --git a/tarpc/src/protocol/mod.rs b/tarpc/src/protocol/mod.rs index 3d9cd44..77bcd67 100644 --- a/tarpc/src/protocol/mod.rs +++ b/tarpc/src/protocol/mod.rs @@ -8,6 +8,7 @@ use bincode::serde::{deserialize_from, serialize_into}; use serde; use std::io::{self, Read, Write}; use std::convert; +use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs}; use std::sync::Arc; use std::time::Duration; @@ -62,6 +63,123 @@ pub struct Config { pub timeout: Option, } +/// A factory for creating a listener on a given address. +pub trait Transport { + /// The type of listener that binds to the given address. + type Listener: Listener; + /// Return a listener on the given address, and a dialer to that address. + fn bind(&self) -> io::Result; +} + +/// A transport for TCP. +pub struct TcpTransport(pub A); +impl Transport for TcpTransport { + type Listener = TcpListener; + fn bind(&self) -> io::Result { + TcpListener::bind(&self.0) + } +} + +/// Accepts incoming connections from dialers. +pub trait Listener: Send + 'static { + /// The type of address being listened on. + type Dialer: Dialer; + /// The type of stream this listener accepts. + type Stream: Stream; + /// Accept an incoming stream. + fn accept(&self) -> io::Result; + /// Returns the local address being listened on. + fn dialer(&self) -> io::Result; + /// Iterate over incoming connections. + fn incoming(&self) -> Incoming { + Incoming { + listener: self, + } + } +} + +impl Listener for TcpListener { + type Dialer = TcpDialer; + type Stream = TcpStream; + fn accept(&self) -> io::Result { + self.accept().map(|(stream, _)| stream) + } + fn dialer(&self) -> io::Result> { + self.local_addr().map(|addr| TcpDialer(addr)) + } +} + +/// A cloneable Reader/Writer. +pub trait Stream: Read + Write + Send + Sized + 'static { + /// Clone that can fail. + fn try_clone(&self) -> io::Result; + /// Sets a read timeout. + fn set_read_timeout(&self, dur: Option) -> io::Result<()>; + /// Sets a write timeout. + fn set_write_timeout(&self, dur: Option) -> io::Result<()>; + /// Shuts down both ends of the stream. + fn shutdown(&self) -> io::Result<()>; +} + +impl Stream for TcpStream { + fn try_clone(&self) -> io::Result { + self.try_clone() + } + fn set_read_timeout(&self, dur: Option) -> io::Result<()> { + self.set_read_timeout(dur) + } + fn set_write_timeout(&self, dur: Option) -> io::Result<()> { + self.set_write_timeout(dur) + } + fn shutdown(&self) -> io::Result<()> { + self.shutdown(::std::net::Shutdown::Both) + } +} + +/// A `Stream` factory. +pub trait Dialer { + /// The type of `Stream` this can create. + type Stream: Stream; + /// Open a stream. + fn dial(&self) -> io::Result; +} + +/// Allows retrieving the address when the Dialer is known to be a TcpDialer. +pub trait TcpDialerExt { + /// Type of the address. + type Addr: ToSocketAddrs; + /// Return the address the Dialer connects to. + fn addr(&self) -> &Self::Addr; +} + +/// Connects to a socket address. +pub struct TcpDialer(pub A); +impl Dialer for TcpDialer { + type Stream = TcpStream; + fn dial(&self) -> io::Result { + TcpStream::connect(&self.0) + } +} +impl TcpDialerExt for TcpDialer { + type Addr = A; + fn addr(&self) -> &A { + &self.0 + } +} + +/// Iterates over incoming connections. +pub struct Incoming<'a, L: Listener + ?Sized + 'a> { + listener: &'a L, +} + +impl<'a, L: Listener> Iterator for Incoming<'a, L> { + type Item = io::Result; + + fn next(&mut self) -> Option { + Some(self.listener.accept()) + } +} + /// Return type of rpc calls: either the successful return value, or a client error. pub type Result = ::std::result::Result; @@ -86,8 +204,9 @@ impl Serialize for W {} #[cfg(test)] mod test { extern crate env_logger; - use super::{Client, Config, Serve}; + use super::{Client, Config, Serve, TcpTransport}; use scoped_pool::Pool; + use std::net::TcpStream; use std::sync::{Arc, Barrier, Mutex}; use std::thread; use std::time::Duration; @@ -127,7 +246,7 @@ mod test { let _ = env_logger::init(); let server = Arc::new(Server::new()); let serve_handle = server.spawn("localhost:0").unwrap(); - let client: Client<(), u64> = Client::new(serve_handle.local_addr()).unwrap(); + let client: Client<(), u64, TcpStream> = Client::new(serve_handle.local_addr()).unwrap(); drop(client); serve_handle.shutdown(); } @@ -139,7 +258,7 @@ mod test { let serve_handle = server.clone().spawn("localhost:0").unwrap(); let addr = serve_handle.local_addr().clone(); // The explicit type is required so that it doesn't deserialize a u32 instead of u64 - let client: Client<(), u64> = Client::new(addr).unwrap(); + let client: Client<(), u64, _> = Client::new(addr).unwrap(); assert_eq!(0, client.rpc(()).unwrap()); assert_eq!(1, server.count()); assert_eq!(1, client.rpc(()).unwrap()); @@ -179,13 +298,13 @@ mod test { fn force_shutdown() { let _ = env_logger::init(); let server = Arc::new(Server::new()); - let serve_handle = server.spawn_with_config("localhost:0", + let serve_handle = server.spawn_with_config(TcpTransport("localhost:0"), Config { - timeout: Some(Duration::new(0, 10)) + timeout: Some(Duration::new(0, 10)), }) .unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Client<(), u64> = Client::new(addr).unwrap(); + let client: Client<(), u64, _> = Client::new(addr).unwrap(); let thread = thread::spawn(move || serve_handle.shutdown()); info!("force_shutdown:: rpc1: {:?}", client.rpc(())); thread.join().unwrap(); @@ -195,13 +314,13 @@ mod test { fn client_failed_rpc() { let _ = env_logger::init(); let server = Arc::new(Server::new()); - let serve_handle = server.spawn_with_config("localhost:0", + let serve_handle = server.spawn_with_config(TcpTransport("localhost:0"), Config { timeout: test_timeout(), }) .unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr).unwrap()); + let client: Arc> = Arc::new(Client::new(addr).unwrap()); client.rpc(()).unwrap(); serve_handle.shutdown(); match client.rpc(()) { @@ -219,7 +338,7 @@ mod test { let server = Arc::new(BarrierServer::new(concurrency)); let serve_handle = server.clone().spawn("localhost:0").unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Client<(), u64> = Client::new(addr).unwrap(); + let client: Client<(), u64, _> = Client::new(addr).unwrap(); pool.scoped(|scope| { for _ in 0..concurrency { let client = client.try_clone().unwrap(); @@ -239,7 +358,7 @@ mod test { let server = Arc::new(Server::new()); let serve_handle = server.spawn("localhost:0").unwrap(); let addr = serve_handle.local_addr().clone(); - let client: Client<(), u64> = Client::new(addr).unwrap(); + let client: Client<(), u64, _> = Client::new(addr).unwrap(); // Drop future immediately; does the reader channel panic when sending? client.rpc_async(()); diff --git a/tarpc/src/protocol/server.rs b/tarpc/src/protocol/server.rs index 0631712..87b866a 100644 --- a/tarpc/src/protocol/server.rs +++ b/tarpc/src/protocol/server.rs @@ -7,24 +7,27 @@ use serde; use scoped_pool::{Pool, Scope}; use std::fmt; use std::io::{self, BufReader, BufWriter}; -use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs}; +use std::net::{SocketAddr, ToSocketAddrs}; use std::sync::mpsc::{Receiver, Sender, TryRecvError, channel}; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use std::thread::{self, JoinHandle}; -use super::{Config, Deserialize, Error, Packet, Result, Serialize}; +use super::{Config, Deserialize, Dialer, Error, Listener, Packet, Result, Serialize, Stream, + TcpDialer, TcpDialerExt, TcpTransport, Transport}; -struct ConnectionHandler<'a, S> - where S: Serve +struct ConnectionHandler<'a, S, St> + where S: Serve, + St: Stream, { - read_stream: BufReader, - write_stream: BufWriter, + read_stream: BufReader, + write_stream: BufWriter, server: S, shutdown: &'a AtomicBool, } -impl<'a, S> ConnectionHandler<'a, S> - where S: Serve +impl<'a, S, St> ConnectionHandler<'a, S, St> + where S: Serve, + St: Stream, { fn handle_conn<'b>(&'b mut self, scope: &Scope<'b>) -> Result<()> { let ConnectionHandler { @@ -83,7 +86,7 @@ impl<'a, S> ConnectionHandler<'a, S> } } - fn write(rx: Receiver::Reply>>, stream: &mut BufWriter) { + fn write(rx: Receiver::Reply>>, stream: &mut BufWriter) { loop { match rx.recv() { Err(e) => { @@ -101,21 +104,30 @@ impl<'a, S> ConnectionHandler<'a, S> } /// Provides methods for blocking until the server completes, -pub struct ServeHandle { +pub struct ServeHandle + where D: Dialer +{ tx: Sender<()>, join_handle: JoinHandle<()>, - addr: SocketAddr, + dialer: D, } -impl ServeHandle { +impl ServeHandle + where D: Dialer +{ /// Block until the server completes pub fn wait(self) { self.join_handle.join().expect(pos!()); } - /// Returns the address the server is bound to - pub fn local_addr(&self) -> &SocketAddr { - &self.addr + /// Returns the dialer to the server. + pub fn dialer(&self) -> &D { + &self.dialer + } + + /// Returns the socket being listened on when the dialer is a `TcpDialer`. + pub fn local_addr(&self) -> &D::Addr where D: TcpDialerExt { + self.dialer().addr() } /// Shutdown the server. Gracefully shuts down the serve thread but currently does not @@ -123,7 +135,7 @@ impl ServeHandle { pub fn shutdown(self) { info!("ServeHandle: attempting to shut down the server."); self.tx.send(()).expect(pos!()); - if let Ok(_) = TcpStream::connect(self.addr) { + if let Ok(_) = self.dialer.dial() { self.join_handle.join().expect(pos!()); } else { warn!("ServeHandle: best effort shutdown of serve thread failed"); @@ -131,15 +143,15 @@ impl ServeHandle { } } -struct Server<'a, S: 'a> { +struct Server<'a, S: 'a, L: Listener> { server: &'a S, - listener: TcpListener, + listener: L, read_timeout: Option, die_rx: Receiver<()>, shutdown: &'a AtomicBool, } -impl<'a, S: 'a> Server<'a, S> +impl<'a, S: 'a, L: Listener> Server<'a, S, L> where S: Serve + 'static { fn serve<'b>(self, scope: &Scope<'b>) @@ -194,7 +206,7 @@ impl<'a, S: 'a> Server<'a, S> } } -impl<'a, S> Drop for Server<'a, S> { +impl<'a, S, L: Listener> Drop for Server<'a, S, L> { fn drop(&mut self) { debug!("Shutting down connection handlers."); self.shutdown.store(true, Ordering::SeqCst); @@ -212,29 +224,30 @@ pub trait Serve: Send + Sync + Sized { fn serve(&self, request: Self::Request) -> Self::Reply; /// spawn - fn spawn(self, addr: A) -> io::Result + fn spawn(self, addr: A) -> io::Result>> where A: ToSocketAddrs, Self: 'static, { - self.spawn_with_config(addr, Config::default()) + self.spawn_with_config(TcpTransport(addr), Config::default()) } /// spawn - fn spawn_with_config(self, addr: A, config: Config) -> io::Result - where A: ToSocketAddrs, - Self: 'static, + fn spawn_with_config(self, transport: T, config: Config) + -> io::Result::Dialer>> + where Self: 'static, { - let listener = try!(TcpListener::bind(&addr)); - let addr = try!(listener.local_addr()); - info!("spawn_with_config: spinning up server on {:?}", addr); + let listener = try!(transport.bind()); + let dialer = try!(listener.dialer()); + info!("spawn_with_config: spinning up server."); let (die_tx, die_rx) = channel(); + let timeout = config.timeout; let join_handle = thread::spawn(move || { let pool = Pool::new(100); // TODO(tjk): make this configurable, and expire idle threads let shutdown = AtomicBool::new(false); let server = Server { server: &self, listener: listener, - read_timeout: config.timeout, + read_timeout: timeout, die_rx: die_rx, shutdown: &shutdown, }; @@ -245,7 +258,7 @@ pub trait Serve: Send + Sync + Sized { Ok(ServeHandle { tx: die_tx, join_handle: join_handle, - addr: addr.clone(), + dialer: dialer, }) }