diff --git a/README.md b/README.md index 7a18621..5864a03 100644 --- a/README.md +++ b/README.md @@ -36,12 +36,12 @@ Add to your `Cargo.toml` dependencies: tarpc = "0.18.0" ``` -The `service!` macro expands to a collection of items that form an -rpc service. In the above example, the macro is called within the -`hello_service` module. This module will contain a `Client` stub and `Service` trait. There is -These generated types make it easy and ergonomic to write servers without dealing with serialization -directly. Simply implement one of the generated traits, and you're off to the -races! +The `service!` macro expands to a collection of items that form an rpc service. +In the above example, the macro is called within the `hello_service` module. +This module will contain a `Client` stub and `Service` trait. There is These +generated types make it easy and ergonomic to write servers without dealing with +serialization directly. Simply implement one of the generated traits, and you're +off to the races! ## Example @@ -49,7 +49,7 @@ For this example, in addition to tarpc, also add two other dependencies to your `Cargo.toml`: ```toml -futures-preview = { version = "0.3.0-alpha.16", features = ["compat"] } +futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] } tokio = "0.1" ``` @@ -175,7 +175,7 @@ async fn run() -> io::Result<()> { let server = server::new(server::Config::default()) // incoming() takes a stream of transports such as would be returned by // TcpListener::incoming (but a stream instead of an iterator). - .incoming(stream::once(future::ready(Ok(server_transport)))) + .incoming(stream::once(future::ready(server_transport))) // serve is generated by the service! macro. It takes as input any type implementing // the generated Service trait. .respond_with(serve(HelloServer)); @@ -246,7 +246,7 @@ background tasks for the client and server. # let server = server::new(server::Config::default()) # // incoming() takes a stream of transports such as would be returned by # // TcpListener::incoming (but a stream instead of an iterator). -# .incoming(stream::once(future::ready(Ok(server_transport)))) +# .incoming(stream::once(future::ready(server_transport))) # // serve is generated by the service! macro. It takes as input any type implementing # // the generated Service trait. # .respond_with(serve(HelloServer)); diff --git a/bincode-transport/Cargo.toml b/bincode-transport/Cargo.toml index 3acca11..141152b 100644 --- a/bincode-transport/Cargo.toml +++ b/bincode-transport/Cargo.toml @@ -14,10 +14,9 @@ description = "A bincode-based transport for tarpc services." [dependencies] bincode = "1" -futures-preview = { version = "0.3.0-alpha.16", features = ["compat"] } +futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] } futures_legacy = { version = "0.1", package = "futures" } pin-utils = "0.1.0-alpha.4" -rpc = { package = "tarpc-lib", version = "0.6", path = "../rpc", features = ["serde1"] } serde = "1.0" tokio-io = "0.1" async-bincode = "0.4" @@ -26,9 +25,10 @@ tokio-tcp = "0.1" [dev-dependencies] env_logger = "0.6" humantime = "1.0" -libtest = "0.0.1" log = "0.4" -rand = "0.6" +rand = "0.7" +rand_distr = "0.2" +rpc = { package = "tarpc-lib", version = "0.6", path = "../rpc", features = ["serde1"] } tokio = "0.1" tokio-executor = "0.1" tokio-reactor = "0.1" diff --git a/bincode-transport/src/lib.rs b/bincode-transport/src/lib.rs index 6ff3f66..bc42242 100644 --- a/bincode-transport/src/lib.rs +++ b/bincode-transport/src/lib.rs @@ -60,7 +60,7 @@ where S: AsyncWrite, SinkItem: Serialize, { - type SinkError = io::Error; + type Error = io::Error; fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { self.inner() @@ -81,7 +81,9 @@ where } } -fn convert>>(poll: Poll>) -> Poll> { +fn convert>>( + poll: Poll>, +) -> Poll> { match poll { Poll::Pending => Poll::Pending, Poll::Ready(Ok(())) => Poll::Ready(Ok(())), @@ -89,23 +91,24 @@ fn convert>>(poll: Poll>) -> Poll } } -impl rpc::Transport for Transport -where - Item: for<'de> Deserialize<'de>, - SinkItem: Serialize, -{ - type Item = Item; - type SinkItem = SinkItem; - - fn peer_addr(&self) -> io::Result { +impl Transport { + /// Returns the address of the peer connected over the transport. + pub fn peer_addr(&self) -> io::Result { self.inner.get_ref().get_ref().peer_addr() } - fn local_addr(&self) -> io::Result { + /// Returns the address of this end of the transport. + pub fn local_addr(&self) -> io::Result { self.inner.get_ref().get_ref().local_addr() } } +impl AsRef for Transport { + fn as_ref(&self) -> &T { + self.inner.get_ref().get_ref() + } +} + /// Returns a new bincode transport that reads from and writes to `io`. pub fn new(io: TcpStream) -> Transport where diff --git a/bincode-transport/tests/bench.rs b/bincode-transport/tests/bench.rs index 4550ccf..c77cffe 100644 --- a/bincode-transport/tests/bench.rs +++ b/bincode-transport/tests/bench.rs @@ -8,8 +8,10 @@ #![feature(test, integer_atomics, async_await)] +extern crate test; + use futures::{compat::Executor01CompatExt, prelude::*}; -use libtest::stats::Stats; +use test::stats::Stats; use rpc::{ client, context, server::{Handler, Server}, @@ -20,8 +22,9 @@ use std::{ }; async fn bench() -> io::Result<()> { - let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let addr = listener.local_addr(); + let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())? + .filter_map(|r| future::ready(r.ok())); + let addr = listener.get_ref().local_addr(); tokio_executor::spawn( Server::::default() diff --git a/bincode-transport/tests/cancel.rs b/bincode-transport/tests/cancel.rs index b9950a1..38087bb 100644 --- a/bincode-transport/tests/cancel.rs +++ b/bincode-transport/tests/cancel.rs @@ -7,6 +7,7 @@ //! Tests client/server control flow. #![feature(async_await)] +#![feature(async_closure)] use futures::{ compat::{Executor01CompatExt, Future01CompatExt}, @@ -14,8 +15,11 @@ use futures::{ stream::FuturesUnordered, }; use log::{info, trace}; -use rand::distributions::{Distribution, Normal}; -use rpc::{client, context, server::Server}; +use rand_distr::{Distribution, Normal}; +use rpc::{ + client, context, + server::{Channel, Server}, +}; use std::{ io, time::{Duration, Instant, SystemTime}, @@ -34,18 +38,14 @@ impl AsDuration for SystemTime { } async fn run() -> io::Result<()> { - let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let addr = listener.local_addr(); + let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())? + .filter_map(|r| future::ready(r.ok())); + let addr = listener.get_ref().local_addr(); let server = Server::::default() .incoming(listener) .take(1) .for_each(async move |channel| { - let channel = if let Ok(channel) = channel { - channel - } else { - return; - }; - let client_addr = *channel.client_addr(); + let client_addr = channel.get_ref().peer_addr().unwrap(); let handler = channel.respond_with(move |ctx, request| { // Sleep for a time sampled from a normal distribution with: // - mean: 1/2 the deadline. @@ -53,7 +53,7 @@ async fn run() -> io::Result<()> { let deadline: Duration = ctx.deadline.as_duration(); let deadline_millis = deadline.as_secs() * 1000 + deadline.subsec_millis() as u64; let distribution = - Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.); + Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.).unwrap(); let delay_millis = distribution.sample(&mut rand::thread_rng()).max(0.); let delay = Duration::from_millis(delay_millis as u64); @@ -79,20 +79,16 @@ async fn run() -> io::Result<()> { let client = client::new::(client::Config::default(), conn).await?; // Proxy service - let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let addr = listener.local_addr(); + let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())? + .filter_map(|r| future::ready(r.ok())); + let addr = listener.get_ref().local_addr(); let proxy_server = Server::::default() .incoming(listener) .take(1) .for_each(move |channel| { let client = client.clone(); async move { - let channel = if let Ok(channel) = channel { - channel - } else { - return; - }; - let client_addr = *channel.client_addr(); + let client_addr = channel.get_ref().peer_addr().unwrap(); let handler = channel.respond_with(move |ctx, request| { trace!("[{}/{}] Proxying request.", ctx.trace_id(), client_addr); let mut client = client.clone(); diff --git a/bincode-transport/tests/pushback.rs b/bincode-transport/tests/pushback.rs index 72a5957..8f54547 100644 --- a/bincode-transport/tests/pushback.rs +++ b/bincode-transport/tests/pushback.rs @@ -7,14 +7,18 @@ //! Tests client/server control flow. #![feature(async_await)] +#![feature(async_closure)] use futures::{ compat::{Executor01CompatExt, Future01CompatExt}, prelude::*, }; use log::{error, info, trace}; -use rand::distributions::{Distribution, Normal}; -use rpc::{client, context, server::Server}; +use rand_distr::{Distribution, Normal}; +use rpc::{ + client, context, + server::{Channel, Handler, Server}, +}; use std::{ io, time::{Duration, Instant, SystemTime}, @@ -33,18 +37,15 @@ impl AsDuration for SystemTime { } async fn run() -> io::Result<()> { - let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let addr = listener.local_addr(); + let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())? + .filter_map(|r| future::ready(r.ok())); + let addr = listener.get_ref().local_addr(); let server = Server::::default() .incoming(listener) .take(1) + .max_concurrent_requests_per_channel(19) .for_each(async move |channel| { - let channel = if let Ok(channel) = channel { - channel - } else { - return; - }; - let client_addr = *channel.client_addr(); + let client_addr = channel.get_ref().get_ref().peer_addr().unwrap(); let handler = channel.respond_with(move |ctx, request| { // Sleep for a time sampled from a normal distribution with: // - mean: 1/2 the deadline. @@ -52,7 +53,7 @@ async fn run() -> io::Result<()> { let deadline: Duration = ctx.deadline.as_duration(); let deadline_millis = deadline.as_secs() * 1000 + deadline.subsec_millis() as u64; let distribution = - Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.); + Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.).unwrap(); let delay_millis = distribution.sample(&mut rand::thread_rng()).max(0.); let delay = Duration::from_millis(delay_millis as u64); @@ -75,7 +76,7 @@ async fn run() -> io::Result<()> { tokio_executor::spawn(server.unit_error().boxed().compat()); let mut config = client::Config::default(); - config.max_in_flight_requests = 10; + config.max_in_flight_requests = 20; config.pending_request_buffer = 10; let conn = tarpc_bincode_transport::connect(&addr).await?; @@ -103,17 +104,11 @@ async fn run() -> io::Result<()> { } #[test] -fn ping_pong() -> io::Result<()> { +fn pushback() -> io::Result<()> { env_logger::init(); rpc::init(tokio::executor::DefaultExecutor::current().compat()); - tokio::run( - run() - .map_ok(|_| println!("done")) - .map_err(|e| panic!(e.to_string())) - .boxed() - .compat(), - ); + tokio::run(run().map_err(|e| panic!(e.to_string())).boxed().compat()); Ok(()) } diff --git a/example-service/Cargo.toml b/example-service/Cargo.toml index f25cd89..90549cd 100644 --- a/example-service/Cargo.toml +++ b/example-service/Cargo.toml @@ -13,12 +13,13 @@ readme = "../README.md" description = "An example server built on tarpc." [dependencies] -bincode-transport = { package = "tarpc-bincode-transport", version = "0.7", path = "../bincode-transport" } +json-transport = { package = "tarpc-json-transport", version = "0.1", path = "../json-transport" } clap = "2.0" -futures-preview = { version = "0.3.0-alpha.16", features = ["compat"] } +futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] } serde = { version = "1.0" } tarpc = { version = "0.18", path = "../tarpc", features = ["serde1"] } tokio = "0.1" +env_logger = "0.6" [lib] name = "service" diff --git a/example-service/src/client.rs b/example-service/src/client.rs index d4f030c..cbaa04d 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -12,7 +12,7 @@ use std::{io, net::SocketAddr}; use tarpc::{client, context}; async fn run(server_addr: SocketAddr, name: String) -> io::Result<()> { - let transport = bincode_transport::connect(&server_addr).await?; + let transport = json_transport::connect(&server_addr).await?; // new_stub is generated by the service! macro. Like Server, it takes a config and any // Transport as input, and returns a Client, also generated by the macro. diff --git a/example-service/src/lib.rs b/example-service/src/lib.rs index 0059610..eb0246e 100644 --- a/example-service/src/lib.rs +++ b/example-service/src/lib.rs @@ -10,5 +10,9 @@ // It defines one RPC, hello, which takes one arg, name, and returns a String. tarpc::service! { /// Returns a greeting for name. - rpc hello(name: String) -> String; + rpc hello(#[serde(default = "default_name")] name: String) -> String; +} + +fn default_name() -> String { + "DefaultName".into() } diff --git a/example-service/src/server.rs b/example-service/src/server.rs index b21ae3e..c0312f1 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -15,13 +15,13 @@ use futures::{ use std::{io, net::SocketAddr}; use tarpc::{ context, - server::{Handler, Server}, + server::{self, Channel, Handler}, }; // This is the type that implements the generated Service trait. It is the business logic // and is used to start the server. #[derive(Clone)] -struct HelloServer; +struct HelloServer(SocketAddr); impl service::Service for HelloServer { // Each defined rpc generates two items in the trait, a fn that serves the RPC, and @@ -30,29 +30,39 @@ impl service::Service for HelloServer { type HelloFut = Ready; fn hello(self, _: context::Context, name: String) -> Self::HelloFut { - future::ready(format!("Hello, {}!", name)) + future::ready(format!( + "Hello, {}! You are connected from {:?}.", + name, self.0 + )) } } async fn run(server_addr: SocketAddr) -> io::Result<()> { // bincode_transport is provided by the associated crate bincode-transport. It makes it easy // to start up a serde-powered bincode serialization strategy over TCP. - let transport = bincode_transport::listen(&server_addr)?; - - // The server is configured with the defaults. - let server = Server::default() - // Server can listen on any type that implements the Transport trait. - .incoming(transport) + json_transport::listen(&server_addr)? + // Ignore accept errors. + .filter_map(|r| future::ready(r.ok())) + .map(server::BaseChannel::with_defaults) + // Limit channels to 1 per IP. + .max_channels_per_key(1, |t| t.as_ref().peer_addr().unwrap().ip()) // serve is generated by the service! macro. It takes as input any type implementing // the generated Service trait. - .respond_with(service::serve(HelloServer)); - - server.await; + .map(|channel| { + let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap()); + channel.respond_with(service::serve(server)) + }) + // Max 10 channels. + .buffer_unordered(10) + .for_each(|_| futures::future::ready(())) + .await; Ok(()) } fn main() { + env_logger::init(); + let flags = App::new("Hello Server") .version("0.1") .author("Tim ") diff --git a/json-transport/Cargo.toml b/json-transport/Cargo.toml index 8613f86..212db90 100644 --- a/json-transport/Cargo.toml +++ b/json-transport/Cargo.toml @@ -13,10 +13,9 @@ readme = "../README.md" description = "A JSON-based transport for tarpc services." [dependencies] -futures-preview = { version = "0.3.0-alpha.16", features = ["compat"] } +futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] } futures_legacy = { version = "0.1", package = "futures" } pin-utils = "0.1.0-alpha.4" -rpc = { package = "tarpc-lib", version = "0.6", path = "../rpc", features = ["serde1"] } serde = "1.0" serde_json = "1.0" tokio = "0.1" @@ -27,9 +26,10 @@ tokio-tcp = "0.1" [dev-dependencies] env_logger = "0.6" humantime = "1.0" -libtest = "0.0.1" log = "0.4" -rand = "0.6" +rand = "0.7" +rand_distr = "0.2" +rpc = { package = "tarpc-lib", version = "0.6", path = "../rpc", features = ["serde1"] } tokio = "0.1" tokio-executor = "0.1" tokio-reactor = "0.1" diff --git a/json-transport/src/lib.rs b/json-transport/src/lib.rs index e1c9109..54aaf98 100644 --- a/json-transport/src/lib.rs +++ b/json-transport/src/lib.rs @@ -67,7 +67,7 @@ where S: AsyncWrite, SinkItem: Serialize, { - type SinkError = io::Error; + type Error = io::Error; fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { self.inner() @@ -88,7 +88,9 @@ where } } -fn convert>>(poll: Poll>) -> Poll> { +fn convert>>( + poll: Poll>, +) -> Poll> { match poll { Poll::Pending => Poll::Pending, Poll::Ready(Ok(())) => Poll::Ready(Ok(())), @@ -96,15 +98,9 @@ fn convert>>(poll: Poll>) -> Poll } } -impl rpc::Transport for Transport -where - Item: for<'de> Deserialize<'de>, - SinkItem: Serialize, -{ - type Item = Item; - type SinkItem = SinkItem; - - fn peer_addr(&self) -> io::Result { +impl Transport { + /// Returns the peer address of the underlying TcpStream. + pub fn peer_addr(&self) -> io::Result { self.inner .get_ref() .get_ref() @@ -113,7 +109,8 @@ where .peer_addr() } - fn local_addr(&self) -> io::Result { + /// Returns the local address of the underlying TcpStream. + pub fn local_addr(&self) -> io::Result { self.inner .get_ref() .get_ref() diff --git a/json-transport/tests/bench.rs b/json-transport/tests/bench.rs index f440985..0a27bcf 100644 --- a/json-transport/tests/bench.rs +++ b/json-transport/tests/bench.rs @@ -8,8 +8,10 @@ #![feature(test, integer_atomics, async_await)] +extern crate test; + use futures::{compat::Executor01CompatExt, prelude::*}; -use libtest::stats::Stats; +use test::stats::Stats; use rpc::{ client, context, server::{Handler, Server}, @@ -20,8 +22,9 @@ use std::{ }; async fn bench() -> io::Result<()> { - let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let addr = listener.local_addr(); + let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())? + .filter_map(|r| future::ready(r.ok())); + let addr = listener.get_ref().local_addr(); tokio_executor::spawn( Server::::default() diff --git a/json-transport/tests/cancel.rs b/json-transport/tests/cancel.rs index 9f62cf7..c05a7c1 100644 --- a/json-transport/tests/cancel.rs +++ b/json-transport/tests/cancel.rs @@ -7,6 +7,7 @@ //! Tests client/server control flow. #![feature(async_await)] +#![feature(async_closure)] use futures::{ compat::{Executor01CompatExt, Future01CompatExt}, @@ -14,8 +15,11 @@ use futures::{ stream::FuturesUnordered, }; use log::{info, trace}; -use rand::distributions::{Distribution, Normal}; -use rpc::{client, context, server::Server}; +use rand_distr::{Distribution, Normal}; +use rpc::{ + client, context, + server::{Channel, Server}, +}; use std::{ io, time::{Duration, Instant, SystemTime}, @@ -34,18 +38,14 @@ impl AsDuration for SystemTime { } async fn run() -> io::Result<()> { - let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let addr = listener.local_addr(); + let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())? + .filter_map(|r| future::ready(r.ok())); + let addr = listener.get_ref().local_addr(); let server = Server::::default() .incoming(listener) .take(1) .for_each(async move |channel| { - let channel = if let Ok(channel) = channel { - channel - } else { - return; - }; - let client_addr = *channel.client_addr(); + let client_addr = channel.get_ref().peer_addr().unwrap(); let handler = channel.respond_with(move |ctx, request| { // Sleep for a time sampled from a normal distribution with: // - mean: 1/2 the deadline. @@ -53,7 +53,7 @@ async fn run() -> io::Result<()> { let deadline: Duration = ctx.deadline.as_duration(); let deadline_millis = deadline.as_secs() * 1000 + deadline.subsec_millis() as u64; let distribution = - Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.); + Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.).unwrap(); let delay_millis = distribution.sample(&mut rand::thread_rng()).max(0.); let delay = Duration::from_millis(delay_millis as u64); @@ -79,20 +79,16 @@ async fn run() -> io::Result<()> { let client = client::new::(client::Config::default(), conn).await?; // Proxy service - let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let addr = listener.local_addr(); + let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())? + .filter_map(|r| future::ready(r.ok())); + let addr = listener.get_ref().local_addr(); let proxy_server = Server::::default() .incoming(listener) .take(1) .for_each(move |channel| { let client = client.clone(); async move { - let channel = if let Ok(channel) = channel { - channel - } else { - return; - }; - let client_addr = *channel.client_addr(); + let client_addr = channel.get_ref().peer_addr().unwrap(); let handler = channel.respond_with(move |ctx, request| { trace!("[{}/{}] Proxying request.", ctx.trace_id(), client_addr); let mut client = client.clone(); diff --git a/json-transport/tests/pushback.rs b/json-transport/tests/pushback.rs index 81f8f0e..ab82d72 100644 --- a/json-transport/tests/pushback.rs +++ b/json-transport/tests/pushback.rs @@ -7,14 +7,18 @@ //! Tests client/server control flow. #![feature(async_await)] +#![feature(async_closure)] use futures::{ compat::{Executor01CompatExt, Future01CompatExt}, prelude::*, }; use log::{error, info, trace}; -use rand::distributions::{Distribution, Normal}; -use rpc::{client, context, server::Server}; +use rand_distr::{Distribution, Normal}; +use rpc::{ + client, context, + server::{Channel, Server}, +}; use std::{ io, time::{Duration, Instant, SystemTime}, @@ -33,18 +37,14 @@ impl AsDuration for SystemTime { } async fn run() -> io::Result<()> { - let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let addr = listener.local_addr(); + let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())? + .filter_map(|r| future::ready(r.ok())); + let addr = listener.get_ref().local_addr(); let server = Server::::default() .incoming(listener) .take(1) .for_each(async move |channel| { - let channel = if let Ok(channel) = channel { - channel - } else { - return; - }; - let client_addr = *channel.client_addr(); + let client_addr = channel.get_ref().peer_addr().unwrap(); let handler = channel.respond_with(move |ctx, request| { // Sleep for a time sampled from a normal distribution with: // - mean: 1/2 the deadline. @@ -52,7 +52,7 @@ async fn run() -> io::Result<()> { let deadline: Duration = ctx.deadline.as_duration(); let deadline_millis = deadline.as_secs() * 1000 + deadline.subsec_millis() as u64; let distribution = - Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.); + Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.).unwrap(); let delay_millis = distribution.sample(&mut rand::thread_rng()).max(0.); let delay = Duration::from_millis(delay_millis as u64); diff --git a/rpc/Cargo.toml b/rpc/Cargo.toml index 3d3ba7f..b0e611c 100644 --- a/rpc/Cargo.toml +++ b/rpc/Cargo.toml @@ -18,17 +18,17 @@ serde1 = ["trace/serde", "serde", "serde/derive"] [dependencies] fnv = "1.0" -futures-preview = { version = "0.3.0-alpha.16", features = ["compat"] } +futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] } humantime = "1.0" log = "0.4" pin-utils = "0.1.0-alpha.4" -rand = "0.6" +rand = "0.7" tokio-timer = "0.2" trace = { package = "tarpc-trace", version = "0.2", path = "../trace" } serde = { optional = true, version = "1.0" } [dev-dependencies] -futures-test-preview = { version = "0.3.0-alpha.16" } +futures-test-preview = { version = "0.3.0-alpha.17" } env_logger = "0.6" tokio = "0.1" tokio-executor = "0.1" diff --git a/rpc/src/client/channel.rs b/rpc/src/client/channel.rs index 3577e31..ecd6913 100644 --- a/rpc/src/client/channel.rs +++ b/rpc/src/client/channel.rs @@ -7,7 +7,7 @@ use crate::{ context, util::{deadline_compat, AsDuration, Compact}, - ClientMessage, ClientMessageKind, PollIo, Request, Response, Transport, + ClientMessage, PollIo, Request, Response, Transport, }; use fnv::FnvHashMap; use futures::{ @@ -24,7 +24,6 @@ use pin_utils::{unsafe_pinned, unsafe_unpinned}; use std::{ io, marker::{self, Unpin}, - net::SocketAddr, pin::Pin, sync::{ atomic::{AtomicU64, Ordering}, @@ -44,7 +43,6 @@ pub struct Channel { cancellation: RequestCancellation, /// The ID to use for the next request to stage. next_request_id: Arc, - server_addr: SocketAddr, } impl Clone for Channel { @@ -53,7 +51,6 @@ impl Clone for Channel { to_dispatch: self.to_dispatch.clone(), cancellation: self.cancellation.clone(), next_request_id: self.next_request_id.clone(), - server_addr: self.server_addr, } } } @@ -122,9 +119,8 @@ impl Channel { let timeout = ctx.deadline.as_duration(); let deadline = Instant::now() + timeout; trace!( - "[{}/{}] Queuing request with deadline {} (timeout {:?}).", + "[{}] Queuing request with deadline {} (timeout {:?}).", ctx.trace_id(), - self.server_addr, format_rfc3339(ctx.deadline), timeout, ); @@ -132,7 +128,6 @@ impl Channel { let (response_completion, response) = oneshot::channel(); let cancellation = self.cancellation.clone(); let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); - let server_addr = self.server_addr; Send { fut: MapOkDispatchResponse::new( MapErrConnectionReset::new(self.to_dispatch.send(DispatchRequest { @@ -147,7 +142,6 @@ impl Channel { request_id, cancellation, ctx, - server_addr, }, ), } @@ -171,11 +165,9 @@ struct DispatchResponse { complete: bool, cancellation: RequestCancellation, request_id: u64, - server_addr: SocketAddr, } impl DispatchResponse { - unsafe_pinned!(server_addr: SocketAddr); unsafe_pinned!(ctx: context::Context); } @@ -192,7 +184,6 @@ impl Future for DispatchResponse { } Err(e) => Err({ let trace_id = *self.as_mut().ctx().trace_id(); - let server_addr = *self.as_mut().server_addr(); if e.is_elapsed() { io::Error::new( @@ -209,12 +200,9 @@ impl Future for DispatchResponse { .to_string(), ) } else if e.is_shutdown() { - panic!("[{}/{}] Timer was shutdown", trace_id, server_addr) + panic!("[{}] Timer was shutdown", trace_id) } else { - panic!( - "[{}/{}] Unrecognized timer error: {}", - trace_id, server_addr, e - ) + panic!("[{}] Unrecognized timer error: {}", trace_id, e) } } else if e.is_inner() { // The oneshot is Canceled when the dispatch task ends. In that case, @@ -223,10 +211,7 @@ impl Future for DispatchResponse { self.complete = true; io::Error::from(io::ErrorKind::ConnectionReset) } else { - panic!( - "[{}/{}] Unrecognized deadline error: {}", - trace_id, server_addr, e - ) + panic!("[{}] Unrecognized deadline error: {}", trace_id, e) } }), }) @@ -255,15 +240,11 @@ impl Drop for DispatchResponse { /// Spawns a dispatch task on the default executor that manages the lifecycle of requests initiated /// by the returned [`Channel`]. -pub async fn spawn( - config: Config, - transport: C, - server_addr: SocketAddr, -) -> io::Result> +pub async fn spawn(config: Config, transport: C) -> io::Result> where Req: marker::Send + 'static, Resp: marker::Send + 'static, - C: Transport, SinkItem = ClientMessage> + marker::Send + 'static, + C: Transport, Response> + marker::Send + 'static, { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); @@ -272,13 +253,12 @@ where crate::spawn( RequestDispatch { config, - server_addr, canceled_requests, transport: transport.fuse(), in_flight_requests: FnvHashMap::default(), pending_requests: pending_requests.fuse(), } - .unwrap_or_else(move |e| error!("[{}] Connection broken: {}", server_addr, e)), + .unwrap_or_else(move |e| error!("Connection broken: {}", e)), ) .map_err(|e| { io::Error::new( @@ -293,7 +273,6 @@ where Ok(Channel { to_dispatch, cancellation, - server_addr, next_request_id: Arc::new(AtomicU64::new(0)), }) } @@ -311,17 +290,14 @@ struct RequestDispatch { in_flight_requests: FnvHashMap>, /// Configures limits to prevent unlimited resource usage. config: Config, - /// The address of the server connected to. - server_addr: SocketAddr, } impl RequestDispatch where Req: marker::Send, Resp: marker::Send, - C: Transport, SinkItem = ClientMessage>, + C: Transport, Response>, { - unsafe_pinned!(server_addr: SocketAddr); unsafe_pinned!(in_flight_requests: FnvHashMap>); unsafe_pinned!(canceled_requests: Fuse); unsafe_pinned!(pending_requests: Fuse>>); @@ -333,10 +309,7 @@ where self.complete(response); Some(Ok(())) } - None => { - trace!("[{}] read half closed", self.as_mut().server_addr()); - None - } + None => None, }) } @@ -415,10 +388,7 @@ where return Poll::Ready(Some(Ok(request))); } - None => { - trace!("[{}] pending_requests closed", self.as_mut().server_addr()); - return Poll::Ready(None); - } + None => return Poll::Ready(None), } } } @@ -440,23 +410,11 @@ where self.as_mut().in_flight_requests().remove(&request_id) { self.as_mut().in_flight_requests().compact(0.1); - - debug!( - "[{}/{}] Removed request.", - in_flight_data.ctx.trace_id(), - self.as_mut().server_addr() - ); - + debug!("[{}] Removed request.", in_flight_data.ctx.trace_id()); return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id)))); } } - None => { - trace!( - "[{}] canceled_requests closed.", - self.as_mut().server_addr() - ); - return Poll::Ready(None); - } + None => return Poll::Ready(None), } } } @@ -466,14 +424,14 @@ where dispatch_request: DispatchRequest, ) -> io::Result<()> { let request_id = dispatch_request.request_id; - let request = ClientMessage { - trace_context: dispatch_request.ctx.trace_context, - message: ClientMessageKind::Request(Request { - id: request_id, - message: dispatch_request.request, + let request = ClientMessage::Request(Request { + id: request_id, + message: dispatch_request.request, + context: context::Context { deadline: dispatch_request.ctx.deadline, - }), - }; + trace_context: dispatch_request.ctx.trace_context, + }, + }); self.as_mut().transport().start_send(request)?; self.as_mut().in_flight_requests().insert( request_id, @@ -491,16 +449,12 @@ where request_id: u64, ) -> io::Result<()> { let trace_id = *context.trace_id(); - let cancel = ClientMessage { + let cancel = ClientMessage::Cancel { trace_context: context.trace_context, - message: ClientMessageKind::Cancel { request_id }, + request_id, }; self.as_mut().transport().start_send(cancel)?; - trace!( - "[{}/{}] Cancel message sent.", - trace_id, - self.as_mut().server_addr() - ); + trace!("[{}] Cancel message sent.", trace_id); Ok(()) } @@ -513,18 +467,13 @@ where { self.as_mut().in_flight_requests().compact(0.1); - trace!( - "[{}/{}] Received response.", - in_flight_data.ctx.trace_id(), - self.as_mut().server_addr() - ); + trace!("[{}] Received response.", in_flight_data.ctx.trace_id()); let _ = in_flight_data.response_completion.send(response); return true; } debug!( - "[{}] No in-flight request found for request_id = {}.", - self.as_mut().server_addr(), + "No in-flight request found for request_id = {}.", response.request_id ); @@ -537,58 +486,29 @@ impl Future for RequestDispatch where Req: marker::Send, Resp: marker::Send, - C: Transport, SinkItem = ClientMessage>, + C: Transport, Response>, { type Output = io::Result<()>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - trace!("[{}] RequestDispatch::poll", self.as_mut().server_addr()); loop { match (self.pump_read(cx)?, self.pump_write(cx)?) { - (read, write @ Poll::Ready(None)) => { + (read, Poll::Ready(None)) => { if self.as_mut().in_flight_requests().is_empty() { - info!( - "[{}] Shutdown: write half closed, and no requests in flight.", - self.as_mut().server_addr() - ); + info!("Shutdown: write half closed, and no requests in flight."); return Poll::Ready(Ok(())); } - let addr = *self.as_mut().server_addr(); info!( - "[{}] {} requests in flight.", - addr, + "Shutdown: write half closed, and {} requests in flight.", self.as_mut().in_flight_requests().len() ); match read { Poll::Ready(Some(())) => continue, - _ => { - trace!( - "[{}] read: {:?}, write: {:?}, (not ready)", - self.as_mut().server_addr(), - read, - write, - ); - return Poll::Pending; - } + _ => return Poll::Pending, } } - (read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => { - trace!( - "[{}] read: {:?}, write: {:?}", - self.as_mut().server_addr(), - read, - write, - ) - } - (read, write) => { - trace!( - "[{}] read: {:?}, write: {:?} (not ready)", - self.as_mut().server_addr(), - read, - write, - ); - return Poll::Pending; - } + (Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {} + _ => return Poll::Pending, } } } @@ -848,14 +768,7 @@ mod tests { }; use futures_test::task::noop_waker_ref; use std::time::Duration; - use std::{ - marker, - net::{IpAddr, Ipv4Addr, SocketAddr}, - pin::Pin, - sync::atomic::AtomicU64, - sync::Arc, - time::Instant, - }; + use std::{marker, pin::Pin, sync::atomic::AtomicU64, sync::Arc, time::Instant}; #[test] fn dispatch_response_cancels_on_timeout() { @@ -869,7 +782,6 @@ mod tests { request_id: 3, cancellation, ctx: context::current(), - server_addr: SocketAddr::from(([0, 0, 0, 0], 9999)), }; { pin_utils::pin_mut!(resp); @@ -994,7 +906,6 @@ mod tests { canceled_requests: CanceledRequests(canceled_requests).fuse(), in_flight_requests: FnvHashMap::default(), config: Config::default(), - server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), }; let cancellation = RequestCancellation(cancel_tx); @@ -1002,7 +913,6 @@ mod tests { to_dispatch, cancellation, next_request_id: Arc::new(AtomicU64::new(0)), - server_addr: SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), }; (dispatch, channel, server_channel) diff --git a/rpc/src/client/mod.rs b/rpc/src/client/mod.rs index 94f2325..ccb0758 100644 --- a/rpc/src/client/mod.rs +++ b/rpc/src/client/mod.rs @@ -8,11 +8,7 @@ use crate::{context, ClientMessage, Response, Transport}; use futures::prelude::*; -use log::warn; -use std::{ - io, - net::{Ipv4Addr, SocketAddr}, -}; +use std::io; /// Provides a [`Client`] backed by a transport. pub mod channel; @@ -137,15 +133,7 @@ pub async fn new(config: Config, transport: T) -> io::Result, SinkItem = ClientMessage> + Send + 'static, + T: Transport, Response> + Send + 'static, { - let server_addr = transport.peer_addr().unwrap_or_else(|e| { - warn!( - "Setting peer to unspecified because peer could not be determined: {}", - e - ); - SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0) - }); - - Ok(channel::spawn(config, transport, server_addr).await?) + Ok(channel::spawn(config, transport).await?) } diff --git a/rpc/src/context.rs b/rpc/src/context.rs index 677d355..83da870 100644 --- a/rpc/src/context.rs +++ b/rpc/src/context.rs @@ -16,10 +16,20 @@ use trace::{self, TraceId}; /// The context should not be stored directly in a server implementation, because the context will /// be different for each request in scope. #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[non_exhaustive] pub struct Context { /// When the client expects the request to be complete by. The server should cancel the request /// if it is not complete by this time. + #[cfg_attr( + feature = "serde1", + serde(serialize_with = "crate::util::serde::serialize_epoch_secs") + )] + #[cfg_attr( + feature = "serde1", + serde(deserialize_with = "crate::util::serde::deserialize_epoch_secs") + )] + #[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))] pub deadline: SystemTime, /// Uniquely identifies requests originating from the same source. /// When a service handles a request by making requests itself, those requests should @@ -28,6 +38,11 @@ pub struct Context { pub trace_context: trace::Context, } +#[cfg(feature = "serde1")] +fn ten_seconds_from_now() -> SystemTime { + return SystemTime::now() + Duration::from_secs(10); +} + /// Returns the context for the current request, or a default Context if no request is active. // TODO: populate Context with request-scoped data, with default fallbacks. pub fn current() -> Context { diff --git a/rpc/src/lib.rs b/rpc/src/lib.rs index fec9dd7..c42b1c0 100644 --- a/rpc/src/lib.rs +++ b/rpc/src/lib.rs @@ -5,11 +5,14 @@ // https://opensource.org/licenses/MIT. #![feature( + weak_counts, non_exhaustive, integer_atomics, try_trait, arbitrary_self_types, - async_await + async_await, + trait_alias, + async_closure )] #![deny(missing_docs, missing_debug_implementations)] @@ -49,19 +52,7 @@ use std::{cell::RefCell, io, sync::Once, time::SystemTime}; #[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[non_exhaustive] -pub struct ClientMessage { - /// The trace context associates the message with a specific chain of causally-related actions, - /// possibly orchestrated across many distributed systems. - pub trace_context: trace::Context, - /// The message payload. - pub message: ClientMessageKind, -} - -/// Different messages that can be sent from a client to a server. -#[derive(Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -#[non_exhaustive] -pub enum ClientMessageKind { +pub enum ClientMessage { /// A request initiated by a user. The server responds to a request by invoking a /// service-provided request handler. The handler completes with a [`response`](Response), which /// the server sends back to the client. @@ -74,35 +65,30 @@ pub enum ClientMessageKind { /// not be canceled, because the framework layer does not /// know about them. Cancel { + /// The trace context associates the message with a specific chain of causally-related actions, + /// possibly orchestrated across many distributed systems. + #[cfg_attr(feature = "serde", serde(default))] + trace_context: trace::Context, /// The ID of the request to cancel. request_id: u64, }, } /// A request from a client to a server. -#[derive(Debug)] +#[derive(Clone, Copy, Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[non_exhaustive] pub struct Request { + /// Trace context, deadline, and other cross-cutting concerns. + pub context: context::Context, /// Uniquely identifies the request across all requests sent over a single channel. pub id: u64, /// The request body. pub message: T, - /// When the client expects the request to be complete by. The server will cancel the request - /// if it is not complete by this time. - #[cfg_attr( - feature = "serde1", - serde(serialize_with = "util::serde::serialize_epoch_secs") - )] - #[cfg_attr( - feature = "serde1", - serde(deserialize_with = "util::serde::deserialize_epoch_secs") - )] - pub deadline: SystemTime, } /// A response from a server to a client. -#[derive(Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[non_exhaustive] pub struct Response { @@ -113,7 +99,7 @@ pub struct Response { } /// An error response from a server to a client. -#[derive(Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[non_exhaustive] pub struct ServerError { @@ -140,7 +126,7 @@ impl From for io::Error { impl Request { /// Returns the deadline for this request. pub fn deadline(&self) -> &SystemTime { - &self.deadline + &self.context.deadline } } diff --git a/rpc/src/server/filter.rs b/rpc/src/server/filter.rs index 5d24d61..8b2ce9a 100644 --- a/rpc/src/server/filter.rs +++ b/rpc/src/server/filter.rs @@ -5,259 +5,331 @@ // https://opensource.org/licenses/MIT. use crate::{ - server::{Channel, Config}, + server::{self, Channel}, util::Compact, - ClientMessage, PollIo, Response, Transport, + Response, }; use fnv::FnvHashMap; use futures::{ channel::mpsc, + future::AbortRegistration, prelude::*, ready, stream::Fuse, task::{Context, Poll}, }; -use log::{debug, error, info, trace, warn}; -use pin_utils::unsafe_pinned; +use log::{debug, info, trace}; +use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use std::sync::{Arc, Weak}; use std::{ - collections::hash_map::Entry, - io, - marker::PhantomData, - net::{IpAddr, SocketAddr}, - ops::Try, - option::NoneError, + collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, io, marker::Unpin, ops::Try, pin::Pin, }; -/// Drops connections under configurable conditions: -/// -/// 1. If the max number of connections is reached. -/// 2. If the max number of connections for a single IP is reached. +/// A single-threaded filter that drops channels based on per-key limits. #[derive(Debug)] -pub struct ConnectionFilter { +pub struct ChannelFilter +where + K: Eq + Hash, +{ listener: Fuse, - closed_connections: mpsc::UnboundedSender, - closed_connections_rx: mpsc::UnboundedReceiver, - config: Config, - connections_per_ip: FnvHashMap, - open_connections: usize, - ghost: PhantomData<(Req, Resp)>, + channels_per_key: u32, + dropped_keys: mpsc::UnboundedReceiver, + dropped_keys_tx: mpsc::UnboundedSender, + key_counts: FnvHashMap>>, + keymaker: F, } -enum NewConnection { - Filtered, - Accepted(Channel), +/// A channel that is tracked by a ChannelFilter. +#[derive(Debug)] +pub struct TrackedChannel { + inner: C, + tracker: Arc>, } -impl Try for NewConnection { - type Ok = Channel; - type Error = NoneError; +impl TrackedChannel { + unsafe_pinned!(inner: C); +} - fn into_result(self) -> Result, NoneError> { +#[derive(Debug)] +struct Tracker { + key: Option, + dropped_keys: mpsc::UnboundedSender, +} + +impl Drop for Tracker { + fn drop(&mut self) { + // Don't care if the listener is dropped. + let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap()); + } +} + +/// A running handler serving all requests for a single client. +#[derive(Debug)] +pub struct TrackedHandler { + inner: Fut, + tracker: Tracker, +} + +impl TrackedHandler +where + Fut: Future, +{ + unsafe_pinned!(inner: Fut); +} + +impl Future for TrackedHandler +where + Fut: Future, +{ + type Output = Fut::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.inner().poll(cx) + } +} + +impl Stream for TrackedChannel +where + C: Channel, +{ + type Item = ::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.channel().poll_next(cx) + } +} + +impl Sink> for TrackedChannel +where + C: Channel, +{ + type Error = io::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.channel().poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Response) -> Result<(), Self::Error> { + self.channel().start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.channel().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.channel().poll_close(cx) + } +} + +impl AsRef for TrackedChannel { + fn as_ref(&self) -> &C { + &self.inner + } +} + +impl Channel for TrackedChannel +where + C: Channel, +{ + type Req = C::Req; + type Resp = C::Resp; + + fn config(&self) -> &server::Config { + self.inner.config() + } + + fn in_flight_requests(self: Pin<&mut Self>) -> usize { + self.inner().in_flight_requests() + } + + fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration { + self.inner().start_request(request_id) + } +} + +impl TrackedChannel { + /// Returns the inner channel. + pub fn get_ref(&self) -> &C { + &self.inner + } + + /// Returns the pinned inner channel. + fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> { + self.inner() + } +} + +enum NewChannel { + Accepted(TrackedChannel), + Filtered(K), +} + +impl Try for NewChannel { + type Ok = TrackedChannel; + type Error = K; + + fn into_result(self) -> Result, K> { match self { - NewConnection::Filtered => Err(NoneError), - NewConnection::Accepted(channel) => Ok(channel), + NewChannel::Accepted(channel) => Ok(channel), + NewChannel::Filtered(k) => Err(k), } } - fn from_error(_: NoneError) -> Self { - NewConnection::Filtered + fn from_error(k: K) -> Self { + NewChannel::Filtered(k) } - fn from_ok(channel: Channel) -> Self { - NewConnection::Accepted(channel) + fn from_ok(channel: TrackedChannel) -> Self { + NewChannel::Accepted(channel) } } -impl ConnectionFilter { - unsafe_pinned!(open_connections: usize); - unsafe_pinned!(config: Config); - unsafe_pinned!(connections_per_ip: FnvHashMap); - unsafe_pinned!(closed_connections_rx: mpsc::UnboundedReceiver); +impl ChannelFilter +where + K: fmt::Display + Eq + Hash + Clone, +{ unsafe_pinned!(listener: Fuse); + unsafe_pinned!(dropped_keys: mpsc::UnboundedReceiver); + unsafe_pinned!(dropped_keys_tx: mpsc::UnboundedSender); + unsafe_unpinned!(key_counts: FnvHashMap>>); + unsafe_unpinned!(channels_per_key: u32); + unsafe_unpinned!(keymaker: F); +} - /// Sheds new connections to stay under configured limits. - pub fn filter(listener: S, config: Config) -> Self - where - S: Stream>, - C: Transport, SinkItem = Response> + Send, - { - let (closed_connections, closed_connections_rx) = mpsc::unbounded(); - - ConnectionFilter { +impl ChannelFilter +where + K: Eq + Hash, + S: Stream, +{ + /// Sheds new channels to stay under configured limits. + pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self { + let (dropped_keys_tx, dropped_keys) = mpsc::unbounded(); + ChannelFilter { listener: listener.fuse(), - closed_connections, - closed_connections_rx, - config, - connections_per_ip: FnvHashMap::default(), - open_connections: 0, - ghost: PhantomData, + channels_per_key, + dropped_keys, + dropped_keys_tx, + key_counts: FnvHashMap::default(), + keymaker, } } +} - fn handle_new_connection(self: &mut Pin<&mut Self>, stream: C) -> NewConnection - where - C: Transport, SinkItem = Response> + Send, - { - let peer = match stream.peer_addr() { - Ok(peer) => peer, - Err(e) => { - warn!("Could not get peer_addr of new connection: {}", e); - return NewConnection::Filtered; - } - }; - - let open_connections = *self.as_mut().open_connections(); - if open_connections >= self.as_mut().config().max_connections { - warn!( - "[{}] Shedding connection because the maximum open connections \ - limit is reached ({}/{}).", - peer, - open_connections, - self.as_mut().config().max_connections - ); - return NewConnection::Filtered; - } - - let config = self.config.clone(); - let open_connections_for_ip = self.increment_connections_for_ip(&peer)?; - *self.as_mut().open_connections() += 1; +impl ChannelFilter +where + S: Stream, + C: Channel, + K: fmt::Display + Eq + Hash + Clone + Unpin, + F: Fn(&C) -> K, +{ + fn handle_new_channel(self: &mut Pin<&mut Self>, stream: C) -> NewChannel { + let key = self.as_mut().keymaker()(&stream); + let tracker = self.increment_channels_for_key(key.clone())?; + let max = self.as_mut().channels_per_key(); debug!( - "[{}] Opening channel ({}/{} connections for IP, {} total).", - peer, - open_connections_for_ip, - config.max_connections_per_ip, - self.as_mut().open_connections(), + "[{}] Opening channel ({}/{}) channels for key.", + key, + Arc::strong_count(&tracker), + max ); - NewConnection::Accepted(Channel { - client_addr: peer, - closed_connections: self.closed_connections.clone(), - transport: stream.fuse(), - config, - ghost: PhantomData, + NewChannel::Accepted(TrackedChannel { + tracker, + inner: stream, }) } - fn handle_closed_connection(self: &mut Pin<&mut Self>, addr: &SocketAddr) { - *self.as_mut().open_connections() -= 1; - debug!( - "[{}] Closing channel. {} open connections remaining.", - addr, self.open_connections - ); - self.decrement_connections_for_ip(&addr); - self.as_mut().connections_per_ip().compact(0.1); - } + fn increment_channels_for_key(self: &mut Pin<&mut Self>, key: K) -> Result>, K> { + let channels_per_key = self.channels_per_key; + let dropped_keys = self.dropped_keys_tx.clone(); + let key_counts = &mut self.as_mut().key_counts(); + match key_counts.entry(key.clone()) { + Entry::Vacant(vacant) => { + let tracker = Arc::new(Tracker { + key: Some(key), + dropped_keys, + }); - fn increment_connections_for_ip(self: &mut Pin<&mut Self>, peer: &SocketAddr) -> Option { - let max_connections_per_ip = self.as_mut().config().max_connections_per_ip; - let mut occupied; - let mut connections_per_ip = self.as_mut().connections_per_ip(); - let occupied = match connections_per_ip.entry(peer.ip()) { - Entry::Vacant(vacant) => vacant.insert(0), - Entry::Occupied(o) => { - if *o.get() < max_connections_per_ip { - // Store the reference outside the block to extend the lifetime. - occupied = o; - occupied.get_mut() - } else { + vacant.insert(Arc::downgrade(&tracker)); + Ok(tracker) + } + Entry::Occupied(mut o) => { + let count = o.get().strong_count(); + if count >= channels_per_key.try_into().unwrap() { info!( - "[{}] Opened max connections from IP ({}/{}).", - peer, - o.get(), - max_connections_per_ip + "[{}] Opened max channels from key ({}/{}).", + key, count, channels_per_key ); - return None; - } - } - }; - *occupied += 1; - Some(*occupied) - } - - fn decrement_connections_for_ip(self: &mut Pin<&mut Self>, addr: &SocketAddr) { - let should_compact = match self.as_mut().connections_per_ip().entry(addr.ip()) { - Entry::Vacant(_) => { - error!("[{}] Got vacant entry when closing connection.", addr); - return; - } - Entry::Occupied(mut occupied) => { - *occupied.get_mut() -= 1; - if *occupied.get() == 0 { - occupied.remove(); - true + Err(key) } else { - false + Ok(o.get().upgrade().unwrap_or_else(|| { + let tracker = Arc::new(Tracker { + key: Some(key), + dropped_keys, + }); + + *o.get_mut() = Arc::downgrade(&tracker); + tracker + })) } } - }; - if should_compact { - self.as_mut().connections_per_ip().compact(0.1); } } - fn poll_listener( + fn poll_listener( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> PollIo> - where - S: Stream>, - C: Transport, SinkItem = Response> + Send, - { - match ready!(self.as_mut().listener().poll_next_unpin(cx)?) { - Some(codec) => Poll::Ready(Some(Ok(self.handle_new_connection(codec)))), + ) -> Poll>> { + match ready!(self.as_mut().listener().poll_next_unpin(cx)) { + Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))), None => Poll::Ready(None), } } - fn poll_closed_connections( - self: &mut Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match ready!(self.as_mut().closed_connections_rx().poll_next_unpin(cx)) { - Some(addr) => { - self.handle_closed_connection(&addr); - Poll::Ready(Ok(())) + fn poll_closed_channels(self: &mut Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + match ready!(self.as_mut().dropped_keys().poll_next_unpin(cx)) { + Some(key) => { + debug!("All channels dropped for key [{}]", key); + self.as_mut().key_counts().remove(&key); + self.as_mut().key_counts().compact(0.1); + Poll::Ready(()) } - None => unreachable!("Holding a copy of closed_connections and didn't close it."), + None => unreachable!("Holding a copy of closed_channels and didn't close it."), } } } -impl Stream for ConnectionFilter +impl Stream for ChannelFilter where - S: Stream>, - T: Transport, SinkItem = Response> + Send, + S: Stream, + C: Channel, + K: fmt::Display + Eq + Hash + Clone + Unpin, + F: Fn(&C) -> K, { - type Item = io::Result>; + type Item = TrackedChannel; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo> { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { loop { match ( - self.as_mut().poll_listener(cx)?, - self.poll_closed_connections(cx)?, + self.as_mut().poll_listener(cx), + self.poll_closed_channels(cx), ) { - (Poll::Ready(Some(NewConnection::Accepted(channel))), _) => { - return Poll::Ready(Some(Ok(channel))); + (Poll::Ready(Some(NewChannel::Accepted(channel))), _) => { + return Poll::Ready(Some(channel)); } - (Poll::Ready(Some(NewConnection::Filtered)), _) | (_, Poll::Ready(())) => { - trace!( - "Filtered a connection; {} open.", - self.as_mut().open_connections() - ); + (Poll::Ready(Some(NewChannel::Filtered(_))), _) => { continue; } + (_, Poll::Ready(())) => continue, (Poll::Pending, Poll::Pending) => return Poll::Pending, (Poll::Ready(None), Poll::Pending) => { - if *self.as_mut().open_connections() > 0 { - trace!( - "Listener closed; {} open connections.", - self.as_mut().open_connections() - ); - return Poll::Pending; - } - trace!("Shutting down listener: all connections closed, and no more coming."); + trace!("Shutting down listener."); return Poll::Ready(None); } } diff --git a/rpc/src/server/mod.rs b/rpc/src/server/mod.rs index 16a8d9b..1a9f2c6 100644 --- a/rpc/src/server/mod.rs +++ b/rpc/src/server/mod.rs @@ -7,27 +7,27 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. use crate::{ - context, util::deadline_compat, util::AsDuration, util::Compact, ClientMessage, - ClientMessageKind, PollIo, Request, Response, ServerError, Transport, + context, util::deadline_compat, util::AsDuration, util::Compact, ClientMessage, PollIo, + Request, Response, ServerError, Transport, }; use fnv::FnvHashMap; use futures::{ channel::mpsc, - future::{abortable, AbortHandle}, + future::{AbortHandle, AbortRegistration, Abortable}, prelude::*, ready, stream::Fuse, task::{Context, Poll}, - try_ready, }; use humantime::format_rfc3339; use log::{debug, error, info, trace, warn}; use pin_utils::{unsafe_pinned, unsafe_unpinned}; use std::{ error::Error as StdError, + fmt, + hash::Hash, io, marker::PhantomData, - net::SocketAddr, pin::Pin, time::{Instant, SystemTime}, }; @@ -35,6 +35,14 @@ use tokio_timer::timeout; use trace::{self, TraceId}; mod filter; +#[cfg(test)] +mod testing; +mod throttle; + +pub use self::{ + filter::ChannelFilter, + throttle::{Throttler, ThrottlerStream}, +}; /// Manages clients, serving multiplexed requests over each connection. #[derive(Debug)] @@ -53,17 +61,6 @@ impl Default for Server { #[non_exhaustive] #[derive(Clone, Debug)] pub struct Config { - /// The maximum number of clients that can be connected to the server at once. When at the - /// limit, existing connections are honored and new connections are rejected. - pub max_connections: usize, - /// The maximum number of clients per IP address that can be connected to the server at once. - /// When an IP is at the limit, existing connections are honored and new connections on that IP - /// address are rejected. - pub max_connections_per_ip: usize, - /// The maximum number of requests that can be in flight for each client. When a client is at - /// the in-flight request limit, existing requests are fulfilled and new requests are rejected. - /// Rejected requests are sent a response error. - pub max_in_flight_requests_per_connection: usize, /// The number of responses per client that can be buffered server-side before being sent. /// `pending_response_buffer` controls the buffer size of the channel that a server's /// response tasks use to send responses to the client handler task. @@ -73,14 +70,21 @@ pub struct Config { impl Default for Config { fn default() -> Self { Config { - max_connections: 1_000_000, - max_connections_per_ip: 1_000, - max_in_flight_requests_per_connection: 1_000, pending_response_buffer: 100, } } } +impl Config { + /// Returns a channel backed by `transport` and configured with `self`. + pub fn channel(self, transport: T) -> BaseChannel + where + T: Transport, ClientMessage> + Send, + { + BaseChannel::new(self, transport) + } +} + /// Returns a new server with configuration specified `config`. pub fn new(config: Config) -> Server { Server { @@ -95,18 +99,15 @@ impl Server { &self.config } - /// Returns a stream of the incoming connections to the server. - pub fn incoming( - self, - listener: S, - ) -> impl Stream>> + /// Returns a stream of server channels. + pub fn incoming(self, listener: S) -> impl Stream> where Req: Send, Resp: Send, - S: Stream>, - T: Transport, SinkItem = Response> + Send, + S: Stream, + T: Transport, ClientMessage> + Send, { - self::filter::ConnectionFilter::filter(listener, self.config.clone()) + listener.map(move |t| BaseChannel::new(self.config.clone(), t)) } } @@ -122,31 +123,21 @@ impl Running { unsafe_unpinned!(request_handler: F); } -impl Future for Running +impl Future for Running where - S: Sized + Stream>>, - Req: Send + 'static, - Resp: Send + 'static, - T: Transport, SinkItem = Response> + Send + 'static, - F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone, - Fut: Future> + Send + 'static, + S: Sized + Stream, + C: Channel + Send + 'static, + F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone, + Fut: Future> + Send + 'static, { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) { - match channel { - Ok(channel) => { - let peer = channel.client_addr; - if let Err(e) = - crate::spawn(channel.respond_with(self.as_mut().request_handler().clone())) - { - warn!("[{}] Failed to spawn connection handler: {:?}", peer, e); - } - } - Err(e) => { - warn!("Incoming connection error: {}", e); - } + if let Err(e) = + crate::spawn(channel.respond_with(self.as_mut().request_handler().clone())) + { + warn!("Failed to spawn channel handler: {:?}", e); } } info!("Server shutting down."); @@ -155,18 +146,30 @@ where } /// A utility trait enabling a stream to fluently chain a request handler. -pub trait Handler +pub trait Handler where - Self: Sized + Stream>>, - Req: Send, - Resp: Send, - T: Transport, SinkItem = Response> + Send, + Self: Sized + Stream, + C: Channel, { + /// Enforces channel per-key limits. + fn max_channels_per_key(self, n: u32, keymaker: KF) -> filter::ChannelFilter + where + K: fmt::Display + Eq + Hash + Clone + Unpin, + KF: Fn(&C) -> K, + { + ChannelFilter::new(self, n, keymaker) + } + + /// Caps the number of concurrent requests per channel. + fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream { + ThrottlerStream::new(self, n) + } + /// Responds to all requests with `request_handler`. fn respond_with(self, request_handler: F) -> Running where - F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone, - Fut: Future> + Send + 'static, + F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone, + Fut: Future> + Send + 'static, { Running { incoming: self, @@ -175,191 +178,276 @@ where } } -impl Handler for S +impl Handler for S where - S: Sized + Stream>>, - Req: Send, - Resp: Send, - T: Transport, SinkItem = Response> + Send, + S: Sized + Stream, + C: Channel, { } -/// Responds to all requests with `request_handler`. -/// The server end of an open connection with a client. +/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests. #[derive(Debug)] -pub struct Channel { +pub struct BaseChannel { + config: Config, /// Writes responses to the wire and reads requests off the wire. transport: Fuse, - /// Signals the connection is closed when `Channel` is dropped. - closed_connections: mpsc::UnboundedSender, - /// Channel limits to prevent unlimited resource usage. - config: Config, - /// The address of the server connected to. - client_addr: SocketAddr, + /// Number of requests currently being responded to. + in_flight_requests: FnvHashMap, /// Types the request and response. ghost: PhantomData<(Req, Resp)>, } -impl Drop for Channel { - fn drop(&mut self) { - trace!("[{}] Closing channel.", self.client_addr); +impl BaseChannel { + unsafe_unpinned!(in_flight_requests: FnvHashMap); +} - // Even in a bounded channel, each connection would have a guaranteed slot, so using - // an unbounded sender is actually no different. And, the bound is on the maximum number - // of open connections. - if self - .closed_connections - .unbounded_send(self.client_addr) - .is_err() - { - warn!( - "[{}] Failed to send closed connection message.", - self.client_addr +impl BaseChannel +where + T: Transport, ClientMessage> + Send, +{ + /// Creates a new channel backed by `transport` and configured with `config`. + pub fn new(config: Config, transport: T) -> Self { + BaseChannel { + config, + transport: transport.fuse(), + in_flight_requests: FnvHashMap::default(), + ghost: PhantomData, + } + } + + /// Creates a new channel backed by `transport` and configured with the defaults. + pub fn with_defaults(transport: T) -> Self { + Self::new(Config::default(), transport) + } + + /// Returns the inner transport. + pub fn get_ref(&self) -> &T { + self.transport.get_ref() + } + + /// Returns the pinned inner transport. + pub fn transport<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> { + unsafe { self.map_unchecked_mut(|me| me.transport.get_mut()) } + } + + fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) { + // It's possible the request was already completed, so it's fine + // if this is None. + if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) { + self.as_mut().in_flight_requests().compact(0.1); + + cancel_handle.abort(); + let remaining = self.as_mut().in_flight_requests().len(); + trace!( + "[{}] Request canceled. In-flight requests = {}", + trace_context.trace_id, + remaining, + ); + } else { + trace!( + "[{}] Received cancellation, but response handler \ + is already complete.", + trace_context.trace_id, ); } } } -impl Channel { - unsafe_pinned!(transport: Fuse); -} - -impl Channel +/// The server end of an open connection with a client, streaming in requests from, and sinking +/// responses to, the client. +/// +/// Channels are free to somewhat rely on the assumption that all in-flight requests are eventually +/// either [cancelled](Channel::cancel_request) or [responded to](Sink::start_send). Safety cannot +/// rely on this assumption, but it is best for `Channel` users to always account for all outstanding +/// requests. +pub trait Channel where - T: Transport, SinkItem = Response> + Send, - Req: Send, - Resp: Send, + Self: Transport::Resp>, Request<::Req>>, { - pub(crate) fn start_send(mut self: Pin<&mut Self>, response: Response) -> io::Result<()> { - self.as_mut().transport().start_send(response) + /// Type of request item. + type Req: Send + 'static; + + /// Type of response sink item. + type Resp: Send + 'static; + + /// Configuration of the channel. + fn config(&self) -> &Config; + + /// Returns the number of in-flight requests over this channel. + fn in_flight_requests(self: Pin<&mut Self>) -> usize; + + /// Caps the number of concurrent requests. + fn max_concurrent_requests(self, n: usize) -> Throttler + where + Self: Sized, + { + Throttler::new(self, n) } - pub(crate) fn poll_ready( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.as_mut().transport().poll_ready(cx) - } - - pub(crate) fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.as_mut().transport().poll_flush(cx) - } - - pub(crate) fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> PollIo> { - self.as_mut().transport().poll_next(cx) - } - - /// Returns the address of the client connected to the channel. - pub fn client_addr(&self) -> &SocketAddr { - &self.client_addr - } + /// Tells the Channel that request with ID `request_id` is being handled. + /// The request will be tracked until a response with the same ID is sent + /// to the Channel. + fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration; /// Respond to requests coming over the channel with `f`. Returns a future that drives the /// responses and resolves when the connection is closed. - pub fn respond_with(self, f: F) -> impl Future + fn respond_with(self, f: F) -> ResponseHandler where - F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone, - Fut: Future> + Send + 'static, - Req: 'static, - Resp: 'static, + F: FnOnce(context::Context, Self::Req) -> Fut + Send + 'static + Clone, + Fut: Future> + Send + 'static, + Self: Sized, { - let (responses_tx, responses) = mpsc::channel(self.config.pending_response_buffer); + let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer); let responses = responses.fuse(); - let peer = self.client_addr; - ClientHandler { + ResponseHandler { channel: self, f, pending_responses: responses, responses_tx, - in_flight_requests: FnvHashMap::default(), } - .unwrap_or_else(move |e| { - info!("[{}] ClientHandler errored out: {}", peer, e); - }) } } +impl Stream for BaseChannel +where + T: Transport, ClientMessage> + Send + 'static, + Req: Send + 'static, + Resp: Send + 'static, +{ + type Item = io::Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + loop { + match ready!(self.as_mut().transport().poll_next(cx)?) { + Some(message) => match message { + ClientMessage::Request(request) => { + return Poll::Ready(Some(Ok(request))); + } + ClientMessage::Cancel { + trace_context, + request_id, + } => { + self.as_mut().cancel_request(&trace_context, request_id); + } + }, + None => return Poll::Ready(None), + } + } + } +} + +impl Sink> for BaseChannel +where + T: Transport, ClientMessage> + Send + 'static, + Req: Send + 'static, + Resp: Send + 'static, +{ + type Error = io::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.transport().poll_ready(cx) + } + + fn start_send( + mut self: Pin<&mut Self>, + response: Response, + ) -> Result<(), Self::Error> { + if self + .as_mut() + .in_flight_requests() + .remove(&response.request_id) + .is_some() + { + self.as_mut().in_flight_requests().compact(0.1); + } + + self.transport().start_send(response) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.transport().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.transport().poll_close(cx) + } +} + +impl AsRef for BaseChannel { + fn as_ref(&self) -> &T { + self.transport.get_ref() + } +} + +impl Channel for BaseChannel +where + T: Transport, ClientMessage> + Send + 'static, + Req: Send + 'static, + Resp: Send + 'static, +{ + type Req = Req; + type Resp = Resp; + + fn config(&self) -> &Config { + &self.config + } + + fn in_flight_requests(mut self: Pin<&mut Self>) -> usize { + self.as_mut().in_flight_requests().len() + } + + fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration { + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + assert!(self + .in_flight_requests() + .insert(request_id, abort_handle) + .is_none()); + abort_registration + } +} + +/// A running handler serving all requests coming over a channel. #[derive(Debug)] -struct ClientHandler { - channel: Channel, +pub struct ResponseHandler +where + C: Channel, +{ + channel: C, /// Responses waiting to be written to the wire. - pending_responses: Fuse)>>, + pending_responses: Fuse)>>, /// Handed out to request handlers to fan in responses. - responses_tx: mpsc::Sender<(context::Context, Response)>, - /// Number of requests currently being responded to. - in_flight_requests: FnvHashMap, + responses_tx: mpsc::Sender<(context::Context, Response)>, /// Request handler. f: F, } -impl ClientHandler { - unsafe_pinned!(channel: Channel); - unsafe_pinned!(in_flight_requests: FnvHashMap); - unsafe_pinned!(pending_responses: Fuse)>>); - unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response)>); +impl ResponseHandler +where + C: Channel, +{ + unsafe_pinned!(channel: C); + unsafe_pinned!(pending_responses: Fuse)>>); + unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response)>); // For this to be safe, field f must be private, and code in this module must never // construct PinMut. unsafe_unpinned!(f: F); } -impl ClientHandler +impl ResponseHandler where - Req: Send + 'static, - Resp: Send + 'static, - T: Transport, SinkItem = Response> + Send, - F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone, - Fut: Future> + Send + 'static, + C: Channel, + F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone, + Fut: Future> + Send + 'static, { - /// If at max in-flight requests, check that there's room to immediately write a throttled - /// response. - fn poll_ready_if_throttling( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - if self.in_flight_requests.len() - >= self.channel.config.max_in_flight_requests_per_connection - { - let peer = self.as_mut().channel().client_addr; - - while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? { - info!( - "[{}] In-flight requests at max ({}), and transport is not ready.", - peer, - self.as_mut().in_flight_requests().len(), - ); - try_ready!(self.as_mut().channel().poll_flush(cx)); - } - } - Poll::Ready(Ok(())) - } - fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { - ready!(self.as_mut().poll_ready_if_throttling(cx)?); - - Poll::Ready(match ready!(self.as_mut().channel().poll_next(cx)?) { - Some(message) => { - match message.message { - ClientMessageKind::Request(request) => { - self.handle_request(message.trace_context, request)?; - } - ClientMessageKind::Cancel { request_id } => { - self.cancel_request(&message.trace_context, request_id); - } - } - Some(Ok(())) + match ready!(self.as_mut().channel().poll_next(cx)?) { + Some(request) => { + self.handle_request(request)?; + Poll::Ready(Some(Ok(()))) } - None => { - trace!("[{}] Read half closed", self.channel.client_addr); - None - } - }) + None => Poll::Ready(None), + } } fn pump_write( @@ -368,7 +456,12 @@ where read_half_closed: bool, ) -> PollIo<()> { match self.as_mut().poll_next_response(cx)? { - Poll::Ready(Some((_, response))) => { + Poll::Ready(Some((ctx, response))) => { + trace!( + "[{}] Staging response. In-flight requests = {}.", + ctx.trace_id(), + self.as_mut().channel().in_flight_requests(), + ); self.as_mut().channel().start_send(response)?; Poll::Ready(Some(Ok(()))) } @@ -384,7 +477,7 @@ where // Being here means there are no staged requests and all written responses are // fully flushed. So, if the read half is closed and there are no in-flight // requests, then we can close the write half. - if read_half_closed && self.as_mut().in_flight_requests().is_empty() { + if read_half_closed && self.as_mut().channel().in_flight_requests() == 0 { Poll::Ready(None) } else { Poll::Pending @@ -396,90 +489,33 @@ where fn poll_next_response( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> PollIo<(context::Context, Response)> { + ) -> PollIo<(context::Context, Response)> { // Ensure there's room to write a response. while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? { ready!(self.as_mut().channel().poll_flush(cx)?); } - let peer = self.as_mut().channel().client_addr; - match ready!(self.as_mut().pending_responses().poll_next(cx)) { - Some((ctx, response)) => { - if self - .as_mut() - .in_flight_requests() - .remove(&response.request_id) - .is_some() - { - self.as_mut().in_flight_requests().compact(0.1); - } - trace!( - "[{}/{}] Staging response. In-flight requests = {}.", - ctx.trace_id(), - peer, - self.as_mut().in_flight_requests().len(), - ); - Poll::Ready(Some(Ok((ctx, response)))) - } + Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))), None => { - // This branch likely won't happen, since the ClientHandler is holding a Sender. - trace!("[{}] No new responses.", peer); + // This branch likely won't happen, since the ResponseHandler is holding a Sender. Poll::Ready(None) } } } - fn handle_request( - mut self: Pin<&mut Self>, - trace_context: trace::Context, - request: Request, - ) -> io::Result<()> { + fn handle_request(mut self: Pin<&mut Self>, request: Request) -> io::Result<()> { let request_id = request.id; - let peer = self.as_mut().channel().client_addr; - let ctx = context::Context { - deadline: request.deadline, - trace_context, - }; - let request = request.message; - - if self.as_mut().in_flight_requests().len() - >= self - .as_mut() - .channel() - .config - .max_in_flight_requests_per_connection - { - debug!( - "[{}/{}] Client has reached in-flight request limit ({}/{}).", - ctx.trace_id(), - peer, - self.as_mut().in_flight_requests().len(), - self.as_mut() - .channel() - .config - .max_in_flight_requests_per_connection - ); - - self.as_mut().channel().start_send(Response { - request_id, - message: Err(ServerError { - kind: io::ErrorKind::WouldBlock, - detail: Some("Server throttled the request.".into()), - }), - })?; - return Ok(()); - } - - let deadline = ctx.deadline; + let deadline = request.context.deadline; let timeout = deadline.as_duration(); trace!( - "[{}/{}] Received request with deadline {} (timeout {:?}).", - ctx.trace_id(), - peer, + "[{}] Received request with deadline {} (timeout {:?}).", + request.context.trace_id(), format_rfc3339(deadline), timeout, ); + let ctx = request.context; + let request = request.message; let mut response_tx = self.as_mut().responses_tx().clone(); let trace_id = *ctx.trace_id(); @@ -490,18 +526,19 @@ where request_id, message: match result { Ok(message) => Ok(message), - Err(e) => Err(make_server_error(e, trace_id, peer, deadline)), + Err(e) => Err(make_server_error(e, trace_id, deadline)), }, }; - trace!("[{}/{}] Sending response.", trace_id, peer); + trace!("[{}] Sending response.", trace_id); response_tx .send((ctx, response)) .unwrap_or_else(|_| ()) .await; }, ); - let (abortable_response, abort_handle) = abortable(response); - crate::spawn(abortable_response.map(|_| ())).map_err(|e| { + let abort_registration = self.as_mut().channel().start_request(request_id); + let response = Abortable::new(response, abort_registration); + crate::spawn(response.map(|_| ())).map_err(|e| { io::Error::new( io::ErrorKind::Other, format!( @@ -510,92 +547,49 @@ where ), ) })?; - self.as_mut() - .in_flight_requests() - .insert(request_id, abort_handle); Ok(()) } - - fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) { - // It's possible the request was already completed, so it's fine - // if this is None. - if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) { - self.as_mut().in_flight_requests().compact(0.1); - - cancel_handle.abort(); - let remaining = self.as_mut().in_flight_requests().len(); - trace!( - "[{}/{}] Request canceled. In-flight requests = {}", - trace_context.trace_id, - self.channel.client_addr, - remaining, - ); - } else { - trace!( - "[{}/{}] Received cancellation, but response handler \ - is already complete.", - trace_context.trace_id, - self.channel.client_addr - ); - } - } } -impl Future for ClientHandler +impl Future for ResponseHandler where - Req: Send + 'static, - Resp: Send + 'static, - T: Transport, SinkItem = Response> + Send, - F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone, - Fut: Future> + Send + 'static, + C: Channel, + F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone, + Fut: Future> + Send + 'static, { - type Output = io::Result<()>; + type Output = (); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - trace!("[{}] ClientHandler::poll", self.channel.client_addr); - loop { - let read = self.as_mut().pump_read(cx)?; - match ( - read, - self.as_mut().pump_write(cx, read == Poll::Ready(None))?, - ) { - (Poll::Ready(None), Poll::Ready(None)) => { - info!("[{}] Client disconnected.", self.channel.client_addr); - return Poll::Ready(Ok(())); - } - (read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => { - trace!( - "[{}] read: {:?}, write: {:?}.", - self.channel.client_addr, - read, - write - ) - } - (read, write) => { - trace!( - "[{}] read: {:?}, write: {:?} (not ready).", - self.channel.client_addr, - read, - write, - ); - return Poll::Pending; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + move || -> Poll> { + loop { + let read = self.as_mut().pump_read(cx)?; + match ( + read, + self.as_mut().pump_write(cx, read == Poll::Ready(None))?, + ) { + (Poll::Ready(None), Poll::Ready(None)) => { + return Poll::Ready(Ok(())); + } + (Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {} + _ => { + return Poll::Pending; + } } } - } + }() + .map(|r| r.unwrap_or_else(|e| info!("ResponseHandler errored out: {}", e))) } } fn make_server_error( e: timeout::Error, trace_id: TraceId, - peer: SocketAddr, deadline: SystemTime, ) -> ServerError { if e.is_elapsed() { debug!( - "[{}/{}] Response did not complete before deadline of {}s.", + "[{}] Response did not complete before deadline of {}s.", trace_id, - peer, format_rfc3339(deadline) ); // No point in responding, since the client will have dropped the request. @@ -608,8 +602,8 @@ fn make_server_error( } } else if e.is_timer() { error!( - "[{}/{}] Response failed because of an issue with a timer: {}", - trace_id, peer, e + "[{}] Response failed because of an issue with a timer: {}", + trace_id, e ); ServerError { @@ -623,7 +617,7 @@ fn make_server_error( detail: Some(e.description().into()), } } else { - error!("[{}/{}] Unexpected response failure: {}", trace_id, peer, e); + error!("[{}] Unexpected response failure: {}", trace_id, e); ServerError { kind: io::ErrorKind::Other, diff --git a/rpc/src/server/testing.rs b/rpc/src/server/testing.rs new file mode 100644 index 0000000..a40002a --- /dev/null +++ b/rpc/src/server/testing.rs @@ -0,0 +1,125 @@ +use crate::server::{Channel, Config}; +use crate::{context, Request, Response}; +use fnv::FnvHashSet; +use futures::future::{AbortHandle, AbortRegistration}; +use futures::{Sink, Stream}; +use futures_test::task::noop_waker_ref; +use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use std::collections::VecDeque; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::SystemTime; + +pub(crate) struct FakeChannel { + pub stream: VecDeque, + pub sink: VecDeque, + pub config: Config, + pub in_flight_requests: FnvHashSet, +} + +impl FakeChannel { + unsafe_pinned!(stream: VecDeque); + unsafe_pinned!(sink: VecDeque); + unsafe_unpinned!(in_flight_requests: FnvHashSet); +} + +impl Stream for FakeChannel +where + In: Unpin, +{ + type Item = In; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.stream().poll_next(cx) + } +} + +impl Sink> for FakeChannel> { + type Error = io::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.sink().poll_ready(cx).map_err(|e| match e {}) + } + + fn start_send( + mut self: Pin<&mut Self>, + response: Response, + ) -> Result<(), Self::Error> { + self.as_mut() + .in_flight_requests() + .remove(&response.request_id); + self.sink().start_send(response).map_err(|e| match e {}) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.sink().poll_flush(cx).map_err(|e| match e {}) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.sink().poll_close(cx).map_err(|e| match e {}) + } +} + +impl Channel for FakeChannel>, Response> +where + Req: Unpin + Send + 'static, + Resp: Send + 'static, +{ + type Req = Req; + type Resp = Resp; + + fn config(&self) -> &Config { + &self.config + } + + fn in_flight_requests(self: Pin<&mut Self>) -> usize { + self.in_flight_requests.len() + } + + fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration { + self.in_flight_requests().insert(id); + AbortHandle::new_pair().1 + } +} + +impl FakeChannel>, Response> { + pub fn push_req(&mut self, id: u64, message: Req) { + self.stream.push_back(Ok(Request { + context: context::Context { + deadline: SystemTime::UNIX_EPOCH, + trace_context: Default::default(), + }, + id, + message, + })); + } +} + +impl FakeChannel<(), ()> { + pub fn default() -> FakeChannel>, Response> { + FakeChannel { + stream: VecDeque::default(), + sink: VecDeque::default(), + config: Config::default(), + in_flight_requests: FnvHashSet::default(), + } + } +} + +pub trait PollExt { + fn is_done(&self) -> bool; +} + +impl PollExt for Poll> { + fn is_done(&self) -> bool { + match self { + Poll::Ready(None) => true, + _ => false, + } + } +} + +pub fn cx() -> Context<'static> { + Context::from_waker(&noop_waker_ref()) +} diff --git a/rpc/src/server/throttle.rs b/rpc/src/server/throttle.rs new file mode 100644 index 0000000..4ba9f19 --- /dev/null +++ b/rpc/src/server/throttle.rs @@ -0,0 +1,332 @@ +use super::{Channel, Config}; +use crate::{Response, ServerError}; +use futures::{ + future::AbortRegistration, + prelude::*, + ready, + task::{Context, Poll}, +}; +use log::debug; +use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use std::{io, pin::Pin}; + +/// A [`Channel`] that limits the number of concurrent +/// requests by throttling. +#[derive(Debug)] +pub struct Throttler { + max_in_flight_requests: usize, + inner: C, +} + +impl Throttler { + unsafe_unpinned!(max_in_flight_requests: usize); + unsafe_pinned!(inner: C); + + /// Returns the inner channel. + pub fn get_ref(&self) -> &C { + &self.inner + } +} + +impl Throttler +where + C: Channel, +{ + /// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to + /// `max_in_flight_requests`. + pub fn new(inner: C, max_in_flight_requests: usize) -> Self { + Throttler { + inner, + max_in_flight_requests, + } + } +} + +impl Stream for Throttler +where + C: Channel, +{ + type Item = ::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + while self.as_mut().in_flight_requests() >= *self.as_mut().max_in_flight_requests() { + ready!(self.as_mut().inner().poll_ready(cx)?); + + match ready!(self.as_mut().inner().poll_next(cx)?) { + Some(request) => { + debug!( + "[{}] Client has reached in-flight request limit ({}/{}).", + request.context.trace_id(), + self.as_mut().in_flight_requests(), + self.as_mut().max_in_flight_requests(), + ); + + self.as_mut().start_send(Response { + request_id: request.id, + message: Err(ServerError { + kind: io::ErrorKind::WouldBlock, + detail: Some("Server throttled the request.".into()), + }), + })?; + } + None => return Poll::Ready(None), + } + } + self.inner().poll_next(cx) + } +} + +impl Sink::Resp>> for Throttler +where + C: Channel, +{ + type Error = io::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.inner().poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Response<::Resp>) -> io::Result<()> { + self.inner().start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.inner().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.inner().poll_close(cx) + } +} + +impl AsRef for Throttler { + fn as_ref(&self) -> &C { + &self.inner + } +} + +impl Channel for Throttler +where + C: Channel, +{ + type Req = ::Req; + type Resp = ::Resp; + + fn in_flight_requests(self: Pin<&mut Self>) -> usize { + self.inner().in_flight_requests() + } + + fn config(&self) -> &Config { + self.inner.config() + } + + fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration { + self.inner().start_request(request_id) + } +} + +/// A stream of throttling channels. +#[derive(Debug)] +pub struct ThrottlerStream { + inner: S, + max_in_flight_requests: usize, +} + +impl ThrottlerStream +where + S: Stream, + ::Item: Channel, +{ + unsafe_pinned!(inner: S); + unsafe_unpinned!(max_in_flight_requests: usize); + + pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self { + Self { + inner, + max_in_flight_requests, + } + } +} + +impl Stream for ThrottlerStream +where + S: Stream, + ::Item: Channel, +{ + type Item = Throttler<::Item>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match ready!(self.as_mut().inner().poll_next(cx)) { + Some(channel) => Poll::Ready(Some(Throttler::new( + channel, + *self.max_in_flight_requests(), + ))), + None => Poll::Ready(None), + } + } +} + +#[cfg(test)] +use super::testing::{self, FakeChannel, PollExt}; +#[cfg(test)] +use crate::Request; +#[cfg(test)] +use pin_utils::pin_mut; +#[cfg(test)] +use std::marker::PhantomData; + +#[test] +fn throttler_in_flight_requests() { + let throttler = Throttler { + max_in_flight_requests: 0, + inner: FakeChannel::default::(), + }; + + pin_mut!(throttler); + for i in 0..5 { + throttler.inner.in_flight_requests.insert(i); + } + assert_eq!(throttler.as_mut().in_flight_requests(), 5); +} + +#[test] +fn throttler_start_request() { + let throttler = Throttler { + max_in_flight_requests: 0, + inner: FakeChannel::default::(), + }; + + pin_mut!(throttler); + throttler.as_mut().start_request(1); + assert_eq!(throttler.inner.in_flight_requests.len(), 1); +} + +#[test] +fn throttler_poll_next_done() { + let throttler = Throttler { + max_in_flight_requests: 0, + inner: FakeChannel::default::(), + }; + + pin_mut!(throttler); + assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done()); +} + +#[test] +fn throttler_poll_next_some() -> io::Result<()> { + let throttler = Throttler { + max_in_flight_requests: 1, + inner: FakeChannel::default::(), + }; + + pin_mut!(throttler); + throttler.inner.push_req(0, 1); + assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready()); + assert_eq!( + throttler + .as_mut() + .poll_next(&mut testing::cx())? + .map(|r| r.map(|r| (r.id, r.message))), + Poll::Ready(Some((0, 1))) + ); + Ok(()) +} + +#[test] +fn throttler_poll_next_throttled() { + let throttler = Throttler { + max_in_flight_requests: 0, + inner: FakeChannel::default::(), + }; + + pin_mut!(throttler); + throttler.inner.push_req(1, 1); + assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done()); + assert_eq!(throttler.inner.sink.len(), 1); + let resp = throttler.inner.sink.get(0).unwrap(); + assert_eq!(resp.request_id, 1); + assert!(resp.message.is_err()); +} + +#[test] +fn throttler_poll_next_throttled_sink_not_ready() { + let throttler = Throttler { + max_in_flight_requests: 0, + inner: PendingSink::default::(), + }; + pin_mut!(throttler); + assert!(throttler.poll_next(&mut testing::cx()).is_pending()); + + struct PendingSink { + ghost: PhantomData In>, + } + impl PendingSink<(), ()> { + pub fn default() -> PendingSink>, Response> { + PendingSink { ghost: PhantomData } + } + } + impl Stream for PendingSink { + type Item = In; + fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + unimplemented!() + } + } + impl Sink for PendingSink { + type Error = io::Error; + fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + Poll::Pending + } + fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> { + Err(io::Error::from(io::ErrorKind::WouldBlock)) + } + fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + Poll::Pending + } + fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + Poll::Pending + } + } + impl Channel for PendingSink>, Response> + where + Req: Send + 'static, + Resp: Send + 'static, + { + type Req = Req; + type Resp = Resp; + fn config(&self) -> &Config { + unimplemented!() + } + fn in_flight_requests(self: Pin<&mut Self>) -> usize { + 0 + } + fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration { + unimplemented!() + } + } +} + +#[test] +fn throttler_start_send() { + let throttler = Throttler { + max_in_flight_requests: 0, + inner: FakeChannel::default::(), + }; + + pin_mut!(throttler); + throttler.inner.in_flight_requests.insert(0); + throttler + .as_mut() + .start_send(Response { + request_id: 0, + message: Ok(1), + }) + .unwrap(); + assert!(throttler.inner.in_flight_requests.is_empty()); + assert_eq!( + throttler.inner.sink.get(0), + Some(&Response { + request_id: 0, + message: Ok(1), + }) + ); +} diff --git a/rpc/src/transport/channel.rs b/rpc/src/transport/channel.rs index a567950..d3f1e21 100644 --- a/rpc/src/transport/channel.rs +++ b/rpc/src/transport/channel.rs @@ -6,14 +6,11 @@ //! Transports backed by in-memory channels. -use crate::{PollIo, Transport}; +use crate::PollIo; use futures::{channel::mpsc, task::Context, Poll, Sink, Stream}; use pin_utils::unsafe_pinned; +use std::io; use std::pin::Pin; -use std::{ - io, - net::{IpAddr, Ipv4Addr, SocketAddr}, -}; /// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's /// [`Sink`]. @@ -51,7 +48,7 @@ impl Stream for UnboundedChannel { } impl Sink for UnboundedChannel { - type SinkError = io::Error; + type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.tx() @@ -65,7 +62,7 @@ impl Sink for UnboundedChannel { .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.tx() .poll_flush(cx) .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) @@ -78,19 +75,6 @@ impl Sink for UnboundedChannel { } } -impl Transport for UnboundedChannel { - type SinkItem = SinkItem; - type Item = Item; - - fn peer_addr(&self) -> io::Result { - Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)) - } - - fn local_addr(&self) -> io::Result { - Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)) - } -} - #[cfg(test)] mod tests { use crate::{ @@ -110,7 +94,7 @@ mod tests { let (client_channel, server_channel) = transport::channel::unbounded(); let server = Server::::default() - .incoming(stream::once(future::ready(Ok(server_channel)))) + .incoming(stream::once(future::ready(server_channel))) .respond_with(|_ctx, request| { future::ready(request.parse::().map_err(|_| { io::Error::new( diff --git a/rpc/src/transport/mod.rs b/rpc/src/transport/mod.rs index b926450..e934aef 100644 --- a/rpc/src/transport/mod.rs +++ b/rpc/src/transport/mod.rs @@ -10,114 +10,10 @@ //! can be plugged in, using whatever protocol it wants. use futures::prelude::*; -use std::{ - io, - marker::PhantomData, - net::SocketAddr, - pin::Pin, - task::{Context, Poll}, -}; +use std::io; pub mod channel; /// A bidirectional stream ([`Sink`] + [`Stream`]) of messages. -pub trait Transport -where - Self: Stream::Item>>, - Self: Sink<::SinkItem, SinkError = io::Error>, -{ - /// The type read off the transport. - type Item; - /// The type written to the transport. - type SinkItem; - - /// The address of the remote peer this transport is in communication with. - fn peer_addr(&self) -> io::Result; - /// The address of the local half of this transport. - fn local_addr(&self) -> io::Result; -} - -/// Returns a new Transport backed by the given Stream + Sink and connecting addresses. -pub fn new( - inner: S, - peer_addr: SocketAddr, - local_addr: SocketAddr, -) -> impl Transport -where - S: Stream>, - S: Sink, -{ - TransportShim { - inner, - peer_addr, - local_addr, - _marker: PhantomData, - } -} - -/// A transport created by adding peers to a Stream + Sink. -#[derive(Debug)] -struct TransportShim { - peer_addr: SocketAddr, - local_addr: SocketAddr, - inner: S, - _marker: PhantomData, -} - -impl TransportShim { - pin_utils::unsafe_pinned!(inner: S); -} - -impl Stream for TransportShim -where - S: Stream, -{ - type Item = S::Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner().poll_next(cx) - } -} - -impl Sink for TransportShim -where - S: Sink, -{ - type SinkError = S::SinkError; - - fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), S::SinkError> { - self.inner().start_send(item) - } - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner().poll_ready(cx) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner().poll_flush(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner().poll_close(cx) - } -} - -impl Transport for TransportShim -where - S: Stream + Sink, - Self: Stream>, - Self: Sink, -{ - type Item = Item; - type SinkItem = SinkItem; - - /// The address of the remote peer this transport is in communication with. - fn peer_addr(&self) -> io::Result { - Ok(self.peer_addr) - } - - /// The address of the local half of this transport. - fn local_addr(&self) -> io::Result { - Ok(self.local_addr) - } -} +pub trait Transport = + Stream> + Sink; diff --git a/rpc/src/util/mod.rs b/rpc/src/util/mod.rs index 8c55178..a4a2642 100644 --- a/rpc/src/util/mod.rs +++ b/rpc/src/util/mod.rs @@ -38,9 +38,11 @@ where H: BuildHasher, { fn compact(&mut self, usage_ratio_threshold: f64) { - let usage_ratio = self.len() as f64 / self.capacity() as f64; - if usage_ratio < usage_ratio_threshold { - self.shrink_to_fit(); + if self.capacity() > 1000 { + let usage_ratio = self.len() as f64 / self.capacity() as f64; + if usage_ratio < usage_ratio_threshold { + self.shrink_to_fit(); + } } } } diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index fd4cc91..5ecdfd2 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -19,7 +19,7 @@ serde1 = ["rpc/serde1", "serde", "serde/derive"] travis-ci = { repository = "google/tarpc" } [dependencies] -futures-preview = { version = "0.3.0-alpha.16", features = ["compat"] } +futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] } log = "0.4" serde = { optional = true, version = "1.0" } rpc = { package = "tarpc-lib", path = "../rpc", version = "0.6" } @@ -31,7 +31,6 @@ bytes = { version = "0.4", features = ["serde"] } humantime = "1.0" bincode-transport = { package = "tarpc-bincode-transport", version = "0.7", path = "../bincode-transport" } env_logger = "0.6" -libtest = "0.0.1" tokio = "0.1" tokio-executor = "0.1" tokio-tcp = "0.1" diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 5fae8f9..83a53f2 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -18,7 +18,7 @@ use futures::{ }; use rpc::{ client, context, - server::{self, Handler, Server}, + server::{self, Handler}, }; use std::{ collections::HashMap, @@ -60,8 +60,9 @@ impl subscriber::Service for Subscriber { impl Subscriber { async fn listen(id: u32, config: server::Config) -> io::Result { - let incoming = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let addr = incoming.local_addr(); + let incoming = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())? + .filter_map(|r| future::ready(r.ok())); + let addr = incoming.get_ref().local_addr(); tokio_executor::spawn( server::new(config) .incoming(incoming) @@ -140,12 +141,13 @@ impl publisher::Service for Publisher { async fn run() -> io::Result<()> { env_logger::init(); - let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let publisher_addr = transport.local_addr(); + let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())? + .filter_map(|r| future::ready(r.ok())); + let publisher_addr = transport.get_ref().local_addr(); tokio_executor::spawn( - Server::default() - .incoming(transport) + transport .take(1) + .map(server::BaseChannel::with_defaults) .respond_with(publisher::serve(Publisher::new())) .unit_error() .boxed() diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index ce50cca..adf95dc 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -13,7 +13,7 @@ use futures::{ }; use rpc::{ client, context, - server::{Handler, Server}, + server::{BaseChannel, Channel}, }; use std::io; @@ -43,20 +43,22 @@ impl Service for HelloServer { async fn run() -> io::Result<()> { // bincode_transport is provided by the associated crate bincode-transport. It makes it easy // to start up a serde-powered bincode serialization strategy over TCP. - let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; + let mut transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; let addr = transport.local_addr(); - // The server is configured with the defaults. - let server = Server::default() - // Server can listen on any type that implements the Transport trait. - .incoming(transport) - // Close the stream after the client connects - .take(1) + // For this example, we're just going to wait for one connection. + let client = transport.next().await.unwrap()?; + + // `Channel` is a trait representing a server-side connection. It is a trait to allow + // for some channels to be instrumented: for example, to track the number of open connections. + // BaseChannel is the most basic channel, simply wrapping a transport with no added + // functionality. + let server = BaseChannel::with_defaults(client) // serve is generated by the tarpc::service! macro. It takes as input any type implementing // the generated Service trait. .respond_with(serve(HelloServer)); - tokio_executor::spawn(server.unit_error().boxed().compat()); + tokio::spawn(server.unit_error().boxed().compat()); let transport = bincode_transport::connect(&addr).await?; diff --git a/tarpc/examples/server_calling_server.rs b/tarpc/examples/server_calling_server.rs index 53338b0..2dc532d 100644 --- a/tarpc/examples/server_calling_server.rs +++ b/tarpc/examples/server_calling_server.rs @@ -69,8 +69,9 @@ impl DoubleService for DoubleServer { } async fn run() -> io::Result<()> { - let add_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let addr = add_listener.local_addr(); + let add_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())? + .filter_map(|r| future::ready(r.ok())); + let addr = add_listener.get_ref().local_addr(); let add_server = Server::default() .incoming(add_listener) .take(1) @@ -80,8 +81,9 @@ async fn run() -> io::Result<()> { let to_add_server = bincode_transport::connect(&addr).await?; let add_client = add::new_stub(client::Config::default(), to_add_server).await?; - let double_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let addr = double_listener.local_addr(); + let double_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())? + .filter_map(|r| future::ready(r.ok())); + let addr = double_listener.get_ref().local_addr(); let double_server = rpc::Server::default() .incoming(double_listener) .take(1) diff --git a/tarpc/examples/service_registry.rs b/tarpc/examples/service_registry.rs index fe376de..df732e8 100644 --- a/tarpc/examples/service_registry.rs +++ b/tarpc/examples/service_registry.rs @@ -1,9 +1,4 @@ -#![feature( - async_await, - arbitrary_self_types, - proc_macro_hygiene, - impl_trait_in_bindings -)] +#![feature(async_await, arbitrary_self_types, proc_macro_hygiene)] mod registry { use bytes::Bytes; @@ -382,8 +377,9 @@ async fn run() -> io::Result<()> { read_service::serve(server.clone()), ); - let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let server_addr = listener.local_addr(); + let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())? + .filter_map(|r| future::ready(r.ok())); + let server_addr = listener.get_ref().local_addr(); let server = tarpc::Server::default() .incoming(listener) .take(1) diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs index d8ead66..21f83ec 100644 --- a/tarpc/src/macros.rs +++ b/tarpc/src/macros.rs @@ -58,13 +58,13 @@ macro_rules! service { ( $( $(#[$attr:meta])* - rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) $(-> $out:ty)*; + rpc $fn_name:ident( $( $(#[$argattr:meta])* $arg:ident : $in_:ty ),* ) $(-> $out:ty)*; )* ) => { $crate::service! {{ $( $(#[$attr])* - rpc $fn_name( $( $arg : $in_ ),* ) $(-> $out)*; + rpc $fn_name( $( $(#[$argattr])* $arg : $in_ ),* ) $(-> $out)*; )* }} }; @@ -72,7 +72,7 @@ macro_rules! service { ( { $(#[$attr:meta])* - rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ); + rpc $fn_name:ident( $( $(#[$argattr:meta])* $arg:ident : $in_:ty ),* ); $( $unexpanded:tt )* } @@ -84,14 +84,14 @@ macro_rules! service { $( $expanded )* $(#[$attr])* - rpc $fn_name( $( $arg : $in_ ),* ) -> (); + rpc $fn_name( $( $(#[$argattr])* $arg : $in_ ),* ) -> (); } }; // Pattern for when the next rpc has an explicit return type. ( { $(#[$attr:meta])* - rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty; + rpc $fn_name:ident( $( $(#[$argattr:meta])* $arg:ident : $in_:ty ),* ) -> $out:ty; $( $unexpanded:tt )* } @@ -103,7 +103,7 @@ macro_rules! service { $( $expanded )* $(#[$attr])* - rpc $fn_name( $( $arg : $in_ ),* ) -> $out; + rpc $fn_name( $( $(#[$argattr])* $arg : $in_ ),* ) -> $out; } }; // Pattern for when all return types have been expanded @@ -111,7 +111,7 @@ macro_rules! service { { } // none left to expand $( $(#[$attr:meta])* - rpc $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty; + rpc $fn_name:ident ( $( $(#[$argattr:meta])* $arg:ident : $in_:ty ),* ) -> $out:ty; )* ) => { $crate::add_serde_if_enabled! { @@ -122,7 +122,7 @@ macro_rules! service { pub enum Request { $( $(#[$attr])* - $fn_name{ $($arg: $in_,)* } + $fn_name{ $( $(#[$argattr])* $arg: $in_,)* } ),* } } @@ -218,9 +218,7 @@ macro_rules! service { pub async fn new_stub(config: $crate::client::Config, transport: T) -> ::std::io::Result where - T: $crate::Transport< - Item = $crate::Response, - SinkItem = $crate::ClientMessage> + Send + 'static, + T: $crate::Transport<$crate::ClientMessage, $crate::Response> + Send + 'static, { Ok(Client($crate::client::new(config, transport).await?)) } @@ -321,7 +319,7 @@ mod functional_test { let (tx, rx) = channel::unbounded(); tokio_executor::spawn( crate::Server::default() - .incoming(stream::once(ready(Ok(rx)))) + .incoming(stream::once(ready(rx))) .respond_with(serve(Server)) .unit_error() .boxed() @@ -350,7 +348,7 @@ mod functional_test { let (tx, rx) = channel::unbounded(); tokio_executor::spawn( rpc::Server::default() - .incoming(stream::once(ready(Ok(rx)))) + .incoming(stream::once(ready(rx))) .respond_with(serve(Server)) .unit_error() .boxed() diff --git a/tarpc/tests/latency.rs b/tarpc/tests/latency.rs index 42fc0bd..ba6861b 100644 --- a/tarpc/tests/latency.rs +++ b/tarpc/tests/latency.rs @@ -12,8 +12,10 @@ proc_macro_hygiene )] +extern crate test; + use futures::{compat::Executor01CompatExt, future, prelude::*}; -use libtest::stats::Stats; +use test::stats::Stats; use rpc::{ client, context, server::{Handler, Server}, @@ -41,8 +43,9 @@ impl ack::Service for Serve { } async fn bench() -> io::Result<()> { - let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; - let addr = listener.local_addr(); + let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())? + .filter_map(|r| future::ready(r.ok())); + let addr = listener.get_ref().local_addr(); tokio_executor::spawn( Server::default() diff --git a/trace/Cargo.toml b/trace/Cargo.toml index 778359f..605a73f 100644 --- a/trace/Cargo.toml +++ b/trace/Cargo.toml @@ -13,7 +13,7 @@ readme = "../README.md" description = "foundations for tracing in tarpc" [dependencies] -rand = "0.6" +rand = "0.7" [dependencies.serde] version = "1.0" diff --git a/trace/src/lib.rs b/trace/src/lib.rs index 4255d07..28ac418 100644 --- a/trace/src/lib.rs +++ b/trace/src/lib.rs @@ -26,7 +26,7 @@ use std::{ /// /// Consists of a span identifying an event, an optional parent span identifying a causal event /// that triggered the current span, and a trace with which all related spans are associated. -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Context { /// An identifier of the trace associated with the current context. A trace ID is typically @@ -46,12 +46,12 @@ pub struct Context { /// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the /// same trace ID. -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct TraceId(u128); /// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace. -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct SpanId(u64);