diff --git a/README.md b/README.md index c146188..29335ec 100644 --- a/README.md +++ b/README.md @@ -46,8 +46,9 @@ impl HelloService for HelloServer { } fn main() { - let server_handle = HelloServer.spawn("0.0.0.0:0").unwrap(); - let client = hello_service::Client::new(server_handle.local_addr()).unwrap(); + let addr = "localhost:10000"; + let server_handle = HelloServer.spawn(addr).unwrap(); + let client = hello_service::Client::new(addr).unwrap(); assert_eq!("Hello, Mom!", client.hello("Mom".into()).unwrap()); drop(client); server_handle.shutdown(); @@ -56,17 +57,18 @@ fn main() { The `service!` macro expands to a collection of items that collectively form an rpc service. In the above example, the macro is called within the `hello_service` module. This module will contain a -`Client` (and `AsyncClient`) type, and a `Service` trait. The trait provides `default fn`s for -starting the service: `spawn` and `spawn_with_config`, which start the service listening on a tcp -port. A `Client` (or `AsyncClient`) can connect to such a service. These generated types make it -easy and ergonomic to write servers without dealing with sockets or serialization directly. See the -tarpc_examples package for more sophisticated examples. +`Client` (and `AsyncClient`) type, and a `Service` trait. The trait provides default `fn`s for +starting the service: `spawn` and `spawn_with_config`, which start the service listening over an +arbitrary transport. A `Client` (or `AsyncClient`) can connect to such a service. These generated +types make it easy and ergonomic to write servers without dealing with sockets or serialization +directly. See the tarpc_examples package for more sophisticated examples. ## Documentation Use `cargo doc` as you normally would to see the documentation created for all items expanded by a `service!` invocation. ## Additional Features +- Connect over any transport that `impl`s the `Transport` trait. - Concurrent requests from a single client. - Any type that `impl`s `serde`'s `Serialize` and `Deserialize` can be used in the rpc signatures. - Attributes can be specified on rpc methods. These will be included on both the `Service` trait diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 9208725..8658713 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -11,11 +11,13 @@ readme = "../README.md" description = "An RPC framework for Rust with a focus on ease of use." [dependencies] -bincode = "^0.5" -log = "^0.3" -scoped-pool = "^0.1" -serde = "^0.7" +bincode = "0.5" +log = "0.3" +scoped-pool = "0.1" +serde = "0.7" +unix_socket = "0.5" [dev-dependencies] -lazy_static = "^0.1" -env_logger = "^0.3" +lazy_static = "0.1" +env_logger = "0.3" +tempdir = "0.3" diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index efb2684..dafa622 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -30,14 +30,13 @@ //! } //! //! fn main() { -//! let addr = "127.0.0.1:9000"; -//! let shutdown = Server.spawn(addr).unwrap(); -//! let client = Client::new(addr).unwrap(); +//! let serve_handle = Server.spawn("localhost:0").unwrap(); +//! let client = Client::new(serve_handle.dialer()).unwrap(); //! assert_eq!(3, client.add(1, 2).unwrap()); //! assert_eq!("Hello, Mom!".to_string(), //! client.hello("Mom".to_string()).unwrap()); //! drop(client); -//! shutdown.shutdown(); +//! serve_handle.shutdown(); //! } //! ``` @@ -48,6 +47,7 @@ extern crate bincode; #[macro_use] extern crate log; extern crate scoped_pool; +extern crate unix_socket; macro_rules! pos { () => (concat!(file!(), ":", line!())) @@ -60,4 +60,7 @@ pub mod protocol; /// Provides the macro used for constructing rpc services and client stubs. pub mod macros; +/// Provides transport traits and implementations. +pub mod transport; + pub use protocol::{Config, Error, Result, ServeHandle}; diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs index e4510fa..d34f054 100644 --- a/tarpc/src/macros.rs +++ b/tarpc/src/macros.rs @@ -262,7 +262,7 @@ macro_rules! service { #[doc(hidden)] #[macro_export] macro_rules! service_inner { - // Pattern for when the next rpc has an implicit unit return type +// Pattern for when the next rpc has an implicit unit return type ( { $(#[$attr:meta])* @@ -281,7 +281,7 @@ macro_rules! service_inner { rpc $fn_name( $( $arg : $in_ ),* ) -> (); } }; - // Pattern for when the next rpc has an explicit return type +// Pattern for when the next rpc has an explicit return type ( { $(#[$attr:meta])* @@ -300,7 +300,7 @@ macro_rules! service_inner { rpc $fn_name( $( $arg : $in_ ),* ) -> $out; } }; - // Pattern when all return types have been expanded +// Pattern when all return types have been expanded ( { } // none left to expand $( @@ -316,21 +316,30 @@ macro_rules! service_inner { )* #[doc="Spawn a running service."] - fn spawn(self, addr: A) -> $crate::Result<$crate::protocol::ServeHandle> - where A: ::std::net::ToSocketAddrs, + fn spawn(self, + transport: T) + -> $crate::Result< + $crate::protocol::ServeHandle< + ::Dialer>> + where T: $crate::transport::Transport, Self: 'static, { - self.spawn_with_config(addr, $crate::Config::default()) + self.spawn_with_config(transport, $crate::Config::default()) } #[doc="Spawn a running service."] - fn spawn_with_config(self, addr: A, config: $crate::Config) - -> $crate::Result<$crate::protocol::ServeHandle> - where A: ::std::net::ToSocketAddrs, + fn spawn_with_config(self, + transport: T, + config: $crate::Config) + -> $crate::Result< + $crate::protocol::ServeHandle< + ::Dialer>> + where T: $crate::transport::Transport, Self: 'static, { - let server = ::std::sync::Arc::new(__Server(self)); - let handle = try!($crate::protocol::Serve::spawn_with_config(server, addr, config)); + let server = __Server(self); + let result = $crate::protocol::Serve::spawn_with_config(server, transport, config); + let handle = try!(result); ::std::result::Result::Ok(handle) } } @@ -386,25 +395,29 @@ macro_rules! service_inner { #[allow(unused)] #[doc="The client stub that makes RPC calls to the server."] - pub struct Client($crate::protocol::Client<__Request, __Reply>); + pub struct Client( + $crate::protocol::Client<__Request, __Reply, S> + ) where S: $crate::transport::Stream; - impl Client { + impl Client + where S: $crate::transport::Stream + { #[allow(unused)] #[doc="Create a new client with default configuration that connects to the given \ address."] - pub fn new(addr: A) -> $crate::Result - where A: ::std::net::ToSocketAddrs, + pub fn new(dialer: D) -> $crate::Result + where D: $crate::transport::Dialer, { - Self::with_config(addr, $crate::Config::default()) + Self::with_config(dialer, $crate::Config::default()) } #[allow(unused)] #[doc="Create a new client with the specified configuration that connects to the \ given address."] - pub fn with_config(addr: A, config: $crate::Config) -> $crate::Result - where A: ::std::net::ToSocketAddrs, + pub fn with_config(dialer: D, config: $crate::Config) -> $crate::Result + where D: $crate::transport::Dialer, { - let inner = try!($crate::protocol::Client::with_config(addr, config)); + let inner = try!($crate::protocol::Client::with_config(dialer, config)); ::std::result::Result::Ok(Client(inner)) } @@ -425,25 +438,27 @@ macro_rules! service_inner { #[allow(unused)] #[doc="The client stub that makes asynchronous RPC calls to the server."] - pub struct AsyncClient($crate::protocol::Client<__Request, __Reply>); + pub struct AsyncClient( + $crate::protocol::Client<__Request, __Reply, S> + ) where S: $crate::transport::Stream; - impl AsyncClient { + impl AsyncClient + where S: $crate::transport::Stream { #[allow(unused)] #[doc="Create a new asynchronous client with default configuration that connects to \ the given address."] - pub fn new(addr: A) -> $crate::Result - where A: ::std::net::ToSocketAddrs, + pub fn new(dialer: D) -> $crate::Result + where D: $crate::transport::Dialer, { - Self::with_config(addr, $crate::Config::default()) + Self::with_config(dialer, $crate::Config::default()) } #[allow(unused)] #[doc="Create a new asynchronous client that connects to the given address."] - pub fn with_config(addr: A, config: $crate::Config) - -> $crate::Result - where A: ::std::net::ToSocketAddrs, + pub fn with_config(dialer: D, config: $crate::Config) -> $crate::Result + where D: $crate::transport::Dialer, { - let inner = try!($crate::protocol::Client::with_config(addr, config)); + let inner = try!($crate::protocol::Client::with_config(dialer, config)); ::std::result::Result::Ok(AsyncClient(inner)) } @@ -463,7 +478,8 @@ macro_rules! service_inner { } #[allow(unused)] - struct __Server(S); + struct __Server(S) + where S: 'static + Service; impl $crate::protocol::Serve for __Server where S: 'static + Service @@ -513,6 +529,8 @@ mod syntax_test { #[cfg(test)] mod functional_test { extern crate env_logger; + extern crate tempdir; + use transport::unix::UnixTransport; service! { rpc add(x: i32, y: i32) -> i32; @@ -534,7 +552,7 @@ mod functional_test { fn simple() { let _ = env_logger::init(); let handle = Server.spawn("localhost:0").unwrap(); - let client = Client::new(handle.local_addr()).unwrap(); + let client = Client::new(handle.dialer()).unwrap(); assert_eq!(3, client.add(1, 2).unwrap()); assert_eq!("Hey, Tim.", client.hey("Tim".into()).unwrap()); drop(client); @@ -545,7 +563,7 @@ mod functional_test { fn simple_async() { let _ = env_logger::init(); let handle = Server.spawn("localhost:0").unwrap(); - let client = AsyncClient::new(handle.local_addr()).unwrap(); + let client = AsyncClient::new(handle.dialer()).unwrap(); assert_eq!(3, client.add(1, 2).get().unwrap()); assert_eq!("Hey, Adam.", client.hey("Adam".into()).get().unwrap()); drop(client); @@ -555,7 +573,7 @@ mod functional_test { #[test] fn try_clone() { let handle = Server.spawn("localhost:0").unwrap(); - let client1 = Client::new(handle.local_addr()).unwrap(); + let client1 = Client::new(handle.dialer()).unwrap(); let client2 = client1.try_clone().unwrap(); assert_eq!(3, client1.add(1, 2).unwrap()); assert_eq!(3, client2.add(1, 2).unwrap()); @@ -564,7 +582,19 @@ mod functional_test { #[test] fn async_try_clone() { let handle = Server.spawn("localhost:0").unwrap(); - let client1 = AsyncClient::new(handle.local_addr()).unwrap(); + let client1 = AsyncClient::new(handle.dialer()).unwrap(); + let client2 = client1.try_clone().unwrap(); + assert_eq!(3, client1.add(1, 2).get().unwrap()); + assert_eq!(3, client2.add(1, 2).get().unwrap()); + } + + #[test] + fn async_try_clone_unix() { + let temp_dir = tempdir::TempDir::new("tarpc").unwrap(); + let temp_file = temp_dir.path() + .join("async_try_clone_unix.tmp"); + let handle = Server.spawn(UnixTransport(temp_file)).unwrap(); + let client1 = AsyncClient::new(handle.dialer()).unwrap(); let client2 = client1.try_clone().unwrap(); assert_eq!(3, client1.add(1, 2).get().unwrap()); assert_eq!(3, client2.add(1, 2).get().unwrap()); @@ -576,6 +606,12 @@ mod functional_test { let _ = ::std::sync::Arc::new(Server).spawn("localhost:0"); } + // Tests that a tcp client can be created from &str + #[allow(dead_code)] + fn test_client_str() { + let _ = Client::new("localhost:0"); + } + #[test] fn serde() { use bincode; diff --git a/tarpc/src/protocol/client.rs b/tarpc/src/protocol/client.rs index edb0d36..dd2ed56 100644 --- a/tarpc/src/protocol/client.rs +++ b/tarpc/src/protocol/client.rs @@ -8,38 +8,44 @@ use std::fmt; use std::io::{self, BufReader, BufWriter, Read}; use std::collections::HashMap; use std::mem; -use std::net::{TcpStream, ToSocketAddrs}; use std::sync::{Arc, Mutex}; use std::sync::mpsc::{Receiver, Sender, channel}; use std::thread; use super::{Config, Deserialize, Error, Packet, Result, Serialize}; +use transport::{Dialer, Stream}; /// A client stub that connects to a server to run rpcs. -pub struct Client - where Request: serde::ser::Serialize +pub struct Client + where Request: serde::ser::Serialize, + S: Stream { // The guard is in an option so it can be joined in the drop fn reader_guard: Arc>>, outbound: Sender<(Request, Sender>)>, requests: Arc>>, - shutdown: TcpStream, + shutdown: S, } -impl Client +impl Client where Request: serde::ser::Serialize + Send + 'static, - Reply: serde::de::Deserialize + Send + 'static + Reply: serde::de::Deserialize + Send + 'static, + S: Stream { /// Create a new client that connects to `addr`. The client uses the given timeout /// for both reads and writes. - pub fn new(addr: A) -> io::Result { - Self::with_config(addr, Config::default()) + pub fn new(dialer: D) -> io::Result + where D: Dialer + { + Self::with_config(dialer, Config::default()) } /// Create a new client that connects to `addr`. The client uses the given timeout /// for both reads and writes. - pub fn with_config(addr: A, config: Config) -> io::Result { - let stream = try!(TcpStream::connect(addr)); + pub fn with_config(dialer: D, config: Config) -> io::Result + where D: Dialer + { + let stream = try!(dialer.dial()); try!(stream.set_read_timeout(config.timeout)); try!(stream.set_write_timeout(config.timeout)); let reader_stream = try!(stream.try_clone()); @@ -59,7 +65,7 @@ impl Client } /// Clones the Client so that it can be shared across threads. - pub fn try_clone(&self) -> io::Result> { + pub fn try_clone(&self) -> io::Result { Ok(Client { reader_guard: self.reader_guard.clone(), outbound: self.outbound.clone(), @@ -97,14 +103,15 @@ impl Client } } -impl Drop for Client - where Request: serde::ser::Serialize +impl Drop for Client + where Request: serde::ser::Serialize, + S: Stream { fn drop(&mut self) { debug!("Dropping Client."); if let Some(reader_guard) = Arc::get_mut(&mut self.reader_guard) { debug!("Attempting to shut down writer and reader threads."); - if let Err(e) = self.shutdown.shutdown(::std::net::Shutdown::Both) { + if let Err(e) = self.shutdown.shutdown() { warn!("Client: couldn't shutdown writer and reader threads: {:?}", e); } else { @@ -185,11 +192,12 @@ impl RpcFutures { } } -fn write(outbound: Receiver<(Request, Sender>)>, - requests: Arc>>, - stream: TcpStream) +fn write(outbound: Receiver<(Request, Sender>)>, + requests: Arc>>, + stream: S) where Request: serde::Serialize, - Reply: serde::Deserialize + Reply: serde::Deserialize, + S: Stream { let mut next_id = 0; let mut stream = BufWriter::new(stream); @@ -238,8 +246,9 @@ fn write(outbound: Receiver<(Request, Sender>)>, } -fn read(requests: Arc>>, stream: TcpStream) - where Reply: serde::Deserialize +fn read(requests: Arc>>, stream: S) + where Reply: serde::Deserialize, + S: Stream { let mut stream = BufReader::new(stream); loop { diff --git a/tarpc/src/protocol/mod.rs b/tarpc/src/protocol/mod.rs index c7e2132..3a7b2fd 100644 --- a/tarpc/src/protocol/mod.rs +++ b/tarpc/src/protocol/mod.rs @@ -93,6 +93,7 @@ mod test { extern crate env_logger; use super::{Client, Config, Serve}; use scoped_pool::Pool; + use std::net::TcpStream; use std::sync::{Arc, Barrier, Mutex}; use std::thread; use std::time::Duration; @@ -132,7 +133,7 @@ mod test { let _ = env_logger::init(); let server = Arc::new(Server::new()); let serve_handle = server.spawn("localhost:0").unwrap(); - let client: Client<(), u64> = Client::new(serve_handle.local_addr()).unwrap(); + let client: Client<(), u64, TcpStream> = Client::new(serve_handle.dialer()).unwrap(); drop(client); serve_handle.shutdown(); } @@ -142,9 +143,8 @@ mod test { let _ = env_logger::init(); let server = Arc::new(Server::new()); let serve_handle = server.clone().spawn("localhost:0").unwrap(); - let addr = serve_handle.local_addr().clone(); // The explicit type is required so that it doesn't deserialize a u32 instead of u64 - let client: Client<(), u64> = Client::new(addr).unwrap(); + let client: Client<(), u64, _> = Client::new(serve_handle.dialer()).unwrap(); assert_eq!(0, client.rpc(()).unwrap()); assert_eq!(1, server.count()); assert_eq!(1, client.rpc(()).unwrap()); @@ -187,8 +187,7 @@ mod test { let serve_handle = server.spawn_with_config("localhost:0", Config { timeout: Some(Duration::new(0, 10)) }) .unwrap(); - let addr = serve_handle.local_addr().clone(); - let client: Client<(), u64> = Client::new(addr).unwrap(); + let client: Client<(), u64, _> = Client::new(serve_handle.dialer()).unwrap(); let thread = thread::spawn(move || serve_handle.shutdown()); info!("force_shutdown:: rpc1: {:?}", client.rpc(())); thread.join().unwrap(); @@ -201,8 +200,7 @@ mod test { let serve_handle = server.spawn_with_config("localhost:0", Config { timeout: test_timeout() }) .unwrap(); - let addr = serve_handle.local_addr().clone(); - let client: Arc> = Arc::new(Client::new(addr).unwrap()); + let client: Arc> = Arc::new(Client::new(serve_handle.dialer()).unwrap()); client.rpc(()).unwrap(); serve_handle.shutdown(); match client.rpc(()) { @@ -219,8 +217,7 @@ mod test { let pool = Pool::new(concurrency); let server = Arc::new(BarrierServer::new(concurrency)); let serve_handle = server.clone().spawn("localhost:0").unwrap(); - let addr = serve_handle.local_addr().clone(); - let client: Client<(), u64> = Client::new(addr).unwrap(); + let client: Client<(), u64, _> = Client::new(serve_handle.dialer()).unwrap(); pool.scoped(|scope| { for _ in 0..concurrency { let client = client.try_clone().unwrap(); @@ -239,8 +236,7 @@ mod test { let _ = env_logger::init(); let server = Arc::new(Server::new()); let serve_handle = server.spawn("localhost:0").unwrap(); - let addr = serve_handle.local_addr().clone(); - let client: Client<(), u64> = Client::new(addr).unwrap(); + let client: Client<(), u64, _> = Client::new(serve_handle.dialer()).unwrap(); // Drop future immediately; does the reader channel panic when sending? client.rpc_async(()); diff --git a/tarpc/src/protocol/server.rs b/tarpc/src/protocol/server.rs index 69149d0..f635145 100644 --- a/tarpc/src/protocol/server.rs +++ b/tarpc/src/protocol/server.rs @@ -7,24 +7,27 @@ use serde; use scoped_pool::{Pool, Scope}; use std::fmt; use std::io::{self, BufReader, BufWriter}; -use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs}; use std::sync::mpsc::{Receiver, Sender, TryRecvError, channel}; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use std::thread::{self, JoinHandle}; use super::{Config, Deserialize, Error, Packet, Result, Serialize}; +use transport::{Dialer, Listener, Stream, Transport}; +use transport::tcp::TcpDialer; -struct ConnectionHandler<'a, S> - where S: Serve +struct ConnectionHandler<'a, S, St> + where S: Serve, + St: Stream { - read_stream: BufReader, - write_stream: BufWriter, + read_stream: BufReader, + write_stream: BufWriter, server: S, shutdown: &'a AtomicBool, } -impl<'a, S> ConnectionHandler<'a, S> - where S: Serve +impl<'a, S, St> ConnectionHandler<'a, S, St> + where S: Serve, + St: Stream { fn handle_conn<'b>(&'b mut self, scope: &Scope<'b>) -> Result<()> { let ConnectionHandler { @@ -83,7 +86,7 @@ impl<'a, S> ConnectionHandler<'a, S> } } - fn write(rx: Receiver::Reply>>, stream: &mut BufWriter) { + fn write(rx: Receiver::Reply>>, stream: &mut BufWriter) { loop { match rx.recv() { Err(e) => { @@ -101,21 +104,25 @@ impl<'a, S> ConnectionHandler<'a, S> } /// Provides methods for blocking until the server completes, -pub struct ServeHandle { +pub struct ServeHandle + where D: Dialer +{ tx: Sender<()>, join_handle: JoinHandle<()>, - addr: SocketAddr, + dialer: D, } -impl ServeHandle { +impl ServeHandle + where D: Dialer +{ /// Block until the server completes pub fn wait(self) { self.join_handle.join().expect(pos!()); } - /// Returns the address the server is bound to - pub fn local_addr(&self) -> &SocketAddr { - &self.addr + /// Returns the dialer to the server. + pub fn dialer(&self) -> &D { + &self.dialer } /// Shutdown the server. Gracefully shuts down the serve thread but currently does not @@ -123,7 +130,7 @@ impl ServeHandle { pub fn shutdown(self) { info!("ServeHandle: attempting to shut down the server."); self.tx.send(()).expect(pos!()); - if let Ok(_) = TcpStream::connect(self.addr) { + if let Ok(_) = self.dialer.dial() { self.join_handle.join().expect(pos!()); } else { warn!("ServeHandle: best effort shutdown of serve thread failed"); @@ -131,16 +138,19 @@ impl ServeHandle { } } -struct Server<'a, S: 'a> { +struct Server<'a, S: 'a, L> + where L: Listener +{ server: &'a S, - listener: TcpListener, + listener: L, read_timeout: Option, die_rx: Receiver<()>, shutdown: &'a AtomicBool, } -impl<'a, S: 'a> Server<'a, S> - where S: Serve + 'static +impl<'a, S, L> Server<'a, S, L> + where S: Serve + 'static, + L: Listener { fn serve<'b>(self, scope: &Scope<'b>) where 'a: 'b @@ -194,7 +204,9 @@ impl<'a, S: 'a> Server<'a, S> } } -impl<'a, S> Drop for Server<'a, S> { +impl<'a, S, L> Drop for Server<'a, S, L> + where L: Listener +{ fn drop(&mut self) { debug!("Shutting down connection handlers."); self.shutdown.store(true, Ordering::SeqCst); @@ -212,29 +224,33 @@ pub trait Serve: Send + Sync + Sized { fn serve(&self, request: Self::Request) -> Self::Reply; /// spawn - fn spawn(self, addr: A) -> io::Result - where A: ToSocketAddrs, + fn spawn(self, transport: T) -> io::Result::Dialer>> + where T: Transport, Self: 'static { - self.spawn_with_config(addr, Config::default()) + self.spawn_with_config(transport, Config::default()) } /// spawn - fn spawn_with_config(self, addr: A, config: Config) -> io::Result - where A: ToSocketAddrs, + fn spawn_with_config(self, + transport: T, + config: Config) + -> io::Result::Dialer>> + where T: Transport, Self: 'static { - let listener = try!(TcpListener::bind(&addr)); - let addr = try!(listener.local_addr()); - info!("spawn_with_config: spinning up server on {:?}", addr); + let listener = try!(transport.bind()); + let dialer = try!(listener.dialer()); + info!("spawn_with_config: spinning up server."); let (die_tx, die_rx) = channel(); + let timeout = config.timeout; let join_handle = thread::spawn(move || { let pool = Pool::new(100); // TODO(tjk): make this configurable, and expire idle threads let shutdown = AtomicBool::new(false); let server = Server { server: &self, listener: listener, - read_timeout: config.timeout, + read_timeout: timeout, die_rx: die_rx, shutdown: &shutdown, }; @@ -245,7 +261,7 @@ pub trait Serve: Send + Sync + Sized { Ok(ServeHandle { tx: die_tx, join_handle: join_handle, - addr: addr.clone(), + dialer: dialer, }) } diff --git a/tarpc/src/transport/mod.rs b/tarpc/src/transport/mod.rs new file mode 100644 index 0000000..f70b2bb --- /dev/null +++ b/tarpc/src/transport/mod.rs @@ -0,0 +1,91 @@ +use std::io::{self, Read, Write}; +use std::time::Duration; + +/// A factory for creating a listener on a given address. +/// For TCP, an address might be an IPv4 address; for Unix sockets, it +/// is just a file name. +pub trait Transport { + /// The type of listener that binds to the given address. + type Listener: Listener; + /// Return a listener on the given address, and a dialer to that address. + fn bind(&self) -> io::Result; +} + +/// Accepts incoming connections from dialers. +pub trait Listener: Send + 'static { + /// The type of address being listened on. + type Dialer: Dialer; + /// The type of stream this listener accepts. + type Stream: Stream; + /// Accept an incoming stream. + fn accept(&self) -> io::Result; + /// Returns the local address being listened on. + fn dialer(&self) -> io::Result; + /// Iterate over incoming connections. + fn incoming(&self) -> Incoming { + Incoming { listener: self } + } +} + +/// A cloneable Reader/Writer. +pub trait Stream: Read + Write + Send + Sized + 'static { + /// Creates a new independently owned handle to the Stream. + /// + /// The returned Stream should reference the same stream that this + /// object references. Both handles should read and write the same + /// stream of data, and options set on one stream should be propagated + /// to the other stream. + fn try_clone(&self) -> io::Result; + /// Sets a read timeout. + /// + /// If the value specified is `None`, then read calls will block indefinitely. + /// It is an error to pass the zero `Duration` to this method. + fn set_read_timeout(&self, dur: Option) -> io::Result<()>; + /// Sets a write timeout. + /// + /// If the value specified is `None`, then write calls will block indefinitely. + /// It is an error to pass the zero `Duration` to this method. + fn set_write_timeout(&self, dur: Option) -> io::Result<()>; + /// Shuts down both ends of the stream. + /// + /// Implementations should cause all pending and future I/O on the specified + /// portions to return immediately with an appropriate value. + fn shutdown(&self) -> io::Result<()>; +} + +/// A `Stream` factory. +pub trait Dialer { + /// The type of `Stream` this can create. + type Stream: Stream; + /// Open a stream. + fn dial(&self) -> io::Result; +} + +impl Dialer for P + where P: ::std::ops::Deref, + D: Dialer + 'static +{ + type Stream = D::Stream; + + fn dial(&self) -> io::Result { + (**self).dial() + } +} + +/// Iterates over incoming connections. +pub struct Incoming<'a, L: Listener + ?Sized + 'a> { + listener: &'a L, +} + +impl<'a, L: Listener> Iterator for Incoming<'a, L> { + type Item = io::Result; + + fn next(&mut self) -> Option { + Some(self.listener.accept()) + } +} + +/// Provides a TCP transport. +pub mod tcp; +/// Provides a unix socket transport. +pub mod unix; diff --git a/tarpc/src/transport/tcp.rs b/tarpc/src/transport/tcp.rs new file mode 100644 index 0000000..2a78b65 --- /dev/null +++ b/tarpc/src/transport/tcp.rs @@ -0,0 +1,75 @@ +use std::io; +use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs}; +use std::time::Duration; + +/// A transport for TCP. +pub struct TcpTransport(pub A); + +impl super::Transport for TcpTransport { + type Listener = TcpListener; + + fn bind(&self) -> io::Result { + TcpListener::bind(&self.0) + } +} + +impl super::Transport for A { + type Listener = TcpListener; + + fn bind(&self) -> io::Result { + TcpListener::bind(self) + } +} + +impl super::Listener for TcpListener { + type Dialer = TcpDialer; + + type Stream = TcpStream; + + fn accept(&self) -> io::Result { + self.accept().map(|(stream, _)| stream) + } + + fn dialer(&self) -> io::Result> { + self.local_addr().map(|addr| TcpDialer(addr)) + } +} + +impl super::Stream for TcpStream { + fn try_clone(&self) -> io::Result { + self.try_clone() + } + + fn set_read_timeout(&self, dur: Option) -> io::Result<()> { + self.set_read_timeout(dur) + } + + fn set_write_timeout(&self, dur: Option) -> io::Result<()> { + self.set_write_timeout(dur) + } + + fn shutdown(&self) -> io::Result<()> { + self.shutdown(::std::net::Shutdown::Both) + } +} + +/// Connects to a socket address. +pub struct TcpDialer(pub A) where A: ToSocketAddrs; + +impl super::Dialer for TcpDialer + where A: ToSocketAddrs +{ + type Stream = TcpStream; + + fn dial(&self) -> io::Result { + TcpStream::connect(&self.0) + } +} + +impl super::Dialer for str { + type Stream = TcpStream; + + fn dial(&self) -> io::Result { + TcpStream::connect(self) + } +} diff --git a/tarpc/src/transport/unix.rs b/tarpc/src/transport/unix.rs new file mode 100644 index 0000000..9a4b590 --- /dev/null +++ b/tarpc/src/transport/unix.rs @@ -0,0 +1,70 @@ +use std::io; +use std::path::{Path, PathBuf}; +use std::time::Duration; +use unix_socket::{UnixListener, UnixStream}; + +/// A transport for unix sockets. +pub struct UnixTransport

(pub P) where P: AsRef; + +impl

super::Transport for UnixTransport

+ where P: AsRef +{ + type Listener = UnixListener; + + fn bind(&self) -> io::Result { + UnixListener::bind(&self.0) + } +} + +/// Connects to a unix socket address. +pub struct UnixDialer

(pub P) where P: AsRef; + +impl

super::Dialer for UnixDialer

+ where P: AsRef +{ + type Stream = UnixStream; + + fn dial(&self) -> io::Result { + UnixStream::connect(&self.0) + } +} + +impl super::Listener for UnixListener { + type Stream = UnixStream; + + type Dialer = UnixDialer; + + fn accept(&self) -> io::Result { + self.accept().map(|(stream, _)| stream) + } + + fn dialer(&self) -> io::Result> { + self.local_addr().and_then(|addr| { + match addr.as_pathname() { + Some(path) => Ok(UnixDialer(path.to_owned())), + None => { + Err(io::Error::new(io::ErrorKind::AddrNotAvailable, + "Couldn't get a path to bound unix socket")) + } + } + }) + } +} + +impl super::Stream for UnixStream { + fn try_clone(&self) -> io::Result { + self.try_clone() + } + + fn set_read_timeout(&self, timeout: Option) -> io::Result<()> { + self.set_read_timeout(timeout) + } + + fn set_write_timeout(&self, timeout: Option) -> io::Result<()> { + self.set_write_timeout(timeout) + } + + fn shutdown(&self) -> io::Result<()> { + self.shutdown(::std::net::Shutdown::Both) + } +} diff --git a/tarpc_examples/src/lib.rs b/tarpc_examples/src/lib.rs index 9842c57..72bf45a 100644 --- a/tarpc_examples/src/lib.rs +++ b/tarpc_examples/src/lib.rs @@ -41,8 +41,9 @@ mod benchmark { Arc::new(Mutex::new(handle)) }; static ref CLIENT: Arc> = { - let addr = HANDLE.lock().unwrap().local_addr().clone(); - let client = AsyncClient::new(addr).unwrap(); + let lock = HANDLE.lock().unwrap(); + let dialer = lock.dialer(); + let client = AsyncClient::new(dialer).unwrap(); Arc::new(Mutex::new(client)) }; }