From 7ad0e4b070d32e9c40e64eee48aa0f953e5f40f8 Mon Sep 17 00:00:00 2001 From: Tim Date: Thu, 25 Oct 2018 11:22:55 -0700 Subject: [PATCH] Service registry (#204) # Changes ## Client is now a trait And `Channel` implements `Client`. Previously, `Client` was a thin wrapper around `Channel`. This was changed to allow for mapping the request and response types. For example, you can take a `channel: Channel` and do: ```rust channel .with_request(|req: Req2| -> Req { ... }) .map_response(|resp: Resp| -> Resp2 { ... }) ``` ...which returns a type that implements `Client`. ### Why would you want to map request and response types? The main benefit of this is that it enables creating different client types backed by the same channel. For example, you could run multiple clients multiplexing requests over a single `TcpStream`. I have a demo in `tarpc/examples/service_registry.rs` showing how you might do this with a bincode transport. I am considering factoring out the service registry portion of that to an actual library, because it's doing pretty cool stuff. For this PR, though, it'll just be part of the example. ## Client::new is now client::new This is pretty minor, but necessary because async fns can't currently exist on traits. I changed `Server::new` to match this as well. ## Macro-generated Clients are generic over the backing Client. This is a natural consequence of the above change. However, it is transparent to the user by keeping `Channel` as the default type for the `` type parameter. `new_stub` returns `Client>`, and other clients can be created via the `From` trait. ## example-service/ now has two binaries, one for client and one for server. This serves as a "realistic" example of how one might set up a service. The other examples all run the client and server in the same binary, which isn't realistic in distributed systems use cases. ## `service!` trait fns take self by value. Services are already cloned per request, so this just passes on that flexibility to the trait implementers. # Open Questions In the service registry example, multiple services are running on a single port, and thus multiple clients are sending requests over a single `TcpStream`. This has implications for throttling: [`max_in_flight_requests_per_connection`](https://github.com/google/tarpc/blob/master/rpc/src/server/mod.rs#L57-L60) will set a maximum for the sum of requests for all clients sharing a single connection. I think this is reasonable behavior, but users may expect this setting to act like `max_in_flight_requests_per_client`. Fixes #103 #153 #205 --- .travis.yml | 2 +- bincode-transport/rustfmt.toml | 2 +- bincode-transport/src/compat.rs | 15 +- bincode-transport/src/lib.rs | 39 +- bincode-transport/tests/bench.rs | 18 +- bincode-transport/tests/cancel.rs | 23 +- bincode-transport/tests/pushback.rs | 17 +- example-service/Cargo.toml | 7 +- example-service/src/client.rs | 79 ++++ example-service/src/lib.rs | 2 +- example-service/src/{main.rs => server.rs} | 61 ++-- plugins/rustfmt.toml | 2 +- plugins/src/lib.rs | 23 +- rpc/rustfmt.toml | 2 +- rpc/src/client/{dispatch.rs => channel.rs} | 335 ++++++++++++++--- rpc/src/client/mod.rs | 146 +++++--- rpc/src/lib.rs | 11 +- rpc/src/server/filter.rs | 13 +- rpc/src/server/mod.rs | 72 ++-- rpc/src/transport/channel.rs | 19 +- rpc/src/transport/mod.rs | 23 +- rpc/src/util/deadline_compat.rs | 4 +- rpc/src/util/serde.rs | 3 +- tarpc/Cargo.toml | 2 + tarpc/examples/pubsub.rs | 26 +- tarpc/examples/readme.rs | 11 +- tarpc/examples/server_calling_server.rs | 21 +- tarpc/examples/service_registry.rs | 402 +++++++++++++++++++++ tarpc/rustfmt.toml | 2 +- tarpc/src/lib.rs | 12 +- tarpc/src/macros.rs | 44 ++- tarpc/tests/latency.rs | 17 +- trace/rustfmt.toml | 2 +- 33 files changed, 1127 insertions(+), 330 deletions(-) create mode 100644 example-service/src/client.rs rename example-service/src/{main.rs => server.rs} (53%) rename rpc/src/client/{dispatch.rs => channel.rs} (74%) create mode 100644 tarpc/examples/service_registry.rs diff --git a/.travis.yml b/.travis.yml index b34256e..0f33cb9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,4 +9,4 @@ os: - linux script: - - cargo test --all --all-features + - cargo test --all-targets --all-features diff --git a/bincode-transport/rustfmt.toml b/bincode-transport/rustfmt.toml index 0ef5137..32a9786 100644 --- a/bincode-transport/rustfmt.toml +++ b/bincode-transport/rustfmt.toml @@ -1 +1 @@ -edition = "Edition2018" +edition = "2018" diff --git a/bincode-transport/src/compat.rs b/bincode-transport/src/compat.rs index a9aa7c4..8b37d49 100644 --- a/bincode-transport/src/compat.rs +++ b/bincode-transport/src/compat.rs @@ -4,9 +4,12 @@ use futures_legacy::{ self as executor01, Notify as Notify01, NotifyHandle as NotifyHandle01, UnsafeNotify as UnsafeNotify01, }, - Async as Async01, AsyncSink as AsyncSink01, Stream as Stream01, Sink as Sink01 + Async as Async01, AsyncSink as AsyncSink01, Sink as Sink01, Stream as Stream01, +}; +use std::{ + pin::Pin, + task::{self, LocalWaker, Poll}, }; -use std::{pin::Pin, task::{self, LocalWaker, Poll}}; /// A shim to convert a 0.1 Sink + Stream to a 0.3 Sink + Stream. #[derive(Debug)] @@ -18,7 +21,10 @@ pub struct Compat { impl Compat { /// Returns a new Compat. pub fn new(inner: S) -> Self { - Compat { inner, staged_item: None } + Compat { + inner, + staged_item: None, + } } /// Unwraps Compat, returning the inner value. @@ -34,7 +40,7 @@ impl Compat { impl Stream for Compat where - S: Stream01 + S: Stream01, { type Item = Result; @@ -142,4 +148,3 @@ impl<'a> From> for NotifyHandle01 { unsafe { NotifyWaker(handle.0.clone().into_waker()).clone_raw() } } } - diff --git a/bincode-transport/src/lib.rs b/bincode-transport/src/lib.rs index 17b467d..7c20481 100644 --- a/bincode-transport/src/lib.rs +++ b/bincode-transport/src/lib.rs @@ -11,16 +11,27 @@ pin, arbitrary_self_types, await_macro, - async_await, + async_await )] #![deny(missing_docs, missing_debug_implementations)] -use async_bincode::{AsyncBincodeStream, AsyncDestination}; -use futures::{compat::{Compat01As03, Future01CompatExt, Stream01CompatExt}, prelude::*, ready}; -use pin_utils::unsafe_pinned; -use serde::{Serialize, Deserialize}; use self::compat::Compat; -use std::{error::Error, io, marker::PhantomData, net::SocketAddr, pin::Pin, task::{LocalWaker, Poll}}; +use async_bincode::{AsyncBincodeStream, AsyncDestination}; +use futures::{ + compat::{Compat01As03, Future01CompatExt, Stream01CompatExt}, + prelude::*, + ready, +}; +use pin_utils::unsafe_pinned; +use serde::{Deserialize, Serialize}; +use std::{ + error::Error, + io, + marker::PhantomData, + net::SocketAddr, + pin::Pin, + task::{LocalWaker, Poll}, +}; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_tcp::{TcpListener, TcpStream}; @@ -29,7 +40,7 @@ mod compat; /// A transport that serializes to, and deserializes from, a [`TcpStream`]. #[derive(Debug)] pub struct Transport { - inner: Compat, SinkItem> + inner: Compat, SinkItem>, } impl Transport { @@ -40,7 +51,9 @@ impl Transport { } impl Transport { - unsafe_pinned!(inner: Compat, SinkItem>); + unsafe_pinned!( + inner: Compat, SinkItem> + ); } impl Stream for Transport @@ -55,7 +68,9 @@ where Poll::Pending => Poll::Pending, Poll::Ready(None) => Poll::Ready(None), Poll::Ready(Some(Ok(next))) => Poll::Ready(Some(Ok(next))), - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e)))), + Poll::Ready(Some(Err(e))) => { + Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e)))) + } } } } @@ -119,7 +134,6 @@ where SinkItem: Serialize, { Transport::from(io) - } impl From for Transport { @@ -131,7 +145,9 @@ impl From for Transport { } /// Connects to `addr`, wrapping the connection in a bincode transport. -pub async fn connect(addr: &SocketAddr) -> io::Result> +pub async fn connect( + addr: &SocketAddr, +) -> io::Result> where Item: for<'de> Deserialize<'de>, SinkItem: Serialize, @@ -184,4 +200,3 @@ where Poll::Ready(next.map(|conn| Ok(new(conn)))) } } - diff --git a/bincode-transport/tests/bench.rs b/bincode-transport/tests/bench.rs index 670acf8..5e80f12 100644 --- a/bincode-transport/tests/bench.rs +++ b/bincode-transport/tests/bench.rs @@ -20,9 +20,8 @@ extern crate test; use self::test::stats::Stats; use futures::{compat::TokioDefaultSpawner, prelude::*}; use rpc::{ - client::{self, Client}, - context, - server::{self, Handler, Server}, + client, context, + server::{Handler, Server}, }; use std::{ io, @@ -34,17 +33,17 @@ async fn bench() -> io::Result<()> { let addr = listener.local_addr(); tokio_executor::spawn( - Server::::new(server::Config::default()) + Server::::default() .incoming(listener) .take(1) .respond_with(|_ctx, request| futures::future::ready(Ok(request))) .unit_error() .boxed() - .compat() + .compat(), ); let conn = await!(tarpc_bincode_transport::connect(&addr))?; - let client = &mut await!(Client::::new(client::Config::default(), conn))?; + let client = &mut await!(client::new::(client::Config::default(), conn))?; let total = 10_000usize; let mut successful = 0u32; @@ -104,12 +103,7 @@ fn bench_small_packet() -> io::Result<()> { env_logger::init(); rpc::init(TokioDefaultSpawner); - tokio::run( - bench() - .map_err(|e| panic!(e.to_string())) - .boxed() - .compat(), - ); + tokio::run(bench().map_err(|e| panic!(e.to_string())).boxed().compat()); println!("done"); Ok(()) diff --git a/bincode-transport/tests/cancel.rs b/bincode-transport/tests/cancel.rs index ed14fb7..6008eaa 100644 --- a/bincode-transport/tests/cancel.rs +++ b/bincode-transport/tests/cancel.rs @@ -6,7 +6,7 @@ //! Tests client/server control flow. -#![feature(generators, await_macro, async_await, futures_api,)] +#![feature(generators, await_macro, async_await, futures_api)] use futures::{ compat::{Future01CompatExt, TokioDefaultSpawner}, @@ -15,11 +15,7 @@ use futures::{ }; use log::{info, trace}; use rand::distributions::{Distribution, Normal}; -use rpc::{ - client::{self, Client}, - context, - server::{self, Server}, -}; +use rpc::{client, context, server::Server}; use std::{ io, time::{Duration, Instant, SystemTime}, @@ -40,7 +36,7 @@ impl AsDuration for SystemTime { async fn run() -> io::Result<()> { let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; let addr = listener.local_addr(); - let server = Server::::new(server::Config::default()) + let server = Server::::default() .incoming(listener) .take(1) .for_each(async move |channel| { @@ -80,7 +76,7 @@ async fn run() -> io::Result<()> { tokio_executor::spawn(server.unit_error().boxed().compat()); let conn = await!(tarpc_bincode_transport::connect(&addr))?; - let client = await!(Client::::new( + let client = await!(client::new::( client::Config::default(), conn ))?; @@ -88,7 +84,7 @@ async fn run() -> io::Result<()> { // Proxy service let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; let addr = listener.local_addr(); - let proxy_server = Server::::new(server::Config::default()) + let proxy_server = Server::::default() .incoming(listener) .take(1) .for_each(move |channel| { @@ -115,7 +111,7 @@ async fn run() -> io::Result<()> { config.max_in_flight_requests = 10; config.pending_request_buffer = 10; - let client = await!(Client::::new( + let client = await!(client::new::( config, await!(tarpc_bincode_transport::connect(&addr))? ))?; @@ -142,11 +138,6 @@ fn cancel_slower() -> io::Result<()> { env_logger::init(); rpc::init(TokioDefaultSpawner); - tokio::run( - run() - .boxed() - .map_err(|e| panic!(e)) - .compat(), - ); + tokio::run(run().boxed().map_err(|e| panic!(e)).compat()); Ok(()) } diff --git a/bincode-transport/tests/pushback.rs b/bincode-transport/tests/pushback.rs index e1e5e5a..4431bea 100644 --- a/bincode-transport/tests/pushback.rs +++ b/bincode-transport/tests/pushback.rs @@ -6,7 +6,7 @@ //! Tests client/server control flow. -#![feature(generators, await_macro, async_await, futures_api,)] +#![feature(generators, await_macro, async_await, futures_api)] use futures::{ compat::{Future01CompatExt, TokioDefaultSpawner}, @@ -14,11 +14,7 @@ use futures::{ }; use log::{error, info, trace}; use rand::distributions::{Distribution, Normal}; -use rpc::{ - client::{self, Client}, - context, - server::{self, Server}, -}; +use rpc::{client, context, server::Server}; use std::{ io, time::{Duration, Instant, SystemTime}, @@ -39,7 +35,7 @@ impl AsDuration for SystemTime { async fn run() -> io::Result<()> { let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; let addr = listener.local_addr(); - let server = Server::::new(server::Config::default()) + let server = Server::::default() .incoming(listener) .take(1) .for_each(async move |channel| { @@ -83,7 +79,7 @@ async fn run() -> io::Result<()> { config.pending_request_buffer = 10; let conn = await!(tarpc_bincode_transport::connect(&addr))?; - let client = await!(Client::::new(config, conn))?; + let client = await!(client::new::(config, conn))?; let clients = (1..=100u32).map(|_| client.clone()).collect::>(); for mut client in clients { @@ -96,7 +92,10 @@ async fn run() -> io::Result<()> { Ok(response) => info!("[{}] response: {}", trace_id, response), Err(e) => error!("[{}] request error: {:?}: {}", trace_id, e.kind(), e), } - }.unit_error().boxed().compat() + } + .unit_error() + .boxed() + .compat(), ); } diff --git a/example-service/Cargo.toml b/example-service/Cargo.toml index baa1a5c..9b36772 100644 --- a/example-service/Cargo.toml +++ b/example-service/Cargo.toml @@ -16,6 +16,7 @@ description = "An example server built on tarpc." [dependencies] bincode-transport = { package = "tarpc-bincode-transport", version = "0.2", path = "../bincode-transport" } +clap = "2.0" futures-preview = { version = "0.3.0-alpha.9", features = ["compat", "tokio-compat"] } serde = { version = "1.0" } tarpc = { version = "0.13", path = "../tarpc", features = ["serde1"] } @@ -28,4 +29,8 @@ path = "src/lib.rs" [[bin]] name = "server" -path = "src/main.rs" +path = "src/server.rs" + +[[bin]] +name = "client" +path = "src/client.rs" diff --git a/example-service/src/client.rs b/example-service/src/client.rs new file mode 100644 index 0000000..7ed3be3 --- /dev/null +++ b/example-service/src/client.rs @@ -0,0 +1,79 @@ +// Copyright 2018 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +#![feature( + futures_api, + pin, + arbitrary_self_types, + await_macro, + async_await +)] + +use clap::{App, Arg}; +use futures::{compat::TokioDefaultSpawner, prelude::*}; +use std::{io, net::SocketAddr}; +use tarpc::{client, context}; + +async fn run(server_addr: SocketAddr, name: String) -> io::Result<()> { + let transport = await!(bincode_transport::connect(&server_addr))?; + + // new_stub is generated by the service! macro. Like Server, it takes a config and any + // Transport as input, and returns a Client, also generated by the macro. + // by the service mcro. + let mut client = await!(service::new_stub(client::Config::default(), transport))?; + + // The client has an RPC method for each RPC defined in service!. It takes the same args + // as defined, with the addition of a Context, which is always the first arg. The Context + // specifies a deadline and trace information which can be helpful in debugging requests. + let hello = await!(client.hello(context::current(), name))?; + + println!("{}", hello); + + Ok(()) +} + +fn main() { + let flags = App::new("Hello Client") + .version("0.1") + .author("Tim ") + .about("Say hello!") + .arg( + Arg::with_name("server_addr") + .long("server_addr") + .value_name("ADDRESS") + .help("Sets the server address to connect to.") + .required(true) + .takes_value(true), + ) + .arg( + Arg::with_name("name") + .short("n") + .long("name") + .value_name("STRING") + .help("Sets the name to say hello to.") + .required(true) + .takes_value(true), + ) + .get_matches(); + + tarpc::init(TokioDefaultSpawner); + + let server_addr = flags.value_of("server_addr").unwrap(); + let server_addr = server_addr + .parse() + .unwrap_or_else(|e| panic!(r#"--server_addr value "{}" invalid: {}"#, server_addr, e)); + + let name = flags.value_of("name").unwrap(); + + tarpc::init(TokioDefaultSpawner); + + tokio::run( + run(server_addr, name.into()) + .map_err(|e| eprintln!("Oh no: {}", e)) + .boxed() + .compat(), + ); +} diff --git a/example-service/src/lib.rs b/example-service/src/lib.rs index 1cf003d..9bc8b7a 100644 --- a/example-service/src/lib.rs +++ b/example-service/src/lib.rs @@ -10,7 +10,7 @@ arbitrary_self_types, await_macro, async_await, - proc_macro_hygiene, + proc_macro_hygiene )] // This is the service definition. It looks a lot like a trait definition. diff --git a/example-service/src/main.rs b/example-service/src/server.rs similarity index 53% rename from example-service/src/main.rs rename to example-service/src/server.rs index e544edf..b8b77d2 100644 --- a/example-service/src/main.rs +++ b/example-service/src/server.rs @@ -9,19 +9,20 @@ pin, arbitrary_self_types, await_macro, - async_await, + async_await )] +use clap::{App, Arg}; use futures::{ compat::TokioDefaultSpawner, future::{self, Ready}, prelude::*, }; +use std::{io, net::SocketAddr}; use tarpc::{ - client, context, - server::{self, Handler, Server}, + context, + server::{Handler, Server}, }; -use std::io; // This is the type that implements the generated Service trait. It is the business logic // and is used to start the server. @@ -34,52 +35,56 @@ impl service::Service for HelloServer { type HelloFut = Ready; - fn hello(&self, _: context::Context, name: String) -> Self::HelloFut { + fn hello(self, _: context::Context, name: String) -> Self::HelloFut { future::ready(format!("Hello, {}!", name)) } } -async fn run() -> io::Result<()> { +async fn run(server_addr: SocketAddr) -> io::Result<()> { // bincode_transport is provided by the associated crate bincode-transport. It makes it easy // to start up a serde-powered bincode serialization strategy over TCP. - let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let addr = transport.local_addr(); + let transport = bincode_transport::listen(&server_addr)?; // The server is configured with the defaults. - let server = Server::new(server::Config::default()) + let server = Server::default() // Server can listen on any type that implements the Transport trait. .incoming(transport) - // Close the stream after the client connects - .take(1) // serve is generated by the service! macro. It takes as input any type implementing // the generated Service trait. .respond_with(service::serve(HelloServer)); - tokio_executor::spawn(server.unit_error().boxed().compat()); - - let transport = await!(bincode_transport::connect(&addr))?; - - // new_stub is generated by the service! macro. Like Server, it takes a config and any - // Transport as input, and returns a Client, also generated by the macro. - // by the service mcro. - let mut client = await!(service::new_stub(client::Config::default(), transport))?; - - // The client has an RPC method for each RPC defined in service!. It takes the same args - // as defined, with the addition of a Context, which is always the first arg. The Context - // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = await!(client.hello(context::current(), "Stim".to_string()))?; - - println!("{}", hello); + await!(server); Ok(()) } fn main() { + let flags = App::new("Hello Server") + .version("0.1") + .author("Tim ") + .about("Say hello!") + .arg( + Arg::with_name("port") + .short("p") + .long("port") + .value_name("NUMBER") + .help("Sets the port number to listen on") + .required(true) + .takes_value(true), + ) + .get_matches(); + + let port = flags.value_of("port").unwrap(); + let port = port + .parse() + .unwrap_or_else(|e| panic!(r#"--port value "{}" invalid: {}"#, port, e)); + tarpc::init(TokioDefaultSpawner); - tokio::run(run() + tokio::run( + run(([0, 0, 0, 0], port).into()) .map_err(|e| eprintln!("Oh no: {}", e)) .boxed() - .compat() + .compat(), ); } diff --git a/plugins/rustfmt.toml b/plugins/rustfmt.toml index 0ef5137..32a9786 100644 --- a/plugins/rustfmt.toml +++ b/plugins/rustfmt.toml @@ -1 +1 @@ -edition = "Edition2018" +edition = "2018" diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index dc750c4..1109c9a 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -4,31 +4,35 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +extern crate itertools; extern crate proc_macro; extern crate proc_macro2; -extern crate syn; -extern crate itertools; extern crate quote; +extern crate syn; use proc_macro::TokenStream; use itertools::Itertools; -use quote::ToTokens; -use syn::{Ident, TraitItemType, TypePath, parse}; use proc_macro2::Span; +use quote::ToTokens; use std::str::FromStr; +use syn::{parse, Ident, TraitItemType, TypePath}; #[proc_macro] pub fn snake_to_camel(input: TokenStream) -> TokenStream { let i = input.clone(); - let mut assoc_type = parse::(input).unwrap_or_else(|_| panic!("Could not parse trait item from:\n{}", i)); + let mut assoc_type = parse::(input) + .unwrap_or_else(|_| panic!("Could not parse trait item from:\n{}", i)); let old_ident = convert(&mut assoc_type.ident); for mut attr in &mut assoc_type.attrs { if let Some(pair) = attr.path.segments.first() { if pair.value().ident == "doc" { - attr.tts = proc_macro2::TokenStream::from_str(&attr.tts.to_string().replace("{}", &old_ident)).unwrap(); + attr.tts = proc_macro2::TokenStream::from_str( + &attr.tts.to_string().replace("{}", &old_ident), + ) + .unwrap(); } } } @@ -41,12 +45,7 @@ pub fn ty_snake_to_camel(input: TokenStream) -> TokenStream { let mut path = parse::(input).unwrap(); // Only capitalize the final segment - convert(&mut path.path - .segments - .last_mut() - .unwrap() - .into_value() - .ident); + convert(&mut path.path.segments.last_mut().unwrap().into_value().ident); path.into_token_stream().into() } diff --git a/rpc/rustfmt.toml b/rpc/rustfmt.toml index 0ef5137..32a9786 100644 --- a/rpc/rustfmt.toml +++ b/rpc/rustfmt.toml @@ -1 +1 @@ -edition = "Edition2018" +edition = "2018" diff --git a/rpc/src/client/dispatch.rs b/rpc/src/client/channel.rs similarity index 74% rename from rpc/src/client/dispatch.rs rename to rpc/src/client/channel.rs index d0e53b1..f31fd3f 100644 --- a/rpc/src/client/dispatch.rs +++ b/rpc/src/client/channel.rs @@ -11,18 +11,19 @@ use crate::{ }; use fnv::FnvHashMap; use futures::{ - Poll, channel::{mpsc, oneshot}, prelude::*, ready, stream::Fuse, task::LocalWaker, + Poll, }; use humantime::format_rfc3339; use log::{debug, error, info, trace}; -use pin_utils::unsafe_pinned; +use pin_utils::{unsafe_pinned, unsafe_unpinned}; use std::{ io, + marker::{self, Unpin}, net::SocketAddr, pin::Pin, sync::{ @@ -37,7 +38,7 @@ use super::Config; /// Handles communication from the client to request dispatch. #[derive(Debug)] -pub(crate) struct Channel { +pub struct Channel { to_dispatch: mpsc::Sender>, /// Channel to send a cancel message to the dispatcher. cancellation: RequestCancellation, @@ -57,14 +58,58 @@ impl Clone for Channel { } } +/// A future returned by [`Channel::send`] that resolves to a server response. +#[derive(Debug)] +#[must_use = "futures do nothing unless polled"] +struct Send<'a, Req, Resp> { + fut: MapOkDispatchResponse< + MapErrConnectionReset>>>, + Resp, + >, +} + +impl<'a, Req, Resp> Send<'a, Req, Resp> { + unsafe_pinned!( + fut: MapOkDispatchResponse< + MapErrConnectionReset< + futures::sink::Send<'a, mpsc::Sender>>, + >, + Resp, + > + ); +} + +impl<'a, Req, Resp> Future for Send<'a, Req, Resp> { + type Output = io::Result>; + + fn poll(mut self: Pin<&mut Self>, lw: &LocalWaker) -> Poll { + self.fut().poll(lw) + } +} + +/// A future returned by [`Channel::call`] that resolves to a server response. +#[derive(Debug)] +#[must_use = "futures do nothing unless polled"] +pub struct Call<'a, Req, Resp> { + fut: AndThenIdent, DispatchResponse>, +} + +impl<'a, Req, Resp> Call<'a, Req, Resp> { + unsafe_pinned!(fut: AndThenIdent, DispatchResponse>); +} + +impl<'a, Req, Resp> Future for Call<'a, Req, Resp> { + type Output = io::Result; + + fn poll(mut self: Pin<&mut Self>, lw: &LocalWaker) -> Poll { + self.fut().poll(lw) + } +} + impl Channel { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves when the request is sent (not when the response is received). - pub(crate) async fn send( - &mut self, - mut ctx: context::Context, - request: Req, - ) -> io::Result> { + fn send<'a>(&'a mut self, mut ctx: context::Context, request: Req) -> Send<'a, Req, Resp> { // Convert the context to the call context. ctx.trace_context.parent_id = Some(ctx.trace_context.span_id); ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng()); @@ -82,38 +127,40 @@ impl Channel { let (response_completion, response) = oneshot::channel(); let cancellation = self.cancellation.clone(); let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); - await!(self.to_dispatch.send(DispatchRequest { - ctx, - request_id, - request, - response_completion, - })).map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset))?; - Ok(DispatchResponse { - response: deadline_compat::Deadline::new(response, deadline), - complete: false, - request_id, - cancellation, - ctx, - server_addr: self.server_addr, - }) + let server_addr = self.server_addr; + Send { + fut: MapOkDispatchResponse::new( + MapErrConnectionReset::new(self.to_dispatch.send(DispatchRequest { + ctx, + request_id, + request, + response_completion, + })), + DispatchResponse { + response: deadline_compat::Deadline::new(response, deadline), + complete: false, + request_id, + cancellation, + ctx, + server_addr, + }, + ), + } } /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. - pub(crate) async fn call( - &mut self, - context: context::Context, - request: Req, - ) -> io::Result { - let response_future = await!(self.send(context, request))?; - await!(response_future) + pub fn call<'a>(&'a mut self, context: context::Context, request: Req) -> Call<'a, Req, Resp> { + Call { + fut: AndThenIdent::new(self.send(context, request)), + } } } /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. #[derive(Debug)] -pub struct DispatchResponse { +struct DispatchResponse { response: deadline_compat::Deadline>>, ctx: context::Context, complete: bool, @@ -205,9 +252,9 @@ pub async fn spawn( server_addr: SocketAddr, ) -> io::Result> where - Req: Send, - Resp: Send, - C: Transport, SinkItem = ClientMessage> + Send, + Req: marker::Send + 'static, + Resp: marker::Send + 'static, + C: Transport, SinkItem = ClientMessage> + marker::Send + 'static, { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); @@ -220,16 +267,18 @@ where transport: transport.fuse(), in_flight_requests: FnvHashMap::default(), pending_requests: pending_requests.fuse(), - }.unwrap_or_else(move |e| error!("[{}] Connection broken: {}", server_addr, e)) - ).map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!( - "Could not spawn client dispatch task. Is shutdown: {}", - e.is_shutdown() - ), - ) - })?; + } + .unwrap_or_else(move |e| error!("[{}] Connection broken: {}", server_addr, e)), + ) + .map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!( + "Could not spawn client dispatch task. Is shutdown: {}", + e.is_shutdown() + ), + ) + })?; Ok(Channel { to_dispatch, @@ -258,8 +307,8 @@ struct RequestDispatch { impl RequestDispatch where - Req: Send, - Resp: Send, + Req: marker::Send, + Resp: marker::Send, C: Transport, SinkItem = ClientMessage>, { unsafe_pinned!(server_addr: SocketAddr); @@ -462,8 +511,8 @@ where impl Future for RequestDispatch where - Req: Send, - Resp: Send, + Req: marker::Send, + Resp: marker::Send, C: Transport, SinkItem = ClientMessage>, { type Output = io::Result<()>; @@ -563,6 +612,185 @@ impl Stream for CanceledRequests { } } +#[derive(Debug)] +#[must_use = "futures do nothing unless polled"] +struct MapErrConnectionReset { + future: Fut, + finished: Option<()>, +} + +impl MapErrConnectionReset { + unsafe_pinned!(future: Fut); + unsafe_unpinned!(finished: Option<()>); + + fn new(future: Fut) -> MapErrConnectionReset { + MapErrConnectionReset { + future, + finished: Some(()), + } + } +} + +impl Unpin for MapErrConnectionReset {} + +impl Future for MapErrConnectionReset +where + Fut: TryFuture, +{ + type Output = io::Result; + + fn poll(mut self: Pin<&mut Self>, lw: &LocalWaker) -> Poll { + match self.future().try_poll(lw) { + Poll::Pending => Poll::Pending, + Poll::Ready(result) => { + self.finished().take().expect( + "MapErrConnectionReset must not be polled after it returned `Poll::Ready`", + ); + Poll::Ready(result.map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset))) + } + } + } +} + +#[derive(Debug)] +#[must_use = "futures do nothing unless polled"] +struct MapOkDispatchResponse { + future: Fut, + response: Option>, +} + +impl MapOkDispatchResponse { + unsafe_pinned!(future: Fut); + unsafe_unpinned!(response: Option>); + + fn new(future: Fut, response: DispatchResponse) -> MapOkDispatchResponse { + MapOkDispatchResponse { + future, + response: Some(response), + } + } +} + +impl Unpin for MapOkDispatchResponse {} + +impl Future for MapOkDispatchResponse +where + Fut: TryFuture, +{ + type Output = Result, Fut::Error>; + + fn poll(mut self: Pin<&mut Self>, lw: &LocalWaker) -> Poll { + match self.future().try_poll(lw) { + Poll::Pending => Poll::Pending, + Poll::Ready(result) => { + let response = self + .response() + .take() + .expect("MapOk must not be polled after it returned `Poll::Ready`"); + Poll::Ready(result.map(|_| response)) + } + } + } +} + +#[derive(Debug)] +#[must_use = "futures do nothing unless polled"] +struct AndThenIdent { + try_chain: TryChain, +} + +impl AndThenIdent +where + Fut1: TryFuture, + Fut2: TryFuture, +{ + unsafe_pinned!(try_chain: TryChain); + + /// Creates a new `Then`. + fn new(future: Fut1) -> AndThenIdent { + AndThenIdent { + try_chain: TryChain::new(future), + } + } +} + +impl Future for AndThenIdent +where + Fut1: TryFuture, + Fut2: TryFuture, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, lw: &LocalWaker) -> Poll { + self.try_chain().poll(lw, |result| match result { + Ok(ok) => TryChainAction::Future(ok), + Err(err) => TryChainAction::Output(Err(err)), + }) + } +} + +#[must_use = "futures do nothing unless polled"] +#[derive(Debug)] +enum TryChain { + First(Fut1), + Second(Fut2), + Empty, +} + +enum TryChainAction +where + Fut2: TryFuture, +{ + Future(Fut2), + Output(Result), +} + +impl TryChain +where + Fut1: TryFuture, + Fut2: TryFuture, +{ + fn new(fut1: Fut1) -> TryChain { + TryChain::First(fut1) + } + + fn poll(self: Pin<&mut Self>, lw: &LocalWaker, f: F) -> Poll> + where + F: FnOnce(Result) -> TryChainAction, + { + let mut f = Some(f); + + // Safe to call `get_mut_unchecked` because we won't move the futures. + let this = unsafe { Pin::get_mut_unchecked(self) }; + + loop { + let output = match this { + TryChain::First(fut1) => { + // Poll the first future + match unsafe { Pin::new_unchecked(fut1) }.try_poll(lw) { + Poll::Pending => return Poll::Pending, + Poll::Ready(output) => output, + } + } + TryChain::Second(fut2) => { + // Poll the second future + return unsafe { Pin::new_unchecked(fut2) }.try_poll(lw); + } + TryChain::Empty => { + panic!("future must not be polled after it returned `Poll::Ready`"); + } + }; + + *this = TryChain::Empty; // Drop fut1 + let f = f.take().unwrap(); + match f(output) { + TryChainAction::Future(fut2) => *this = TryChain::Second(fut2), + TryChainAction::Output(output) => return Poll::Ready(output), + } + } + } +} + #[cfg(test)] mod tests { use super::{CanceledRequests, Channel, RequestCancellation, RequestDispatch}; @@ -573,9 +801,10 @@ mod tests { ClientMessage, Response, }; use fnv::FnvHashMap; - use futures::{Poll, channel::mpsc, prelude::*}; - use futures_test::task::{noop_local_waker_ref}; + use futures::{channel::mpsc, prelude::*, Poll}; + use futures_test::task::noop_local_waker_ref; use std::{ + marker, net::{IpAddr, Ipv4Addr, SocketAddr}, pin::Pin, sync::atomic::AtomicU64, @@ -617,7 +846,8 @@ mod tests { .send(context::current(), "hi".into()) .boxed() .compat(), - ).unwrap(); + ) + .unwrap(); drop(resp); drop(channel); @@ -640,7 +870,8 @@ mod tests { .send(context::current(), "hi".into()) .boxed() .compat(), - ).unwrap(); + ) + .unwrap(); drop(resp); drop(channel); @@ -688,7 +919,7 @@ mod tests { impl PollTest for Poll>> where - E: ::std::fmt::Display + Send + 'static, + E: ::std::fmt::Display + marker::Send + 'static, { type T = Option; diff --git a/rpc/src/client/mod.rs b/rpc/src/client/mod.rs index 208b3fc..939d18a 100644 --- a/rpc/src/client/mod.rs +++ b/rpc/src/client/mod.rs @@ -6,27 +6,103 @@ //! Provides a client that connects to a server and sends multiplexed requests. -use crate::{context::Context, ClientMessage, Response, Transport}; +use crate::{context, ClientMessage, Response, Transport}; +use futures::prelude::*; use log::warn; use std::{ io, net::{Ipv4Addr, SocketAddr}, }; -mod dispatch; +/// Provides a [`Client`] backed by a transport. +pub mod channel; +pub use self::channel::Channel; /// Sends multiplexed requests to, and receives responses from, a server. -#[derive(Debug)] -pub struct Client { - /// Channel to send requests to the dispatch task. - channel: dispatch::Channel, +pub trait Client<'a, Req> { + /// The response type. + type Response; + + /// The future response. + type Future: Future> + 'a; + + /// Initiates a request, sending it to the dispatch task. + /// + /// Returns a [`Future`] that resolves to this client and the future response + /// once the request is successfully enqueued. + /// + /// [`Future`]: futures::Future + fn call(&'a mut self, ctx: context::Context, request: Req) -> Self::Future; + + /// Returns a Client that applies a post-processing function to the returned response. + fn map_response(self, f: F) -> MapResponse + where + F: FnMut(Self::Response) -> R, + Self: Sized, + { + MapResponse { inner: self, f } + } + + /// Returns a Client that applies a pre-processing function to the request. + fn with_request(self, f: F) -> WithRequest + where + F: FnMut(Req2) -> Req, + Self: Sized, + { + WithRequest { inner: self, f } + } } -impl Clone for Client { - fn clone(&self) -> Self { - Client { - channel: self.channel.clone(), - } +/// A Client that applies a function to the returned response. +#[derive(Clone, Debug)] +pub struct MapResponse { + inner: C, + f: F, +} + +impl<'a, C, F, Req, Resp, Resp2> Client<'a, Req> for MapResponse +where + C: Client<'a, Req, Response = Resp>, + F: FnMut(Resp) -> Resp2 + 'a, +{ + type Response = Resp2; + type Future = futures::future::MapOk<>::Future, &'a mut F>; + + fn call(&'a mut self, ctx: context::Context, request: Req) -> Self::Future { + self.inner.call(ctx, request).map_ok(&mut self.f) + } +} + +/// A Client that applies a pre-processing function to the request. +#[derive(Clone, Debug)] +pub struct WithRequest { + inner: C, + f: F, +} + +impl<'a, C, F, Req, Req2, Resp> Client<'a, Req2> for WithRequest +where + C: Client<'a, Req, Response = Resp>, + F: FnMut(Req2) -> Req, +{ + type Response = Resp; + type Future = >::Future; + + fn call(&'a mut self, ctx: context::Context, request: Req2) -> Self::Future { + self.inner.call(ctx, (self.f)(request)) + } +} + +impl<'a, Req, Resp> Client<'a, Req> for Channel +where + Req: 'a, + Resp: 'a, +{ + type Response = Resp; + type Future = channel::Call<'a, Req, Resp>; + + fn call(&'a mut self, ctx: context::Context, request: Req) -> channel::Call<'a, Req, Resp> { + self.call(ctx, request) } } @@ -53,39 +129,23 @@ impl Default for Config { } } -impl Client +/// Creates a new Client by wrapping a [`Transport`] and spawning a dispatch task +/// that manages the lifecycle of requests. +/// +/// Must only be called from on an executor. +pub async fn new(config: Config, transport: T) -> io::Result> where - Req: Send, - Resp: Send, + Req: Send + 'static, + Resp: Send + 'static, + T: Transport, SinkItem = ClientMessage> + Send + 'static, { - /// Creates a new Client by wrapping a [`Transport`] and spawning a dispatch task - /// that manages the lifecycle of requests. - /// - /// Must only be called from on an executor. - pub async fn new(config: Config, transport: T) -> io::Result - where - T: Transport, SinkItem = ClientMessage> + Send, - { - let server_addr = transport.peer_addr().unwrap_or_else(|e| { - warn!( - "Setting peer to unspecified because peer could not be determined: {}", - e - ); - SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0) - }); + let server_addr = transport.peer_addr().unwrap_or_else(|e| { + warn!( + "Setting peer to unspecified because peer could not be determined: {}", + e + ); + SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0) + }); - Ok(Client { - channel: await!(dispatch::spawn(config, transport, server_addr))?, - }) - } - - /// Initiates a request, sending it to the dispatch task. - /// - /// Returns a [`Future`] that resolves to this client and the future response - /// once the request is successfully enqueued. - /// - /// [`Future`]: futures::Future - pub async fn call(&mut self, ctx: Context, request: Req) -> io::Result { - await!(self.channel.call(ctx, request)) - } + Ok(await!(channel::spawn(config, transport, server_addr))?) } diff --git a/rpc/src/lib.rs b/rpc/src/lib.rs index ab95a9b..7191cdf 100644 --- a/rpc/src/lib.rs +++ b/rpc/src/lib.rs @@ -19,7 +19,7 @@ optin_builtin_traits, generator_trait, gen_future, - decl_macro, + decl_macro )] #![deny(missing_docs, missing_debug_implementations)] @@ -49,7 +49,10 @@ pub(crate) mod util; pub use crate::{client::Client, server::Server, transport::Transport}; -use futures::{Future, task::{Spawn, SpawnExt, SpawnError}}; +use futures::{ + task::{Spawn, SpawnError, SpawnExt}, + Future, +}; use std::{cell::RefCell, io, sync::Once, time::SystemTime}; /// A message from a client to a server. @@ -193,9 +196,7 @@ pub fn init(spawn: impl Spawn + Clone + 'static) { } pub(crate) fn spawn(future: impl Future + Send + 'static) -> Result<(), SpawnError> { - SPAWN.with(|spawn| { - spawn.borrow_mut().spawn(future) - }) + SPAWN.with(|spawn| spawn.borrow_mut().spawn(future)) } trait CloneSpawn: Spawn { diff --git a/rpc/src/server/filter.rs b/rpc/src/server/filter.rs index ed2c9e3..ae75a55 100644 --- a/rpc/src/server/filter.rs +++ b/rpc/src/server/filter.rs @@ -10,7 +10,13 @@ use crate::{ ClientMessage, Response, Transport, }; use fnv::FnvHashMap; -use futures::{channel::mpsc, prelude::*, ready, stream::Fuse, task::{LocalWaker, Poll}}; +use futures::{ + channel::mpsc, + prelude::*, + ready, + stream::Fuse, + task::{LocalWaker, Poll}, +}; use log::{debug, error, info, trace, warn}; use pin_utils::unsafe_pinned; use std::{ @@ -205,10 +211,7 @@ impl ConnectionFilter { } } - fn poll_closed_connections( - self: &mut Pin<&mut Self>, - cx: &LocalWaker, - ) -> Poll> { + fn poll_closed_connections(self: &mut Pin<&mut Self>, cx: &LocalWaker) -> Poll> { match ready!(self.closed_connections_rx().poll_next_unpin(cx)) { Some(addr) => { self.handle_closed_connection(&addr); diff --git a/rpc/src/server/mod.rs b/rpc/src/server/mod.rs index 38d75fd..897c08d 100644 --- a/rpc/src/server/mod.rs +++ b/rpc/src/server/mod.rs @@ -43,6 +43,12 @@ pub struct Server { ghost: PhantomData<(Req, Resp)>, } +impl Default for Server { + fn default() -> Self { + new(Config::default()) + } +} + /// Settings that control the behavior of the server. #[non_exhaustive] #[derive(Clone, Debug)] @@ -75,15 +81,15 @@ impl Default for Config { } } -impl Server { - /// Returns a new server with configuration specified `config`. - pub fn new(config: Config) -> Self { - Server { - config, - ghost: PhantomData, - } +/// Returns a new server with configuration specified `config`. +pub fn new(config: Config) -> Server { + Server { + config, + ghost: PhantomData, } +} +impl Server { /// Returns the config for this server. pub fn config(&self) -> &Config { &self.config @@ -122,7 +128,7 @@ where Req: Send + 'static, Resp: Send + 'static, T: Transport, SinkItem = Response> + Send + 'static, - F: FnMut(Context, Req) -> Fut + Send + 'static + Clone, + F: FnOnce(Context, Req) -> Fut + Send + 'static + Clone, Fut: Future> + Send + 'static, { type Output = (); @@ -132,7 +138,8 @@ where match channel { Ok(channel) => { let peer = channel.client_addr; - if let Err(e) = crate::spawn(channel.respond_with(self.request_handler().clone())) + if let Err(e) = + crate::spawn(channel.respond_with(self.request_handler().clone())) { warn!("[{}] Failed to spawn connection handler: {:?}", peer, e); } @@ -158,7 +165,7 @@ where /// Responds to all requests with `request_handler`. fn respond_with(self, request_handler: F) -> Running where - F: FnMut(Context, Req) -> Fut + Send + 'static + Clone, + F: FnOnce(Context, Req) -> Fut + Send + 'static + Clone, Fut: Future> + Send + 'static, { Running { @@ -222,21 +229,18 @@ where Req: Send, Resp: Send, { - pub(crate) fn start_send(self: &mut Pin<&mut Self>, response: Response) -> io::Result<()> { + pub(crate) fn start_send( + self: &mut Pin<&mut Self>, + response: Response, + ) -> io::Result<()> { self.transport().start_send(response) } - pub(crate) fn poll_ready( - self: &mut Pin<&mut Self>, - cx: &LocalWaker, - ) -> Poll> { + pub(crate) fn poll_ready(self: &mut Pin<&mut Self>, cx: &LocalWaker) -> Poll> { self.transport().poll_ready(cx) } - pub(crate) fn poll_flush( - self: &mut Pin<&mut Self>, - cx: &LocalWaker, - ) -> Poll> { + pub(crate) fn poll_flush(self: &mut Pin<&mut Self>, cx: &LocalWaker) -> Poll> { self.transport().poll_flush(cx) } @@ -256,7 +260,7 @@ where /// responses and resolves when the connection is closed. pub fn respond_with(self, f: F) -> impl Future where - F: FnMut(Context, Req) -> Fut + Send + 'static, + F: FnOnce(Context, Req) -> Fut + Send + 'static + Clone, Fut: Future> + Send + 'static, Req: 'static, Resp: 'static, @@ -271,7 +275,8 @@ where pending_responses: responses, responses_tx, in_flight_requests: FnvHashMap::default(), - }.unwrap_or_else(move |e| { + } + .unwrap_or_else(move |e| { info!("[{}] ClientHandler errored out: {}", peer, e); }) } @@ -305,7 +310,7 @@ where Req: Send + 'static, Resp: Send + 'static, T: Transport, SinkItem = Response> + Send, - F: FnMut(Context, Req) -> Fut + Send + 'static, + F: FnOnce(Context, Req) -> Fut + Send + 'static + Clone, Fut: Future> + Send + 'static, { /// If at max in-flight requests, check that there's room to immediately write a throttled @@ -462,7 +467,7 @@ where let mut response_tx = self.responses_tx().clone(); let trace_id = *ctx.trace_id(); - let response = self.f()(ctx.clone(), request); + let response = self.f().clone()(ctx.clone(), request); let response = deadline_compat::Deadline::new(response, Instant::now() + timeout).then( async move |result| { let response = Response { @@ -477,16 +482,15 @@ where }, ); let (abortable_response, abort_handle) = abortable(response); - crate::spawn(abortable_response.map(|_| ())) - .map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!( - "Could not spawn response task. Is shutdown: {}", - e.is_shutdown() - ), - ) - })?; + crate::spawn(abortable_response.map(|_| ())).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!( + "Could not spawn response task. Is shutdown: {}", + e.is_shutdown() + ), + ) + })?; self.in_flight_requests().insert(request_id, abort_handle); Ok(()) } @@ -521,7 +525,7 @@ where Req: Send + 'static, Resp: Send + 'static, T: Transport, SinkItem = Response> + Send, - F: FnMut(Context, Req) -> Fut + Send + 'static, + F: FnOnce(Context, Req) -> Fut + Send + 'static + Clone, Fut: Future> + Send + 'static, { type Output = io::Result<()>; diff --git a/rpc/src/transport/channel.rs b/rpc/src/transport/channel.rs index ac55447..9b2ec1a 100644 --- a/rpc/src/transport/channel.rs +++ b/rpc/src/transport/channel.rs @@ -7,7 +7,7 @@ //! Transports backed by in-memory channels. use crate::Transport; -use futures::{channel::mpsc, task::{LocalWaker}, Poll, Sink, Stream}; +use futures::{channel::mpsc, task::LocalWaker, Poll, Sink, Stream}; use pin_utils::unsafe_pinned; use std::pin::Pin; use std::{ @@ -66,10 +66,7 @@ impl Sink for UnboundedChannel { .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) } - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &LocalWaker, - ) -> Poll> { + fn poll_flush(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll> { self.tx() .poll_flush(cx) .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) @@ -97,8 +94,12 @@ impl Transport for UnboundedChannel { #[cfg(test)] mod tests { - use crate::{client::{self, Client}, context, server::{self, Handler, Server}, transport}; - use futures::{prelude::*, stream, compat::TokioDefaultSpawner}; + use crate::{ + client, context, + server::{Handler, Server}, + transport, + }; + use futures::{compat::TokioDefaultSpawner, prelude::*, stream}; use log::trace; use std::io; @@ -108,7 +109,7 @@ mod tests { crate::init(TokioDefaultSpawner); let (client_channel, server_channel) = transport::channel::unbounded(); - let server = Server::::new(server::Config::default()) + let server = Server::::default() .incoming(stream::once(future::ready(Ok(server_channel)))) .respond_with(|_ctx, request| { future::ready(request.parse::().map_err(|_| { @@ -120,7 +121,7 @@ mod tests { }); let responses = async { - let mut client = await!(Client::new(client::Config::default(), client_channel))?; + let mut client = await!(client::new(client::Config::default(), client_channel))?; let response1 = await!(client.call(context::current(), "123".into())); let response2 = await!(client.call(context::current(), "abc".into())); diff --git a/rpc/src/transport/mod.rs b/rpc/src/transport/mod.rs index 4b1e147..212babf 100644 --- a/rpc/src/transport/mod.rs +++ b/rpc/src/transport/mod.rs @@ -10,7 +10,12 @@ //! can be plugged in, using whatever protocol it wants. use futures::prelude::*; -use std::{io, net::SocketAddr, pin::Pin, task::{Poll, LocalWaker}}; +use std::{ + io, + net::SocketAddr, + pin::Pin, + task::{LocalWaker, Poll}, +}; pub mod channel; @@ -32,12 +37,20 @@ where } /// Returns a new Transport backed by the given Stream + Sink and connecting addresses. -pub fn new(inner: S, peer_addr: SocketAddr, local_addr: SocketAddr) -> impl Transport +pub fn new( + inner: S, + peer_addr: SocketAddr, + local_addr: SocketAddr, +) -> impl Transport where S: Stream>, S: Sink, { - TransportShim { inner, peer_addr, local_addr } + TransportShim { + inner, + peer_addr, + local_addr, + } } /// A transport created by adding peers to a Stream + Sink. @@ -48,10 +61,8 @@ struct TransportShim { inner: S, } - impl TransportShim { pin_utils::unsafe_pinned!(inner: S); - } impl Stream for TransportShim @@ -67,7 +78,7 @@ where impl Sink for TransportShim where - S: Sink + S: Sink, { type SinkItem = S::SinkItem; type SinkError = S::SinkError; diff --git a/rpc/src/util/deadline_compat.rs b/rpc/src/util/deadline_compat.rs index 0a0e9dc..7d155d1 100644 --- a/rpc/src/util/deadline_compat.rs +++ b/rpc/src/util/deadline_compat.rs @@ -7,7 +7,8 @@ use futures::{ compat::{Compat01As03, Future01CompatExt}, prelude::*, - ready, task::{Poll, LocalWaker}, + ready, + task::{LocalWaker, Poll}, }; use pin_utils::unsafe_pinned; use std::pin::Pin; @@ -50,7 +51,6 @@ where type Output = Result>; fn poll(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll { - // First, try polling the future match self.future().try_poll(waker) { Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)), diff --git a/rpc/src/util/serde.rs b/rpc/src/util/serde.rs index ac8f703..f0851fa 100644 --- a/rpc/src/util/serde.rs +++ b/rpc/src/util/serde.rs @@ -59,7 +59,8 @@ where Other => 16, UnexpectedEof => 17, _ => 16, - }.serialize(serializer) + } + .serialize(serializer) } /// Deserializes [`io::ErrorKind`] from a `u32`. diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 26fb105..5e410c5 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -30,6 +30,8 @@ rpc = { package = "tarpc-lib", path = "../rpc", version = "0.1" } futures-preview = "0.3.0-alpha.9" [dev-dependencies] +bincode = "1.0" +bytes = { version = "0.4", features = ["serde"] } humantime = "1.0" futures-preview = { version = "0.3.0-alpha.9", features = ["compat", "tokio-compat"] } bincode-transport = { package = "tarpc-bincode-transport", version = "0.2", path = "../bincode-transport" } diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 258cd69..1b90eb5 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -11,7 +11,7 @@ await_macro, async_await, existential_type, - proc_macro_hygiene, + proc_macro_hygiene )] use futures::{ @@ -55,7 +55,7 @@ struct Subscriber { impl subscriber::Service for Subscriber { type ReceiveFut = Ready<()>; - fn receive(&self, _: context::Context, message: String) -> Self::ReceiveFut { + fn receive(self, _: context::Context, message: String) -> Self::ReceiveFut { println!("{} received message: {}", self.id, message); future::ready(()) } @@ -66,13 +66,13 @@ impl Subscriber { let incoming = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; let addr = incoming.local_addr(); tokio_executor::spawn( - Server::new(config) + server::new(config) .incoming(incoming) .take(1) .respond_with(subscriber::serve(Subscriber { id })) .unit_error() .boxed() - .compat() + .compat(), ); Ok(addr) } @@ -94,7 +94,7 @@ impl Publisher { impl publisher::Service for Publisher { existential type BroadcastFut: Future; - fn broadcast(&self, _: context::Context, message: String) -> Self::BroadcastFut { + fn broadcast(self, _: context::Context, message: String) -> Self::BroadcastFut { async fn broadcast(clients: Arc>>, message: String) { let mut clients = clients.lock().unwrap().clone(); for client in clients.values_mut() { @@ -110,7 +110,7 @@ impl publisher::Service for Publisher { existential type SubscribeFut: Future>; - fn subscribe(&self, _: context::Context, id: u32, addr: SocketAddr) -> Self::SubscribeFut { + fn subscribe(self, _: context::Context, id: u32, addr: SocketAddr) -> Self::SubscribeFut { async fn subscribe( clients: Arc>>, id: u32, @@ -128,7 +128,7 @@ impl publisher::Service for Publisher { existential type UnsubscribeFut: Future; - fn unsubscribe(&self, _: context::Context, id: u32) -> Self::UnsubscribeFut { + fn unsubscribe(self, _: context::Context, id: u32) -> Self::UnsubscribeFut { println!("Unsubscribing {}", id); let mut clients = self.clients.lock().unwrap(); if let None = clients.remove(&id) { @@ -146,13 +146,13 @@ async fn run() -> io::Result<()> { let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; let publisher_addr = transport.local_addr(); tokio_executor::spawn( - Server::new(server::Config::default()) + Server::default() .incoming(transport) .take(1) .respond_with(publisher::serve(Publisher::new())) .unit_error() .boxed() - .compat() + .compat(), ); let subscriber1 = await!(Subscriber::listen(0, server::Config::default()))?; @@ -180,12 +180,6 @@ async fn run() -> io::Result<()> { } fn main() { - tokio::run( - run() - .boxed() - .map_err(|e| panic!(e)) - .boxed() - .compat(), - ); + tokio::run(run().boxed().map_err(|e| panic!(e)).boxed().compat()); thread::sleep(Duration::from_millis(100)); } diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 0841e57..2c2cf0c 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -10,16 +10,17 @@ arbitrary_self_types, await_macro, async_await, - proc_macro_hygiene, + proc_macro_hygiene )] use futures::{ + compat::TokioDefaultSpawner, future::{self, Ready}, prelude::*, }; use rpc::{ client, context, - server::{self, Handler, Server}, + server::{Handler, Server}, }; use std::io; @@ -41,7 +42,7 @@ impl Service for HelloServer { type HelloFut = Ready; - fn hello(&self, _: context::Context, name: String) -> Self::HelloFut { + fn hello(self, _: context::Context, name: String) -> Self::HelloFut { future::ready(format!("Hello, {}!", name)) } } @@ -53,7 +54,7 @@ async fn run() -> io::Result<()> { let addr = transport.local_addr(); // The server is configured with the defaults. - let server = Server::new(server::Config::default()) + let server = Server::default() // Server can listen on any type that implements the Transport trait. .incoming(transport) // Close the stream after the client connects @@ -82,6 +83,8 @@ async fn run() -> io::Result<()> { } fn main() { + tarpc::init(TokioDefaultSpawner); + tokio::run( run() .map_err(|e| eprintln!("Oh no: {}", e)) diff --git a/tarpc/examples/server_calling_server.rs b/tarpc/examples/server_calling_server.rs index b1e7a8c..101e75d 100644 --- a/tarpc/examples/server_calling_server.rs +++ b/tarpc/examples/server_calling_server.rs @@ -11,17 +11,18 @@ futures_api, await_macro, async_await, - proc_macro_hygiene, + proc_macro_hygiene )] use crate::{add::Service as AddService, double::Service as DoubleService}; use futures::{ + compat::TokioDefaultSpawner, future::{self, Ready}, prelude::*, }; use rpc::{ client, context, - server::{self, Handler, Server}, + server::{Handler, Server}, }; use std::io; @@ -45,7 +46,7 @@ struct AddServer; impl AddService for AddServer { type AddFut = Ready; - fn add(&self, _: context::Context, x: i32, y: i32) -> Self::AddFut { + fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut { future::ready(x + y) } } @@ -58,7 +59,7 @@ struct DoubleServer { impl DoubleService for DoubleServer { existential type DoubleFut: Future> + Send; - fn double(&self, _: context::Context, x: i32) -> Self::DoubleFut { + fn double(self, _: context::Context, x: i32) -> Self::DoubleFut { async fn double(mut client: add::Client, x: i32) -> Result { let result = await!(client.add(context::current(), x, x)); result.map_err(|e| e.to_string()) @@ -71,7 +72,7 @@ impl DoubleService for DoubleServer { async fn run() -> io::Result<()> { let add_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; let addr = add_listener.local_addr(); - let add_server = Server::new(server::Config::default()) + let add_server = Server::default() .incoming(add_listener) .take(1) .respond_with(add::serve(AddServer)); @@ -82,7 +83,7 @@ async fn run() -> io::Result<()> { let double_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; let addr = double_listener.local_addr(); - let double_server = rpc::Server::new(server::Config::default()) + let double_server = rpc::Server::default() .incoming(double_listener) .take(1) .respond_with(double::serve(DoubleServer { add_client })); @@ -102,10 +103,6 @@ async fn run() -> io::Result<()> { fn main() { env_logger::init(); - tokio::run( - run() - .map_err(|e| panic!(e)) - .boxed() - .compat(), - ); + tarpc::init(TokioDefaultSpawner); + tokio::run(run().map_err(|e| panic!(e)).boxed().compat()); } diff --git a/tarpc/examples/service_registry.rs b/tarpc/examples/service_registry.rs new file mode 100644 index 0000000..b880ead --- /dev/null +++ b/tarpc/examples/service_registry.rs @@ -0,0 +1,402 @@ +#![feature( + pin, + async_await, + await_macro, + futures_api, + arbitrary_self_types, + proc_macro_hygiene, + impl_trait_in_bindings +)] + +mod registry { + use bytes::Bytes; + use futures::{ + future::{ready, Ready}, + prelude::*, + }; + use serde::{Deserialize, Serialize}; + use std::{ + io, + pin::Pin, + sync::Arc, + task::{LocalWaker, Poll}, + }; + use tarpc::{ + client::{self, Client}, + context, + }; + + /// A request to a named service. + #[derive(Serialize, Deserialize)] + pub struct ServiceRequest { + service_name: String, + request: Bytes, + } + + /// A response from a named service. + #[derive(Serialize, Deserialize)] + pub struct ServiceResponse { + response: Bytes, + } + + /// A list of registered services. + pub struct Registry { + registrations: Services, + } + + impl Default for Registry { + fn default() -> Self { + Registry { registrations: Nil } + } + } + + impl Registry { + /// Returns a function that serves requests for the registered services. + pub fn serve( + self, + ) -> impl FnOnce(context::Context, ServiceRequest) + -> Either>> + + Clone { + let registrations = Arc::new(self.registrations); + move |cx, req: ServiceRequest| match registrations.serve(cx, &req) { + Some(serve) => Either::Left(serve), + None => Either::Right(ready(Err(io::Error::new( + io::ErrorKind::NotFound, + format!("Service '{}' not registered", req.service_name), + )))), + } + } + + /// Registers `serve` with the given `name` using the given serialization scheme. + pub fn register( + self, + name: String, + serve: S, + deserialize: De, + serialize: Ser) + -> Registry> + where + Req: Send, + S: FnOnce(context::Context, Req) -> RespFut + Send + 'static + Clone, + RespFut: Future> + Send + 'static, + De: FnOnce(Bytes) -> io::Result + Send + 'static + Clone, + Ser: FnOnce(Resp) -> io::Result + Send + 'static + Clone, + { + let registrations = Registration { + name: name, + serve: move |cx, req: Bytes| { + async move { + let req = deserialize.clone()(req)?; + let response = await!(serve.clone()(cx, req))?; + let response = serialize.clone()(response)?; + Ok(ServiceResponse { response }) + } + }, + rest: self.registrations, + }; + Registry { registrations } + } + } + + /// Creates a client that sends requests to a service + /// named `service_name`, over the given channel, using + /// the specified serialization scheme. + pub fn new_client( + service_name: String, + channel: &client::Channel, + mut serialize: Ser, + mut deserialize: De, + ) -> client::MapResponse< + client::WithRequest< + client::Channel, + impl FnMut(Req) -> ServiceRequest, + >, + impl FnMut(ServiceResponse) -> Resp, + > + where + Req: Send + 'static, + Resp: Send + 'static, + De: FnMut(Bytes) -> io::Result + Clone + Send + 'static, + Ser: FnMut(Req) -> io::Result + Clone + Send + 'static, + { + channel + .clone() + .with_request(move |req| { + ServiceRequest { + service_name: service_name.clone(), + // TODO: shouldn't need to unwrap here. Maybe with_request should allow for + // returning Result. + request: serialize(req).unwrap(), + } + }) + // TODO: same thing. Maybe this should be more like and_then rather than map. + .map_response(move |resp| deserialize(resp.response).unwrap()) + } + + /// Serves a request. + /// + /// This trait is mostly an implementation detail that isn't used outside of the registry + /// internals. + pub trait Serve: Clone + Send + 'static { + type Response: Future> + Send + 'static; + fn serve(self, cx: context::Context, request: Bytes) -> Self::Response; + } + + /// Serves a request if the request is for a registered service. + /// + /// This trait is mostly an implementation detail that isn't used outside of the registry + /// internals. + pub trait MaybeServe: Send + 'static { + type Future: Future> + Send + 'static; + + fn serve(&self, cx: context::Context, request: &ServiceRequest) -> Option; + } + + /// A registry starting with service S, followed by Rest. + /// + /// This type is mostly an implementation detail that is not used directly + /// outside of the registry internals. + pub struct Registration { + /// The registered service's name. Must be unique across all registered services. + name: String, + /// The registered service. + serve: S, + /// Any remaining registered services. + rest: Rest, + } + + /// An empty registry. + /// + /// This type is mostly an implementation detail that is not used directly + /// outside of the registry internals. + pub struct Nil; + + impl MaybeServe for Nil { + type Future = futures::future::Ready>; + + fn serve(&self, _: context::Context, _: &ServiceRequest) -> Option { + None + } + } + + impl MaybeServe for Registration + where + S: Serve, + Rest: MaybeServe, + { + type Future = Either; + + fn serve(&self, cx: context::Context, request: &ServiceRequest) -> Option { + if self.name == request.service_name { + Some(Either::Left(self.serve.clone().serve(cx, request.request.clone()))) + } else { + self.rest.serve(cx, request).map(Either::Right) + } + } + } + + /// Wraps either of two future types that both resolve to the same output type. + #[derive(Debug)] + #[must_use = "futures do nothing unless polled"] + pub enum Either { + Left(Left), + Right(Right), + } + + impl Future for Either + where + Left: Future, + Right: Future + { + type Output = Output; + + fn poll(self: Pin<&mut Self>, waker: &LocalWaker) -> Poll { + unsafe { + match Pin::get_mut_unchecked(self) { + Either::Left(car) => Pin::new_unchecked(car).poll(waker), + Either::Right(cdr) => Pin::new_unchecked(cdr).poll(waker), + } + } + } + } + + impl Serve for F + where + F: FnOnce(context::Context, Bytes) -> Resp + Clone + Send + 'static, + Resp: Future> + Send + 'static + { + type Response = Resp; + + fn serve(self, cx: context::Context, request: Bytes) -> Resp { + self(cx, request) + } + } +} + +// Example +use bytes::Bytes; +use futures::{ + future::{ready, Ready}, + prelude::*, +}; +use serde::{Deserialize, Serialize}; +use std::{ + collections::HashMap, + io, + sync::{Arc, RwLock}, +}; +use tarpc::{client, context, server::Handler}; + +fn deserialize(req: Bytes) -> io::Result +where + Req: for<'a> Deserialize<'a> + Send, +{ + bincode::deserialize(req.as_ref()).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) +} + +fn serialize(resp: Resp) -> io::Result +where + Resp: Serialize, +{ + Ok(bincode::serialize(&resp) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))? + .into()) +} + +mod write_service { + tarpc::service! { + rpc write(key: String, value: String); + } +} + +mod read_service { + tarpc::service! { + rpc read(key: String) -> Option; + } +} + +#[derive(Debug, Default, Clone)] +struct Server { + data: Arc>>, +} + +impl write_service::Service for Server { + type WriteFut = Ready<()>; + + fn write(self, _: context::Context, key: String, value: String) -> Self::WriteFut { + self.data.write().unwrap().insert(key, value); + ready(()) + } +} + +impl read_service::Service for Server { + type ReadFut = Ready>; + + fn read(self, _: context::Context, key: String) -> Self::ReadFut { + ready(self.data.read().unwrap().get(&key).cloned()) + } +} + +trait DefaultSpawn { + fn spawn(self); +} + +impl DefaultSpawn for F +where + F: Future + Send + 'static, +{ + fn spawn(self) { + tokio_executor::spawn(self.unit_error().boxed().compat()) + } +} + +struct BincodeRegistry { + registry: registry::Registry, +} + +impl Default for BincodeRegistry { + fn default() -> Self { + BincodeRegistry { registry: registry::Registry::default() } + } +} + +impl BincodeRegistry { + fn serve(self) -> impl FnOnce(context::Context, registry::ServiceRequest) + -> registry::Either>> + Clone + { + self.registry.serve() + } + + fn register( + self, + name: String, + serve: S, + ) -> BincodeRegistry> + where + Req: for<'a> Deserialize<'a> + Send + 'static, + Resp: Serialize + 'static, + S: FnOnce(context::Context, Req) -> RespFut + Send + 'static + Clone, + RespFut: Future> + Send + 'static, + { + let registry = self.registry.register(name, serve, deserialize, serialize); + BincodeRegistry { registry } + } +} + +pub fn new_client( + service_name: String, + channel: &client::Channel, +) -> client::MapResponse< + client::WithRequest< + client::Channel, + impl FnMut(Req) -> registry::ServiceRequest, + >, + impl FnMut(registry::ServiceResponse) -> Resp, +> +where + Req: Serialize + Send + 'static, + Resp: for<'a> Deserialize<'a> + Send + 'static, +{ + registry::new_client(service_name, channel, serialize, deserialize) +} + +async fn run() -> io::Result<()> { + let server = Server::default(); + let registry = BincodeRegistry::default() + .register( + "WriteService".to_string(), + write_service::serve(server.clone()), + ) + .register( + "ReadService".to_string(), + read_service::serve(server.clone()), + ); + + let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; + let server_addr = listener.local_addr(); + let server = tarpc::Server::default() + .incoming(listener) + .take(1) + .respond_with(registry.serve()); + tokio_executor::spawn(server.unit_error().boxed().compat()); + + let transport = await!(bincode_transport::connect(&server_addr))?; + let channel = await!(client::new(client::Config::default(), transport))?; + + let write_client = new_client("WriteService".to_string(), &channel); + let mut write_client = write_service::Client::from(write_client); + + let read_client = new_client("ReadService".to_string(), &channel); + let mut read_client = read_service::Client::from(read_client); + + await!(write_client.write(context::current(), "key".to_string(), "val".to_string()))?; + let val = await!(read_client.read(context::current(), "key".to_string()))?; + println!("{:?}", val); + + Ok(()) +} + +fn main() { + tarpc::init(futures::compat::TokioDefaultSpawner); + tokio::run(run().boxed().map_err(|e| panic!(e)).boxed().compat()); +} diff --git a/tarpc/rustfmt.toml b/tarpc/rustfmt.toml index 0ef5137..32a9786 100644 --- a/tarpc/rustfmt.toml +++ b/tarpc/rustfmt.toml @@ -1 +1 @@ -edition = "Edition2018" +edition = "2018" diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 89aa271..d95a491 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -40,7 +40,7 @@ //! }; //! use tarpc::{ //! client, context, -//! server::{self, Handler, Server}, +//! server::{self, Handler}, //! }; //! use std::io; //! @@ -74,7 +74,7 @@ //! let addr = transport.local_addr(); //! //! // The server is configured with the defaults. -//! let server = Server::new(server::Config::default()) +//! let server = server::new(server::Config::default()) //! // Server can listen on any type that implements the Transport trait. //! .incoming(transport) //! // Close the stream after the client connects @@ -113,13 +113,7 @@ //! ``` #![deny(missing_docs, missing_debug_implementations)] -#![feature( - futures_api, - pin, - await_macro, - async_await, - decl_macro, -)] +#![feature(futures_api, pin, await_macro, async_await, decl_macro)] #![cfg_attr(test, feature(proc_macro_hygiene, arbitrary_self_types))] #[doc(hidden)] diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs index 26ba657..d25f94e 100644 --- a/tarpc/src/macros.rs +++ b/tarpc/src/macros.rs @@ -154,7 +154,7 @@ macro_rules! service { } $(#[$attr])* - fn $fn_name(&self, ctx: $crate::context::Context, $($arg:$in_),*) -> $crate::ty_snake_to_camel!(Self::$fn_name); + fn $fn_name(self, ctx: $crate::context::Context, $($arg:$in_),*) -> $crate::ty_snake_to_camel!(Self::$fn_name); )* } @@ -196,12 +196,12 @@ macro_rules! service { /// Returns a serving function to use with rpc::server::Server. pub fn serve(service: S) - -> impl FnMut($crate::context::Context, Request) -> ResponseFut + Send + 'static + Clone { + -> impl FnOnce($crate::context::Context, Request) -> ResponseFut + Send + 'static + Clone { move |ctx, req| { match req { $( Request::$fn_name{ $($arg,)* } => { - let resp = Service::$fn_name(&mut service.clone(), ctx, $($arg),*); + let resp = Service::$fn_name(service.clone(), ctx, $($arg),*); ResponseFut::$fn_name(resp) } )* @@ -212,7 +212,7 @@ macro_rules! service { #[allow(unused)] #[derive(Clone, Debug)] /// The client stub that makes RPC calls to the server. Exposes a Future interface. - pub struct Client($crate::client::Client); + pub struct Client>(C); /// Returns a new client stub that sends requests over the given transport. pub async fn new_stub(config: $crate::client::Config, transport: T) @@ -220,19 +220,29 @@ macro_rules! service { where T: $crate::Transport< Item = $crate::Response, - SinkItem = $crate::ClientMessage> + Send, + SinkItem = $crate::ClientMessage> + Send + 'static, { - Ok(Client(await!($crate::client::Client::new(config, transport))?)) + Ok(Client(await!($crate::client::new(config, transport))?)) } - impl Client { + impl From for Client + where for <'a> C: $crate::Client<'a, Request, Response = Response> + { + fn from(client: C) -> Self { + Client(client) + } + } + + impl Client + where for<'a> C: $crate::Client<'a, Request, Response = Response> + { $( #[allow(unused)] $(#[$attr])* pub fn $fn_name(&mut self, ctx: $crate::context::Context, $($arg: $in_),*) -> impl ::std::future::Future> + '_ { let request__ = Request::$fn_name { $($arg,)* }; - let resp = self.0.call(ctx, request__); + let resp = $crate::Client::call(&mut self.0, ctx, request__); async move { match await!(resp)? { Response::$fn_name(msg__) => ::std::result::Result::Ok(msg__), @@ -276,11 +286,7 @@ mod functional_test { future::{ready, Ready}, prelude::*, }; - use rpc::{ - client, context, - server::{self, Handler}, - transport::channel, - }; + use rpc::{client, context, server::Handler, transport::channel}; use std::io; use tokio::runtime::current_thread; @@ -295,13 +301,13 @@ mod functional_test { impl Service for Server { type AddFut = Ready; - fn add(&self, _: context::Context, x: i32, y: i32) -> Self::AddFut { + fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut { ready(x + y) } type HeyFut = Ready; - fn hey(&self, _: context::Context, name: String) -> Self::HeyFut { + fn hey(self, _: context::Context, name: String) -> Self::HeyFut { ready(format!("Hey, {}.", name)) } } @@ -314,12 +320,12 @@ mod functional_test { let test = async { let (tx, rx) = channel::unbounded(); tokio_executor::spawn( - rpc::Server::new(server::Config::default()) + crate::Server::default() .incoming(stream::once(ready(Ok(rx)))) .respond_with(serve(Server)) .unit_error() .boxed() - .compat() + .compat(), ); let mut client = await!(new_stub(client::Config::default(), tx))?; @@ -343,12 +349,12 @@ mod functional_test { let test = async { let (tx, rx) = channel::unbounded(); tokio_executor::spawn( - rpc::Server::new(server::Config::default()) + rpc::Server::default() .incoming(stream::once(ready(Ok(rx)))) .respond_with(serve(Server)) .unit_error() .boxed() - .compat() + .compat(), ); let client = await!(new_stub(client::Config::default(), tx))?; diff --git a/tarpc/tests/latency.rs b/tarpc/tests/latency.rs index c543133..ec30845 100644 --- a/tarpc/tests/latency.rs +++ b/tarpc/tests/latency.rs @@ -13,7 +13,7 @@ generators, await_macro, async_await, - proc_macro_hygiene, + proc_macro_hygiene )] extern crate test; @@ -22,7 +22,7 @@ use self::test::stats::Stats; use futures::{compat::TokioDefaultSpawner, future, prelude::*}; use rpc::{ client, context, - server::{self, Handler, Server}, + server::{Handler, Server}, }; use std::{ io, @@ -41,7 +41,7 @@ struct Serve; impl ack::Service for Serve { type AckFut = future::Ready<()>; - fn ack(&self, _: context::Context) -> Self::AckFut { + fn ack(self, _: context::Context) -> Self::AckFut { future::ready(()) } } @@ -51,13 +51,13 @@ async fn bench() -> io::Result<()> { let addr = listener.local_addr(); tokio_executor::spawn( - Server::new(server::Config::default()) + Server::default() .incoming(listener) .take(1) .respond_with(ack::serve(Serve)) .unit_error() .boxed() - .compat() + .compat(), ); let conn = await!(bincode_transport::connect(&addr))?; @@ -122,10 +122,5 @@ fn bench_small_packet() { env_logger::init(); tarpc::init(TokioDefaultSpawner); - tokio::run( - bench() - .map_err(|e| panic!(e.to_string())) - .boxed() - .compat(), - ) + tokio::run(bench().map_err(|e| panic!(e.to_string())).boxed().compat()) } diff --git a/trace/rustfmt.toml b/trace/rustfmt.toml index 0ef5137..32a9786 100644 --- a/trace/rustfmt.toml +++ b/trace/rustfmt.toml @@ -1 +1 @@ -edition = "Edition2018" +edition = "2018"