Make server methods more composable.

-- Connection Limits

The problem with having ConnectionFilter default-enabled is elaborated on in https://github.com/google/tarpc/issues/217. The gist of it is not all servers want a policy based on `SocketAddr`. This PR allows customizing the behavior of ConnectionFilter, at the cost of not having it enabled by default. However, enabling it is as simple as one line:

incoming.max_channels_per_key(10, ip_addr)

The second argument is a key function that takes the user-chosen transport and returns some hashable, equatable, cloneable key. In the above example, it returns an `IpAddr`.

This also allows the `Transport` trait to have the addr fns removed, which means it has become simply an alias for `Stream + Sink`.

-- Per-Channel Request Throttling

With respect to Channel's throttling behavior, the same argument applies. There isn't a one size fits all solution to throttling requests, and the policy applied by tarpc is just one of potentially many solutions. As such, `Channel` is now a trait that offers a few combinators, one of which is throttling:

channel.max_concurrent_requests(10).respond_with(serve(Server))

This functionality is also available on the existing `Handler` trait, which applies it to all incoming channels and can be used in tandem with connection limits:

incoming
    .max_channels_per_key(10, ip_addr)
    .max_concurrent_requests_per_channel(10).respond_with(serve(Server))

-- Global Request Throttling

I've entirely removed the overall request limit enforced across all channels. This functionality is easily gotten back via [`StreamExt::buffer_unordered`](https://rust-lang-nursery.github.io/futures-api-docs/0.3.0-alpha.1/futures/stream/trait.StreamExt.html#method.buffer_unordered), with the difference being that the previous behavior allowed you to spawn channels onto different threads, whereas `buffer_unordered ` means the `Channels` are handled on a single thread (the per-request handlers are still spawned). Considering the existing options, I don't believe that the benefit provided by this functionality held its own.
This commit is contained in:
Tim Kuehn
2019-07-15 18:58:36 -07:00
parent 146496d08c
commit 1089415451
36 changed files with 1303 additions and 989 deletions

View File

@@ -36,12 +36,12 @@ Add to your `Cargo.toml` dependencies:
tarpc = "0.18.0"
```
The `service!` macro expands to a collection of items that form an
rpc service. In the above example, the macro is called within the
`hello_service` module. This module will contain a `Client` stub and `Service` trait. There is
These generated types make it easy and ergonomic to write servers without dealing with serialization
directly. Simply implement one of the generated traits, and you're off to the
races!
The `service!` macro expands to a collection of items that form an rpc service.
In the above example, the macro is called within the `hello_service` module.
This module will contain a `Client` stub and `Service` trait. There is These
generated types make it easy and ergonomic to write servers without dealing with
serialization directly. Simply implement one of the generated traits, and you're
off to the races!
## Example
@@ -49,7 +49,7 @@ For this example, in addition to tarpc, also add two other dependencies to
your `Cargo.toml`:
```toml
futures-preview = { version = "0.3.0-alpha.16", features = ["compat"] }
futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] }
tokio = "0.1"
```
@@ -175,7 +175,7 @@ async fn run() -> io::Result<()> {
let server = server::new(server::Config::default())
// incoming() takes a stream of transports such as would be returned by
// TcpListener::incoming (but a stream instead of an iterator).
.incoming(stream::once(future::ready(Ok(server_transport))))
.incoming(stream::once(future::ready(server_transport)))
// serve is generated by the service! macro. It takes as input any type implementing
// the generated Service trait.
.respond_with(serve(HelloServer));
@@ -246,7 +246,7 @@ background tasks for the client and server.
# let server = server::new(server::Config::default())
# // incoming() takes a stream of transports such as would be returned by
# // TcpListener::incoming (but a stream instead of an iterator).
# .incoming(stream::once(future::ready(Ok(server_transport))))
# .incoming(stream::once(future::ready(server_transport)))
# // serve is generated by the service! macro. It takes as input any type implementing
# // the generated Service trait.
# .respond_with(serve(HelloServer));

View File

@@ -14,10 +14,9 @@ description = "A bincode-based transport for tarpc services."
[dependencies]
bincode = "1"
futures-preview = { version = "0.3.0-alpha.16", features = ["compat"] }
futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] }
futures_legacy = { version = "0.1", package = "futures" }
pin-utils = "0.1.0-alpha.4"
rpc = { package = "tarpc-lib", version = "0.6", path = "../rpc", features = ["serde1"] }
serde = "1.0"
tokio-io = "0.1"
async-bincode = "0.4"
@@ -26,9 +25,10 @@ tokio-tcp = "0.1"
[dev-dependencies]
env_logger = "0.6"
humantime = "1.0"
libtest = "0.0.1"
log = "0.4"
rand = "0.6"
rand = "0.7"
rand_distr = "0.2"
rpc = { package = "tarpc-lib", version = "0.6", path = "../rpc", features = ["serde1"] }
tokio = "0.1"
tokio-executor = "0.1"
tokio-reactor = "0.1"

View File

@@ -60,7 +60,7 @@ where
S: AsyncWrite,
SinkItem: Serialize,
{
type SinkError = io::Error;
type Error = io::Error;
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.inner()
@@ -81,7 +81,9 @@ where
}
}
fn convert<E: Into<Box<Error + Send + Sync>>>(poll: Poll<Result<(), E>>) -> Poll<io::Result<()>> {
fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
poll: Poll<Result<(), E>>,
) -> Poll<io::Result<()>> {
match poll {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
@@ -89,23 +91,24 @@ fn convert<E: Into<Box<Error + Send + Sync>>>(poll: Poll<Result<(), E>>) -> Poll
}
}
impl<Item, SinkItem> rpc::Transport for Transport<TcpStream, Item, SinkItem>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
type Item = Item;
type SinkItem = SinkItem;
fn peer_addr(&self) -> io::Result<SocketAddr> {
impl<Item, SinkItem> Transport<TcpStream, Item, SinkItem> {
/// Returns the address of the peer connected over the transport.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().peer_addr()
}
fn local_addr(&self) -> io::Result<SocketAddr> {
/// Returns the address of this end of the transport.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().local_addr()
}
}
impl<T, Item, SinkItem> AsRef<T> for Transport<T, Item, SinkItem> {
fn as_ref(&self) -> &T {
self.inner.get_ref().get_ref()
}
}
/// Returns a new bincode transport that reads from and writes to `io`.
pub fn new<Item, SinkItem>(io: TcpStream) -> Transport<TcpStream, Item, SinkItem>
where

View File

@@ -8,8 +8,10 @@
#![feature(test, integer_atomics, async_await)]
extern crate test;
use futures::{compat::Executor01CompatExt, prelude::*};
use libtest::stats::Stats;
use test::stats::Stats;
use rpc::{
client, context,
server::{Handler, Server},
@@ -20,8 +22,9 @@ use std::{
};
async fn bench() -> io::Result<()> {
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
let addr = listener.get_ref().local_addr();
tokio_executor::spawn(
Server::<u32, u32>::default()

View File

@@ -7,6 +7,7 @@
//! Tests client/server control flow.
#![feature(async_await)]
#![feature(async_closure)]
use futures::{
compat::{Executor01CompatExt, Future01CompatExt},
@@ -14,8 +15,11 @@ use futures::{
stream::FuturesUnordered,
};
use log::{info, trace};
use rand::distributions::{Distribution, Normal};
use rpc::{client, context, server::Server};
use rand_distr::{Distribution, Normal};
use rpc::{
client, context,
server::{Channel, Server},
};
use std::{
io,
time::{Duration, Instant, SystemTime},
@@ -34,18 +38,14 @@ impl AsDuration for SystemTime {
}
async fn run() -> io::Result<()> {
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
let addr = listener.get_ref().local_addr();
let server = Server::<String, String>::default()
.incoming(listener)
.take(1)
.for_each(async move |channel| {
let channel = if let Ok(channel) = channel {
channel
} else {
return;
};
let client_addr = *channel.client_addr();
let client_addr = channel.get_ref().peer_addr().unwrap();
let handler = channel.respond_with(move |ctx, request| {
// Sleep for a time sampled from a normal distribution with:
// - mean: 1/2 the deadline.
@@ -53,7 +53,7 @@ async fn run() -> io::Result<()> {
let deadline: Duration = ctx.deadline.as_duration();
let deadline_millis = deadline.as_secs() * 1000 + deadline.subsec_millis() as u64;
let distribution =
Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.);
Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.).unwrap();
let delay_millis = distribution.sample(&mut rand::thread_rng()).max(0.);
let delay = Duration::from_millis(delay_millis as u64);
@@ -79,20 +79,16 @@ async fn run() -> io::Result<()> {
let client = client::new::<String, String, _>(client::Config::default(), conn).await?;
// Proxy service
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
let addr = listener.get_ref().local_addr();
let proxy_server = Server::<String, String>::default()
.incoming(listener)
.take(1)
.for_each(move |channel| {
let client = client.clone();
async move {
let channel = if let Ok(channel) = channel {
channel
} else {
return;
};
let client_addr = *channel.client_addr();
let client_addr = channel.get_ref().peer_addr().unwrap();
let handler = channel.respond_with(move |ctx, request| {
trace!("[{}/{}] Proxying request.", ctx.trace_id(), client_addr);
let mut client = client.clone();

View File

@@ -7,14 +7,18 @@
//! Tests client/server control flow.
#![feature(async_await)]
#![feature(async_closure)]
use futures::{
compat::{Executor01CompatExt, Future01CompatExt},
prelude::*,
};
use log::{error, info, trace};
use rand::distributions::{Distribution, Normal};
use rpc::{client, context, server::Server};
use rand_distr::{Distribution, Normal};
use rpc::{
client, context,
server::{Channel, Handler, Server},
};
use std::{
io,
time::{Duration, Instant, SystemTime},
@@ -33,18 +37,15 @@ impl AsDuration for SystemTime {
}
async fn run() -> io::Result<()> {
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let listener = tarpc_bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
let addr = listener.get_ref().local_addr();
let server = Server::<String, String>::default()
.incoming(listener)
.take(1)
.max_concurrent_requests_per_channel(19)
.for_each(async move |channel| {
let channel = if let Ok(channel) = channel {
channel
} else {
return;
};
let client_addr = *channel.client_addr();
let client_addr = channel.get_ref().get_ref().peer_addr().unwrap();
let handler = channel.respond_with(move |ctx, request| {
// Sleep for a time sampled from a normal distribution with:
// - mean: 1/2 the deadline.
@@ -52,7 +53,7 @@ async fn run() -> io::Result<()> {
let deadline: Duration = ctx.deadline.as_duration();
let deadline_millis = deadline.as_secs() * 1000 + deadline.subsec_millis() as u64;
let distribution =
Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.);
Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.).unwrap();
let delay_millis = distribution.sample(&mut rand::thread_rng()).max(0.);
let delay = Duration::from_millis(delay_millis as u64);
@@ -75,7 +76,7 @@ async fn run() -> io::Result<()> {
tokio_executor::spawn(server.unit_error().boxed().compat());
let mut config = client::Config::default();
config.max_in_flight_requests = 10;
config.max_in_flight_requests = 20;
config.pending_request_buffer = 10;
let conn = tarpc_bincode_transport::connect(&addr).await?;
@@ -103,17 +104,11 @@ async fn run() -> io::Result<()> {
}
#[test]
fn ping_pong() -> io::Result<()> {
fn pushback() -> io::Result<()> {
env_logger::init();
rpc::init(tokio::executor::DefaultExecutor::current().compat());
tokio::run(
run()
.map_ok(|_| println!("done"))
.map_err(|e| panic!(e.to_string()))
.boxed()
.compat(),
);
tokio::run(run().map_err(|e| panic!(e.to_string())).boxed().compat());
Ok(())
}

View File

@@ -13,12 +13,13 @@ readme = "../README.md"
description = "An example server built on tarpc."
[dependencies]
bincode-transport = { package = "tarpc-bincode-transport", version = "0.7", path = "../bincode-transport" }
json-transport = { package = "tarpc-json-transport", version = "0.1", path = "../json-transport" }
clap = "2.0"
futures-preview = { version = "0.3.0-alpha.16", features = ["compat"] }
futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] }
serde = { version = "1.0" }
tarpc = { version = "0.18", path = "../tarpc", features = ["serde1"] }
tokio = "0.1"
env_logger = "0.6"
[lib]
name = "service"

View File

@@ -12,7 +12,7 @@ use std::{io, net::SocketAddr};
use tarpc::{client, context};
async fn run(server_addr: SocketAddr, name: String) -> io::Result<()> {
let transport = bincode_transport::connect(&server_addr).await?;
let transport = json_transport::connect(&server_addr).await?;
// new_stub is generated by the service! macro. Like Server, it takes a config and any
// Transport as input, and returns a Client, also generated by the macro.

View File

@@ -10,5 +10,9 @@
// It defines one RPC, hello, which takes one arg, name, and returns a String.
tarpc::service! {
/// Returns a greeting for name.
rpc hello(name: String) -> String;
rpc hello(#[serde(default = "default_name")] name: String) -> String;
}
fn default_name() -> String {
"DefaultName".into()
}

View File

@@ -15,13 +15,13 @@ use futures::{
use std::{io, net::SocketAddr};
use tarpc::{
context,
server::{Handler, Server},
server::{self, Channel, Handler},
};
// This is the type that implements the generated Service trait. It is the business logic
// and is used to start the server.
#[derive(Clone)]
struct HelloServer;
struct HelloServer(SocketAddr);
impl service::Service for HelloServer {
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
@@ -30,29 +30,39 @@ impl service::Service for HelloServer {
type HelloFut = Ready<String>;
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!("Hello, {}!", name))
future::ready(format!(
"Hello, {}! You are connected from {:?}.",
name, self.0
))
}
}
async fn run(server_addr: SocketAddr) -> io::Result<()> {
// bincode_transport is provided by the associated crate bincode-transport. It makes it easy
// to start up a serde-powered bincode serialization strategy over TCP.
let transport = bincode_transport::listen(&server_addr)?;
// The server is configured with the defaults.
let server = Server::default()
// Server can listen on any type that implements the Transport trait.
.incoming(transport)
json_transport::listen(&server_addr)?
// Ignore accept errors.
.filter_map(|r| future::ready(r.ok()))
.map(server::BaseChannel::with_defaults)
// Limit channels to 1 per IP.
.max_channels_per_key(1, |t| t.as_ref().peer_addr().unwrap().ip())
// serve is generated by the service! macro. It takes as input any type implementing
// the generated Service trait.
.respond_with(service::serve(HelloServer));
server.await;
.map(|channel| {
let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap());
channel.respond_with(service::serve(server))
})
// Max 10 channels.
.buffer_unordered(10)
.for_each(|_| futures::future::ready(()))
.await;
Ok(())
}
fn main() {
env_logger::init();
let flags = App::new("Hello Server")
.version("0.1")
.author("Tim <tikue@google.com>")

View File

@@ -13,10 +13,9 @@ readme = "../README.md"
description = "A JSON-based transport for tarpc services."
[dependencies]
futures-preview = { version = "0.3.0-alpha.16", features = ["compat"] }
futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] }
futures_legacy = { version = "0.1", package = "futures" }
pin-utils = "0.1.0-alpha.4"
rpc = { package = "tarpc-lib", version = "0.6", path = "../rpc", features = ["serde1"] }
serde = "1.0"
serde_json = "1.0"
tokio = "0.1"
@@ -27,9 +26,10 @@ tokio-tcp = "0.1"
[dev-dependencies]
env_logger = "0.6"
humantime = "1.0"
libtest = "0.0.1"
log = "0.4"
rand = "0.6"
rand = "0.7"
rand_distr = "0.2"
rpc = { package = "tarpc-lib", version = "0.6", path = "../rpc", features = ["serde1"] }
tokio = "0.1"
tokio-executor = "0.1"
tokio-reactor = "0.1"

View File

@@ -67,7 +67,7 @@ where
S: AsyncWrite,
SinkItem: Serialize,
{
type SinkError = io::Error;
type Error = io::Error;
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.inner()
@@ -88,7 +88,9 @@ where
}
}
fn convert<E: Into<Box<Error + Send + Sync>>>(poll: Poll<Result<(), E>>) -> Poll<io::Result<()>> {
fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
poll: Poll<Result<(), E>>,
) -> Poll<io::Result<()>> {
match poll {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
@@ -96,15 +98,9 @@ fn convert<E: Into<Box<Error + Send + Sync>>>(poll: Poll<Result<(), E>>) -> Poll
}
}
impl<Item, SinkItem> rpc::Transport for Transport<TcpStream, Item, SinkItem>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
{
type Item = Item;
type SinkItem = SinkItem;
fn peer_addr(&self) -> io::Result<SocketAddr> {
impl<Item, SinkItem> Transport<TcpStream, Item, SinkItem> {
/// Returns the peer address of the underlying TcpStream.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner
.get_ref()
.get_ref()
@@ -113,7 +109,8 @@ where
.peer_addr()
}
fn local_addr(&self) -> io::Result<SocketAddr> {
/// Returns the local address of the underlying TcpStream.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner
.get_ref()
.get_ref()

View File

@@ -8,8 +8,10 @@
#![feature(test, integer_atomics, async_await)]
extern crate test;
use futures::{compat::Executor01CompatExt, prelude::*};
use libtest::stats::Stats;
use test::stats::Stats;
use rpc::{
client, context,
server::{Handler, Server},
@@ -20,8 +22,9 @@ use std::{
};
async fn bench() -> io::Result<()> {
let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
let addr = listener.get_ref().local_addr();
tokio_executor::spawn(
Server::<u32, u32>::default()

View File

@@ -7,6 +7,7 @@
//! Tests client/server control flow.
#![feature(async_await)]
#![feature(async_closure)]
use futures::{
compat::{Executor01CompatExt, Future01CompatExt},
@@ -14,8 +15,11 @@ use futures::{
stream::FuturesUnordered,
};
use log::{info, trace};
use rand::distributions::{Distribution, Normal};
use rpc::{client, context, server::Server};
use rand_distr::{Distribution, Normal};
use rpc::{
client, context,
server::{Channel, Server},
};
use std::{
io,
time::{Duration, Instant, SystemTime},
@@ -34,18 +38,14 @@ impl AsDuration for SystemTime {
}
async fn run() -> io::Result<()> {
let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
let addr = listener.get_ref().local_addr();
let server = Server::<String, String>::default()
.incoming(listener)
.take(1)
.for_each(async move |channel| {
let channel = if let Ok(channel) = channel {
channel
} else {
return;
};
let client_addr = *channel.client_addr();
let client_addr = channel.get_ref().peer_addr().unwrap();
let handler = channel.respond_with(move |ctx, request| {
// Sleep for a time sampled from a normal distribution with:
// - mean: 1/2 the deadline.
@@ -53,7 +53,7 @@ async fn run() -> io::Result<()> {
let deadline: Duration = ctx.deadline.as_duration();
let deadline_millis = deadline.as_secs() * 1000 + deadline.subsec_millis() as u64;
let distribution =
Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.);
Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.).unwrap();
let delay_millis = distribution.sample(&mut rand::thread_rng()).max(0.);
let delay = Duration::from_millis(delay_millis as u64);
@@ -79,20 +79,16 @@ async fn run() -> io::Result<()> {
let client = client::new::<String, String, _>(client::Config::default(), conn).await?;
// Proxy service
let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
let addr = listener.get_ref().local_addr();
let proxy_server = Server::<String, String>::default()
.incoming(listener)
.take(1)
.for_each(move |channel| {
let client = client.clone();
async move {
let channel = if let Ok(channel) = channel {
channel
} else {
return;
};
let client_addr = *channel.client_addr();
let client_addr = channel.get_ref().peer_addr().unwrap();
let handler = channel.respond_with(move |ctx, request| {
trace!("[{}/{}] Proxying request.", ctx.trace_id(), client_addr);
let mut client = client.clone();

View File

@@ -7,14 +7,18 @@
//! Tests client/server control flow.
#![feature(async_await)]
#![feature(async_closure)]
use futures::{
compat::{Executor01CompatExt, Future01CompatExt},
prelude::*,
};
use log::{error, info, trace};
use rand::distributions::{Distribution, Normal};
use rpc::{client, context, server::Server};
use rand_distr::{Distribution, Normal};
use rpc::{
client, context,
server::{Channel, Server},
};
use std::{
io,
time::{Duration, Instant, SystemTime},
@@ -33,18 +37,14 @@ impl AsDuration for SystemTime {
}
async fn run() -> io::Result<()> {
let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let listener = tarpc_json_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
let addr = listener.get_ref().local_addr();
let server = Server::<String, String>::default()
.incoming(listener)
.take(1)
.for_each(async move |channel| {
let channel = if let Ok(channel) = channel {
channel
} else {
return;
};
let client_addr = *channel.client_addr();
let client_addr = channel.get_ref().peer_addr().unwrap();
let handler = channel.respond_with(move |ctx, request| {
// Sleep for a time sampled from a normal distribution with:
// - mean: 1/2 the deadline.
@@ -52,7 +52,7 @@ async fn run() -> io::Result<()> {
let deadline: Duration = ctx.deadline.as_duration();
let deadline_millis = deadline.as_secs() * 1000 + deadline.subsec_millis() as u64;
let distribution =
Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.);
Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.).unwrap();
let delay_millis = distribution.sample(&mut rand::thread_rng()).max(0.);
let delay = Duration::from_millis(delay_millis as u64);

View File

@@ -18,17 +18,17 @@ serde1 = ["trace/serde", "serde", "serde/derive"]
[dependencies]
fnv = "1.0"
futures-preview = { version = "0.3.0-alpha.16", features = ["compat"] }
futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] }
humantime = "1.0"
log = "0.4"
pin-utils = "0.1.0-alpha.4"
rand = "0.6"
rand = "0.7"
tokio-timer = "0.2"
trace = { package = "tarpc-trace", version = "0.2", path = "../trace" }
serde = { optional = true, version = "1.0" }
[dev-dependencies]
futures-test-preview = { version = "0.3.0-alpha.16" }
futures-test-preview = { version = "0.3.0-alpha.17" }
env_logger = "0.6"
tokio = "0.1"
tokio-executor = "0.1"

View File

@@ -7,7 +7,7 @@
use crate::{
context,
util::{deadline_compat, AsDuration, Compact},
ClientMessage, ClientMessageKind, PollIo, Request, Response, Transport,
ClientMessage, PollIo, Request, Response, Transport,
};
use fnv::FnvHashMap;
use futures::{
@@ -24,7 +24,6 @@ use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::{
io,
marker::{self, Unpin},
net::SocketAddr,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
@@ -44,7 +43,6 @@ pub struct Channel<Req, Resp> {
cancellation: RequestCancellation,
/// The ID to use for the next request to stage.
next_request_id: Arc<AtomicU64>,
server_addr: SocketAddr,
}
impl<Req, Resp> Clone for Channel<Req, Resp> {
@@ -53,7 +51,6 @@ impl<Req, Resp> Clone for Channel<Req, Resp> {
to_dispatch: self.to_dispatch.clone(),
cancellation: self.cancellation.clone(),
next_request_id: self.next_request_id.clone(),
server_addr: self.server_addr,
}
}
}
@@ -122,9 +119,8 @@ impl<Req, Resp> Channel<Req, Resp> {
let timeout = ctx.deadline.as_duration();
let deadline = Instant::now() + timeout;
trace!(
"[{}/{}] Queuing request with deadline {} (timeout {:?}).",
"[{}] Queuing request with deadline {} (timeout {:?}).",
ctx.trace_id(),
self.server_addr,
format_rfc3339(ctx.deadline),
timeout,
);
@@ -132,7 +128,6 @@ impl<Req, Resp> Channel<Req, Resp> {
let (response_completion, response) = oneshot::channel();
let cancellation = self.cancellation.clone();
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
let server_addr = self.server_addr;
Send {
fut: MapOkDispatchResponse::new(
MapErrConnectionReset::new(self.to_dispatch.send(DispatchRequest {
@@ -147,7 +142,6 @@ impl<Req, Resp> Channel<Req, Resp> {
request_id,
cancellation,
ctx,
server_addr,
},
),
}
@@ -171,11 +165,9 @@ struct DispatchResponse<Resp> {
complete: bool,
cancellation: RequestCancellation,
request_id: u64,
server_addr: SocketAddr,
}
impl<Resp> DispatchResponse<Resp> {
unsafe_pinned!(server_addr: SocketAddr);
unsafe_pinned!(ctx: context::Context);
}
@@ -192,7 +184,6 @@ impl<Resp> Future for DispatchResponse<Resp> {
}
Err(e) => Err({
let trace_id = *self.as_mut().ctx().trace_id();
let server_addr = *self.as_mut().server_addr();
if e.is_elapsed() {
io::Error::new(
@@ -209,12 +200,9 @@ impl<Resp> Future for DispatchResponse<Resp> {
.to_string(),
)
} else if e.is_shutdown() {
panic!("[{}/{}] Timer was shutdown", trace_id, server_addr)
panic!("[{}] Timer was shutdown", trace_id)
} else {
panic!(
"[{}/{}] Unrecognized timer error: {}",
trace_id, server_addr, e
)
panic!("[{}] Unrecognized timer error: {}", trace_id, e)
}
} else if e.is_inner() {
// The oneshot is Canceled when the dispatch task ends. In that case,
@@ -223,10 +211,7 @@ impl<Resp> Future for DispatchResponse<Resp> {
self.complete = true;
io::Error::from(io::ErrorKind::ConnectionReset)
} else {
panic!(
"[{}/{}] Unrecognized deadline error: {}",
trace_id, server_addr, e
)
panic!("[{}] Unrecognized deadline error: {}", trace_id, e)
}
}),
})
@@ -255,15 +240,11 @@ impl<Resp> Drop for DispatchResponse<Resp> {
/// Spawns a dispatch task on the default executor that manages the lifecycle of requests initiated
/// by the returned [`Channel`].
pub async fn spawn<Req, Resp, C>(
config: Config,
transport: C,
server_addr: SocketAddr,
) -> io::Result<Channel<Req, Resp>>
pub async fn spawn<Req, Resp, C>(config: Config, transport: C) -> io::Result<Channel<Req, Resp>>
where
Req: marker::Send + 'static,
Resp: marker::Send + 'static,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>> + marker::Send + 'static,
C: Transport<ClientMessage<Req>, Response<Resp>> + marker::Send + 'static,
{
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
let (cancellation, canceled_requests) = cancellations();
@@ -272,13 +253,12 @@ where
crate::spawn(
RequestDispatch {
config,
server_addr,
canceled_requests,
transport: transport.fuse(),
in_flight_requests: FnvHashMap::default(),
pending_requests: pending_requests.fuse(),
}
.unwrap_or_else(move |e| error!("[{}] Connection broken: {}", server_addr, e)),
.unwrap_or_else(move |e| error!("Connection broken: {}", e)),
)
.map_err(|e| {
io::Error::new(
@@ -293,7 +273,6 @@ where
Ok(Channel {
to_dispatch,
cancellation,
server_addr,
next_request_id: Arc::new(AtomicU64::new(0)),
})
}
@@ -311,17 +290,14 @@ struct RequestDispatch<Req, Resp, C> {
in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>,
/// Configures limits to prevent unlimited resource usage.
config: Config,
/// The address of the server connected to.
server_addr: SocketAddr,
}
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
where
Req: marker::Send,
Resp: marker::Send,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>>,
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
unsafe_pinned!(server_addr: SocketAddr);
unsafe_pinned!(in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>);
unsafe_pinned!(canceled_requests: Fuse<CanceledRequests>);
unsafe_pinned!(pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>);
@@ -333,10 +309,7 @@ where
self.complete(response);
Some(Ok(()))
}
None => {
trace!("[{}] read half closed", self.as_mut().server_addr());
None
}
None => None,
})
}
@@ -415,10 +388,7 @@ where
return Poll::Ready(Some(Ok(request)));
}
None => {
trace!("[{}] pending_requests closed", self.as_mut().server_addr());
return Poll::Ready(None);
}
None => return Poll::Ready(None),
}
}
}
@@ -440,23 +410,11 @@ where
self.as_mut().in_flight_requests().remove(&request_id)
{
self.as_mut().in_flight_requests().compact(0.1);
debug!(
"[{}/{}] Removed request.",
in_flight_data.ctx.trace_id(),
self.as_mut().server_addr()
);
debug!("[{}] Removed request.", in_flight_data.ctx.trace_id());
return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id))));
}
}
None => {
trace!(
"[{}] canceled_requests closed.",
self.as_mut().server_addr()
);
return Poll::Ready(None);
}
None => return Poll::Ready(None),
}
}
}
@@ -466,14 +424,14 @@ where
dispatch_request: DispatchRequest<Req, Resp>,
) -> io::Result<()> {
let request_id = dispatch_request.request_id;
let request = ClientMessage {
trace_context: dispatch_request.ctx.trace_context,
message: ClientMessageKind::Request(Request {
id: request_id,
message: dispatch_request.request,
let request = ClientMessage::Request(Request {
id: request_id,
message: dispatch_request.request,
context: context::Context {
deadline: dispatch_request.ctx.deadline,
}),
};
trace_context: dispatch_request.ctx.trace_context,
},
});
self.as_mut().transport().start_send(request)?;
self.as_mut().in_flight_requests().insert(
request_id,
@@ -491,16 +449,12 @@ where
request_id: u64,
) -> io::Result<()> {
let trace_id = *context.trace_id();
let cancel = ClientMessage {
let cancel = ClientMessage::Cancel {
trace_context: context.trace_context,
message: ClientMessageKind::Cancel { request_id },
request_id,
};
self.as_mut().transport().start_send(cancel)?;
trace!(
"[{}/{}] Cancel message sent.",
trace_id,
self.as_mut().server_addr()
);
trace!("[{}] Cancel message sent.", trace_id);
Ok(())
}
@@ -513,18 +467,13 @@ where
{
self.as_mut().in_flight_requests().compact(0.1);
trace!(
"[{}/{}] Received response.",
in_flight_data.ctx.trace_id(),
self.as_mut().server_addr()
);
trace!("[{}] Received response.", in_flight_data.ctx.trace_id());
let _ = in_flight_data.response_completion.send(response);
return true;
}
debug!(
"[{}] No in-flight request found for request_id = {}.",
self.as_mut().server_addr(),
"No in-flight request found for request_id = {}.",
response.request_id
);
@@ -537,58 +486,29 @@ impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
where
Req: marker::Send,
Resp: marker::Send,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>>,
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
trace!("[{}] RequestDispatch::poll", self.as_mut().server_addr());
loop {
match (self.pump_read(cx)?, self.pump_write(cx)?) {
(read, write @ Poll::Ready(None)) => {
(read, Poll::Ready(None)) => {
if self.as_mut().in_flight_requests().is_empty() {
info!(
"[{}] Shutdown: write half closed, and no requests in flight.",
self.as_mut().server_addr()
);
info!("Shutdown: write half closed, and no requests in flight.");
return Poll::Ready(Ok(()));
}
let addr = *self.as_mut().server_addr();
info!(
"[{}] {} requests in flight.",
addr,
"Shutdown: write half closed, and {} requests in flight.",
self.as_mut().in_flight_requests().len()
);
match read {
Poll::Ready(Some(())) => continue,
_ => {
trace!(
"[{}] read: {:?}, write: {:?}, (not ready)",
self.as_mut().server_addr(),
read,
write,
);
return Poll::Pending;
}
_ => return Poll::Pending,
}
}
(read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => {
trace!(
"[{}] read: {:?}, write: {:?}",
self.as_mut().server_addr(),
read,
write,
)
}
(read, write) => {
trace!(
"[{}] read: {:?}, write: {:?} (not ready)",
self.as_mut().server_addr(),
read,
write,
);
return Poll::Pending;
}
(Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
_ => return Poll::Pending,
}
}
}
@@ -848,14 +768,7 @@ mod tests {
};
use futures_test::task::noop_waker_ref;
use std::time::Duration;
use std::{
marker,
net::{IpAddr, Ipv4Addr, SocketAddr},
pin::Pin,
sync::atomic::AtomicU64,
sync::Arc,
time::Instant,
};
use std::{marker, pin::Pin, sync::atomic::AtomicU64, sync::Arc, time::Instant};
#[test]
fn dispatch_response_cancels_on_timeout() {
@@ -869,7 +782,6 @@ mod tests {
request_id: 3,
cancellation,
ctx: context::current(),
server_addr: SocketAddr::from(([0, 0, 0, 0], 9999)),
};
{
pin_utils::pin_mut!(resp);
@@ -994,7 +906,6 @@ mod tests {
canceled_requests: CanceledRequests(canceled_requests).fuse(),
in_flight_requests: FnvHashMap::default(),
config: Config::default(),
server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
};
let cancellation = RequestCancellation(cancel_tx);
@@ -1002,7 +913,6 @@ mod tests {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicU64::new(0)),
server_addr: SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
};
(dispatch, channel, server_channel)

View File

@@ -8,11 +8,7 @@
use crate::{context, ClientMessage, Response, Transport};
use futures::prelude::*;
use log::warn;
use std::{
io,
net::{Ipv4Addr, SocketAddr},
};
use std::io;
/// Provides a [`Client`] backed by a transport.
pub mod channel;
@@ -137,15 +133,7 @@ pub async fn new<Req, Resp, T>(config: Config, transport: T) -> io::Result<Chann
where
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>> + Send + 'static,
T: Transport<ClientMessage<Req>, Response<Resp>> + Send + 'static,
{
let server_addr = transport.peer_addr().unwrap_or_else(|e| {
warn!(
"Setting peer to unspecified because peer could not be determined: {}",
e
);
SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0)
});
Ok(channel::spawn(config, transport, server_addr).await?)
Ok(channel::spawn(config, transport).await?)
}

View File

@@ -16,10 +16,20 @@ use trace::{self, TraceId};
/// The context should not be stored directly in a server implementation, because the context will
/// be different for each request in scope.
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct Context {
/// When the client expects the request to be complete by. The server should cancel the request
/// if it is not complete by this time.
#[cfg_attr(
feature = "serde1",
serde(serialize_with = "crate::util::serde::serialize_epoch_secs")
)]
#[cfg_attr(
feature = "serde1",
serde(deserialize_with = "crate::util::serde::deserialize_epoch_secs")
)]
#[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))]
pub deadline: SystemTime,
/// Uniquely identifies requests originating from the same source.
/// When a service handles a request by making requests itself, those requests should
@@ -28,6 +38,11 @@ pub struct Context {
pub trace_context: trace::Context,
}
#[cfg(feature = "serde1")]
fn ten_seconds_from_now() -> SystemTime {
return SystemTime::now() + Duration::from_secs(10);
}
/// Returns the context for the current request, or a default Context if no request is active.
// TODO: populate Context with request-scoped data, with default fallbacks.
pub fn current() -> Context {

View File

@@ -5,11 +5,14 @@
// https://opensource.org/licenses/MIT.
#![feature(
weak_counts,
non_exhaustive,
integer_atomics,
try_trait,
arbitrary_self_types,
async_await
async_await,
trait_alias,
async_closure
)]
#![deny(missing_docs, missing_debug_implementations)]
@@ -49,19 +52,7 @@ use std::{cell::RefCell, io, sync::Once, time::SystemTime};
#[derive(Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct ClientMessage<T> {
/// The trace context associates the message with a specific chain of causally-related actions,
/// possibly orchestrated across many distributed systems.
pub trace_context: trace::Context,
/// The message payload.
pub message: ClientMessageKind<T>,
}
/// Different messages that can be sent from a client to a server.
#[derive(Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum ClientMessageKind<T> {
pub enum ClientMessage<T> {
/// A request initiated by a user. The server responds to a request by invoking a
/// service-provided request handler. The handler completes with a [`response`](Response), which
/// the server sends back to the client.
@@ -74,35 +65,30 @@ pub enum ClientMessageKind<T> {
/// not be canceled, because the framework layer does not
/// know about them.
Cancel {
/// The trace context associates the message with a specific chain of causally-related actions,
/// possibly orchestrated across many distributed systems.
#[cfg_attr(feature = "serde", serde(default))]
trace_context: trace::Context,
/// The ID of the request to cancel.
request_id: u64,
},
}
/// A request from a client to a server.
#[derive(Debug)]
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct Request<T> {
/// Trace context, deadline, and other cross-cutting concerns.
pub context: context::Context,
/// Uniquely identifies the request across all requests sent over a single channel.
pub id: u64,
/// The request body.
pub message: T,
/// When the client expects the request to be complete by. The server will cancel the request
/// if it is not complete by this time.
#[cfg_attr(
feature = "serde1",
serde(serialize_with = "util::serde::serialize_epoch_secs")
)]
#[cfg_attr(
feature = "serde1",
serde(deserialize_with = "util::serde::deserialize_epoch_secs")
)]
pub deadline: SystemTime,
}
/// A response from a server to a client.
#[derive(Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct Response<T> {
@@ -113,7 +99,7 @@ pub struct Response<T> {
}
/// An error response from a server to a client.
#[derive(Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct ServerError {
@@ -140,7 +126,7 @@ impl From<ServerError> for io::Error {
impl<T> Request<T> {
/// Returns the deadline for this request.
pub fn deadline(&self) -> &SystemTime {
&self.deadline
&self.context.deadline
}
}

View File

@@ -5,259 +5,331 @@
// https://opensource.org/licenses/MIT.
use crate::{
server::{Channel, Config},
server::{self, Channel},
util::Compact,
ClientMessage, PollIo, Response, Transport,
Response,
};
use fnv::FnvHashMap;
use futures::{
channel::mpsc,
future::AbortRegistration,
prelude::*,
ready,
stream::Fuse,
task::{Context, Poll},
};
use log::{debug, error, info, trace, warn};
use pin_utils::unsafe_pinned;
use log::{debug, info, trace};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::sync::{Arc, Weak};
use std::{
collections::hash_map::Entry,
io,
marker::PhantomData,
net::{IpAddr, SocketAddr},
ops::Try,
option::NoneError,
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, io, marker::Unpin, ops::Try,
pin::Pin,
};
/// Drops connections under configurable conditions:
///
/// 1. If the max number of connections is reached.
/// 2. If the max number of connections for a single IP is reached.
/// A single-threaded filter that drops channels based on per-key limits.
#[derive(Debug)]
pub struct ConnectionFilter<S, Req, Resp> {
pub struct ChannelFilter<S, K, F>
where
K: Eq + Hash,
{
listener: Fuse<S>,
closed_connections: mpsc::UnboundedSender<SocketAddr>,
closed_connections_rx: mpsc::UnboundedReceiver<SocketAddr>,
config: Config,
connections_per_ip: FnvHashMap<IpAddr, usize>,
open_connections: usize,
ghost: PhantomData<(Req, Resp)>,
channels_per_key: u32,
dropped_keys: mpsc::UnboundedReceiver<K>,
dropped_keys_tx: mpsc::UnboundedSender<K>,
key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
keymaker: F,
}
enum NewConnection<Req, Resp, C> {
Filtered,
Accepted(Channel<Req, Resp, C>),
/// A channel that is tracked by a ChannelFilter.
#[derive(Debug)]
pub struct TrackedChannel<C, K> {
inner: C,
tracker: Arc<Tracker<K>>,
}
impl<Req, Resp, C> Try for NewConnection<Req, Resp, C> {
type Ok = Channel<Req, Resp, C>;
type Error = NoneError;
impl<C, K> TrackedChannel<C, K> {
unsafe_pinned!(inner: C);
}
fn into_result(self) -> Result<Channel<Req, Resp, C>, NoneError> {
#[derive(Debug)]
struct Tracker<K> {
key: Option<K>,
dropped_keys: mpsc::UnboundedSender<K>,
}
impl<K> Drop for Tracker<K> {
fn drop(&mut self) {
// Don't care if the listener is dropped.
let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap());
}
}
/// A running handler serving all requests for a single client.
#[derive(Debug)]
pub struct TrackedHandler<K, Fut> {
inner: Fut,
tracker: Tracker<K>,
}
impl<K, Fut> TrackedHandler<K, Fut>
where
Fut: Future,
{
unsafe_pinned!(inner: Fut);
}
impl<K, Fut> Future for TrackedHandler<K, Fut>
where
Fut: Future,
{
type Output = Fut::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.inner().poll(cx)
}
}
impl<C, K> Stream for TrackedChannel<C, K>
where
C: Channel,
{
type Item = <C as Stream>::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.channel().poll_next(cx)
}
}
impl<C, K> Sink<Response<C::Resp>> for TrackedChannel<C, K>
where
C: Channel,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.channel().poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Response<C::Resp>) -> Result<(), Self::Error> {
self.channel().start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.channel().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.channel().poll_close(cx)
}
}
impl<C, K> AsRef<C> for TrackedChannel<C, K> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C, K> Channel for TrackedChannel<C, K>
where
C: Channel,
{
type Req = C::Req;
type Resp = C::Resp;
fn config(&self) -> &server::Config {
self.inner.config()
}
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
self.inner().in_flight_requests()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
self.inner().start_request(request_id)
}
}
impl<C, K> TrackedChannel<C, K> {
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
}
/// Returns the pinned inner channel.
fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> {
self.inner()
}
}
enum NewChannel<C, K> {
Accepted(TrackedChannel<C, K>),
Filtered(K),
}
impl<C, K> Try for NewChannel<C, K> {
type Ok = TrackedChannel<C, K>;
type Error = K;
fn into_result(self) -> Result<TrackedChannel<C, K>, K> {
match self {
NewConnection::Filtered => Err(NoneError),
NewConnection::Accepted(channel) => Ok(channel),
NewChannel::Accepted(channel) => Ok(channel),
NewChannel::Filtered(k) => Err(k),
}
}
fn from_error(_: NoneError) -> Self {
NewConnection::Filtered
fn from_error(k: K) -> Self {
NewChannel::Filtered(k)
}
fn from_ok(channel: Channel<Req, Resp, C>) -> Self {
NewConnection::Accepted(channel)
fn from_ok(channel: TrackedChannel<C, K>) -> Self {
NewChannel::Accepted(channel)
}
}
impl<S, Req, Resp> ConnectionFilter<S, Req, Resp> {
unsafe_pinned!(open_connections: usize);
unsafe_pinned!(config: Config);
unsafe_pinned!(connections_per_ip: FnvHashMap<IpAddr, usize>);
unsafe_pinned!(closed_connections_rx: mpsc::UnboundedReceiver<SocketAddr>);
impl<S, K, F> ChannelFilter<S, K, F>
where
K: fmt::Display + Eq + Hash + Clone,
{
unsafe_pinned!(listener: Fuse<S>);
unsafe_pinned!(dropped_keys: mpsc::UnboundedReceiver<K>);
unsafe_pinned!(dropped_keys_tx: mpsc::UnboundedSender<K>);
unsafe_unpinned!(key_counts: FnvHashMap<K, Weak<Tracker<K>>>);
unsafe_unpinned!(channels_per_key: u32);
unsafe_unpinned!(keymaker: F);
}
/// Sheds new connections to stay under configured limits.
pub fn filter<C>(listener: S, config: Config) -> Self
where
S: Stream<Item = Result<C, io::Error>>,
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
let (closed_connections, closed_connections_rx) = mpsc::unbounded();
ConnectionFilter {
impl<S, K, F> ChannelFilter<S, K, F>
where
K: Eq + Hash,
S: Stream,
{
/// Sheds new channels to stay under configured limits.
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded();
ChannelFilter {
listener: listener.fuse(),
closed_connections,
closed_connections_rx,
config,
connections_per_ip: FnvHashMap::default(),
open_connections: 0,
ghost: PhantomData,
channels_per_key,
dropped_keys,
dropped_keys_tx,
key_counts: FnvHashMap::default(),
keymaker,
}
}
}
fn handle_new_connection<C>(self: &mut Pin<&mut Self>, stream: C) -> NewConnection<Req, Resp, C>
where
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
let peer = match stream.peer_addr() {
Ok(peer) => peer,
Err(e) => {
warn!("Could not get peer_addr of new connection: {}", e);
return NewConnection::Filtered;
}
};
let open_connections = *self.as_mut().open_connections();
if open_connections >= self.as_mut().config().max_connections {
warn!(
"[{}] Shedding connection because the maximum open connections \
limit is reached ({}/{}).",
peer,
open_connections,
self.as_mut().config().max_connections
);
return NewConnection::Filtered;
}
let config = self.config.clone();
let open_connections_for_ip = self.increment_connections_for_ip(&peer)?;
*self.as_mut().open_connections() += 1;
impl<S, C, K, F> ChannelFilter<S, K, F>
where
S: Stream<Item = C>,
C: Channel,
K: fmt::Display + Eq + Hash + Clone + Unpin,
F: Fn(&C) -> K,
{
fn handle_new_channel(self: &mut Pin<&mut Self>, stream: C) -> NewChannel<C, K> {
let key = self.as_mut().keymaker()(&stream);
let tracker = self.increment_channels_for_key(key.clone())?;
let max = self.as_mut().channels_per_key();
debug!(
"[{}] Opening channel ({}/{} connections for IP, {} total).",
peer,
open_connections_for_ip,
config.max_connections_per_ip,
self.as_mut().open_connections(),
"[{}] Opening channel ({}/{}) channels for key.",
key,
Arc::strong_count(&tracker),
max
);
NewConnection::Accepted(Channel {
client_addr: peer,
closed_connections: self.closed_connections.clone(),
transport: stream.fuse(),
config,
ghost: PhantomData,
NewChannel::Accepted(TrackedChannel {
tracker,
inner: stream,
})
}
fn handle_closed_connection(self: &mut Pin<&mut Self>, addr: &SocketAddr) {
*self.as_mut().open_connections() -= 1;
debug!(
"[{}] Closing channel. {} open connections remaining.",
addr, self.open_connections
);
self.decrement_connections_for_ip(&addr);
self.as_mut().connections_per_ip().compact(0.1);
}
fn increment_channels_for_key(self: &mut Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
let channels_per_key = self.channels_per_key;
let dropped_keys = self.dropped_keys_tx.clone();
let key_counts = &mut self.as_mut().key_counts();
match key_counts.entry(key.clone()) {
Entry::Vacant(vacant) => {
let tracker = Arc::new(Tracker {
key: Some(key),
dropped_keys,
});
fn increment_connections_for_ip(self: &mut Pin<&mut Self>, peer: &SocketAddr) -> Option<usize> {
let max_connections_per_ip = self.as_mut().config().max_connections_per_ip;
let mut occupied;
let mut connections_per_ip = self.as_mut().connections_per_ip();
let occupied = match connections_per_ip.entry(peer.ip()) {
Entry::Vacant(vacant) => vacant.insert(0),
Entry::Occupied(o) => {
if *o.get() < max_connections_per_ip {
// Store the reference outside the block to extend the lifetime.
occupied = o;
occupied.get_mut()
} else {
vacant.insert(Arc::downgrade(&tracker));
Ok(tracker)
}
Entry::Occupied(mut o) => {
let count = o.get().strong_count();
if count >= channels_per_key.try_into().unwrap() {
info!(
"[{}] Opened max connections from IP ({}/{}).",
peer,
o.get(),
max_connections_per_ip
"[{}] Opened max channels from key ({}/{}).",
key, count, channels_per_key
);
return None;
}
}
};
*occupied += 1;
Some(*occupied)
}
fn decrement_connections_for_ip(self: &mut Pin<&mut Self>, addr: &SocketAddr) {
let should_compact = match self.as_mut().connections_per_ip().entry(addr.ip()) {
Entry::Vacant(_) => {
error!("[{}] Got vacant entry when closing connection.", addr);
return;
}
Entry::Occupied(mut occupied) => {
*occupied.get_mut() -= 1;
if *occupied.get() == 0 {
occupied.remove();
true
Err(key)
} else {
false
Ok(o.get().upgrade().unwrap_or_else(|| {
let tracker = Arc::new(Tracker {
key: Some(key),
dropped_keys,
});
*o.get_mut() = Arc::downgrade(&tracker);
tracker
}))
}
}
};
if should_compact {
self.as_mut().connections_per_ip().compact(0.1);
}
}
fn poll_listener<C>(
fn poll_listener(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<NewConnection<Req, Resp, C>>
where
S: Stream<Item = Result<C, io::Error>>,
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
match ready!(self.as_mut().listener().poll_next_unpin(cx)?) {
Some(codec) => Poll::Ready(Some(Ok(self.handle_new_connection(codec)))),
) -> Poll<Option<NewChannel<C, K>>> {
match ready!(self.as_mut().listener().poll_next_unpin(cx)) {
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
None => Poll::Ready(None),
}
}
fn poll_closed_connections(
self: &mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
match ready!(self.as_mut().closed_connections_rx().poll_next_unpin(cx)) {
Some(addr) => {
self.handle_closed_connection(&addr);
Poll::Ready(Ok(()))
fn poll_closed_channels(self: &mut Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
match ready!(self.as_mut().dropped_keys().poll_next_unpin(cx)) {
Some(key) => {
debug!("All channels dropped for key [{}]", key);
self.as_mut().key_counts().remove(&key);
self.as_mut().key_counts().compact(0.1);
Poll::Ready(())
}
None => unreachable!("Holding a copy of closed_connections and didn't close it."),
None => unreachable!("Holding a copy of closed_channels and didn't close it."),
}
}
}
impl<S, Req, Resp, T> Stream for ConnectionFilter<S, Req, Resp>
impl<S, C, K, F> Stream for ChannelFilter<S, K, F>
where
S: Stream<Item = Result<T, io::Error>>,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
S: Stream<Item = C>,
C: Channel,
K: fmt::Display + Eq + Hash + Clone + Unpin,
F: Fn(&C) -> K,
{
type Item = io::Result<Channel<Req, Resp, T>>;
type Item = TrackedChannel<C, K>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Channel<Req, Resp, T>> {
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<TrackedChannel<C, K>>> {
loop {
match (
self.as_mut().poll_listener(cx)?,
self.poll_closed_connections(cx)?,
self.as_mut().poll_listener(cx),
self.poll_closed_channels(cx),
) {
(Poll::Ready(Some(NewConnection::Accepted(channel))), _) => {
return Poll::Ready(Some(Ok(channel)));
(Poll::Ready(Some(NewChannel::Accepted(channel))), _) => {
return Poll::Ready(Some(channel));
}
(Poll::Ready(Some(NewConnection::Filtered)), _) | (_, Poll::Ready(())) => {
trace!(
"Filtered a connection; {} open.",
self.as_mut().open_connections()
);
(Poll::Ready(Some(NewChannel::Filtered(_))), _) => {
continue;
}
(_, Poll::Ready(())) => continue,
(Poll::Pending, Poll::Pending) => return Poll::Pending,
(Poll::Ready(None), Poll::Pending) => {
if *self.as_mut().open_connections() > 0 {
trace!(
"Listener closed; {} open connections.",
self.as_mut().open_connections()
);
return Poll::Pending;
}
trace!("Shutting down listener: all connections closed, and no more coming.");
trace!("Shutting down listener.");
return Poll::Ready(None);
}
}

View File

@@ -7,27 +7,27 @@
//! Provides a server that concurrently handles many connections sending multiplexed requests.
use crate::{
context, util::deadline_compat, util::AsDuration, util::Compact, ClientMessage,
ClientMessageKind, PollIo, Request, Response, ServerError, Transport,
context, util::deadline_compat, util::AsDuration, util::Compact, ClientMessage, PollIo,
Request, Response, ServerError, Transport,
};
use fnv::FnvHashMap;
use futures::{
channel::mpsc,
future::{abortable, AbortHandle},
future::{AbortHandle, AbortRegistration, Abortable},
prelude::*,
ready,
stream::Fuse,
task::{Context, Poll},
try_ready,
};
use humantime::format_rfc3339;
use log::{debug, error, info, trace, warn};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::{
error::Error as StdError,
fmt,
hash::Hash,
io,
marker::PhantomData,
net::SocketAddr,
pin::Pin,
time::{Instant, SystemTime},
};
@@ -35,6 +35,14 @@ use tokio_timer::timeout;
use trace::{self, TraceId};
mod filter;
#[cfg(test)]
mod testing;
mod throttle;
pub use self::{
filter::ChannelFilter,
throttle::{Throttler, ThrottlerStream},
};
/// Manages clients, serving multiplexed requests over each connection.
#[derive(Debug)]
@@ -53,17 +61,6 @@ impl<Req, Resp> Default for Server<Req, Resp> {
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct Config {
/// The maximum number of clients that can be connected to the server at once. When at the
/// limit, existing connections are honored and new connections are rejected.
pub max_connections: usize,
/// The maximum number of clients per IP address that can be connected to the server at once.
/// When an IP is at the limit, existing connections are honored and new connections on that IP
/// address are rejected.
pub max_connections_per_ip: usize,
/// The maximum number of requests that can be in flight for each client. When a client is at
/// the in-flight request limit, existing requests are fulfilled and new requests are rejected.
/// Rejected requests are sent a response error.
pub max_in_flight_requests_per_connection: usize,
/// The number of responses per client that can be buffered server-side before being sent.
/// `pending_response_buffer` controls the buffer size of the channel that a server's
/// response tasks use to send responses to the client handler task.
@@ -73,14 +70,21 @@ pub struct Config {
impl Default for Config {
fn default() -> Self {
Config {
max_connections: 1_000_000,
max_connections_per_ip: 1_000,
max_in_flight_requests_per_connection: 1_000,
pending_response_buffer: 100,
}
}
}
impl Config {
/// Returns a channel backed by `transport` and configured with `self`.
pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send,
{
BaseChannel::new(self, transport)
}
}
/// Returns a new server with configuration specified `config`.
pub fn new<Req, Resp>(config: Config) -> Server<Req, Resp> {
Server {
@@ -95,18 +99,15 @@ impl<Req, Resp> Server<Req, Resp> {
&self.config
}
/// Returns a stream of the incoming connections to the server.
pub fn incoming<S, T>(
self,
listener: S,
) -> impl Stream<Item = io::Result<Channel<Req, Resp, T>>>
/// Returns a stream of server channels.
pub fn incoming<S, T>(self, listener: S) -> impl Stream<Item = BaseChannel<Req, Resp, T>>
where
Req: Send,
Resp: Send,
S: Stream<Item = io::Result<T>>,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
S: Stream<Item = T>,
T: Transport<Response<Resp>, ClientMessage<Req>> + Send,
{
self::filter::ConnectionFilter::filter(listener, self.config.clone())
listener.map(move |t| BaseChannel::new(self.config.clone(), t))
}
}
@@ -122,31 +123,21 @@ impl<S, F> Running<S, F> {
unsafe_unpinned!(request_handler: F);
}
impl<S, T, Req, Resp, F, Fut> Future for Running<S, F>
impl<S, C, F, Fut> Future for Running<S, F>
where
S: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send + 'static,
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
S: Sized + Stream<Item = C>,
C: Channel + Send + 'static,
F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<C::Resp>> + Send + 'static,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) {
match channel {
Ok(channel) => {
let peer = channel.client_addr;
if let Err(e) =
crate::spawn(channel.respond_with(self.as_mut().request_handler().clone()))
{
warn!("[{}] Failed to spawn connection handler: {:?}", peer, e);
}
}
Err(e) => {
warn!("Incoming connection error: {}", e);
}
if let Err(e) =
crate::spawn(channel.respond_with(self.as_mut().request_handler().clone()))
{
warn!("Failed to spawn channel handler: {:?}", e);
}
}
info!("Server shutting down.");
@@ -155,18 +146,30 @@ where
}
/// A utility trait enabling a stream to fluently chain a request handler.
pub trait Handler<T, Req, Resp>
pub trait Handler<C>
where
Self: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send,
Resp: Send,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
Self: Sized + Stream<Item = C>,
C: Channel,
{
/// Enforces channel per-key limits.
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> filter::ChannelFilter<Self, K, KF>
where
K: fmt::Display + Eq + Hash + Clone + Unpin,
KF: Fn(&C) -> K,
{
ChannelFilter::new(self, n, keymaker)
}
/// Caps the number of concurrent requests per channel.
fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream<Self> {
ThrottlerStream::new(self, n)
}
/// Responds to all requests with `request_handler`.
fn respond_with<F, Fut>(self, request_handler: F) -> Running<Self, F>
where
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<C::Resp>> + Send + 'static,
{
Running {
incoming: self,
@@ -175,191 +178,276 @@ where
}
}
impl<T, Req, Resp, S> Handler<T, Req, Resp> for S
impl<S, C> Handler<C> for S
where
S: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send,
Resp: Send,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
S: Sized + Stream<Item = C>,
C: Channel,
{
}
/// Responds to all requests with `request_handler`.
/// The server end of an open connection with a client.
/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
#[derive(Debug)]
pub struct Channel<Req, Resp, T> {
pub struct BaseChannel<Req, Resp, T> {
config: Config,
/// Writes responses to the wire and reads requests off the wire.
transport: Fuse<T>,
/// Signals the connection is closed when `Channel` is dropped.
closed_connections: mpsc::UnboundedSender<SocketAddr>,
/// Channel limits to prevent unlimited resource usage.
config: Config,
/// The address of the server connected to.
client_addr: SocketAddr,
/// Number of requests currently being responded to.
in_flight_requests: FnvHashMap<u64, AbortHandle>,
/// Types the request and response.
ghost: PhantomData<(Req, Resp)>,
}
impl<Req, Resp, T> Drop for Channel<Req, Resp, T> {
fn drop(&mut self) {
trace!("[{}] Closing channel.", self.client_addr);
impl<Req, Resp, T> BaseChannel<Req, Resp, T> {
unsafe_unpinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
}
// Even in a bounded channel, each connection would have a guaranteed slot, so using
// an unbounded sender is actually no different. And, the bound is on the maximum number
// of open connections.
if self
.closed_connections
.unbounded_send(self.client_addr)
.is_err()
{
warn!(
"[{}] Failed to send closed connection message.",
self.client_addr
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send,
{
/// Creates a new channel backed by `transport` and configured with `config`.
pub fn new(config: Config, transport: T) -> Self {
BaseChannel {
config,
transport: transport.fuse(),
in_flight_requests: FnvHashMap::default(),
ghost: PhantomData,
}
}
/// Creates a new channel backed by `transport` and configured with the defaults.
pub fn with_defaults(transport: T) -> Self {
Self::new(Config::default(), transport)
}
/// Returns the inner transport.
pub fn get_ref(&self) -> &T {
self.transport.get_ref()
}
/// Returns the pinned inner transport.
pub fn transport<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> {
unsafe { self.map_unchecked_mut(|me| me.transport.get_mut()) }
}
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
// It's possible the request was already completed, so it's fine
// if this is None.
if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) {
self.as_mut().in_flight_requests().compact(0.1);
cancel_handle.abort();
let remaining = self.as_mut().in_flight_requests().len();
trace!(
"[{}] Request canceled. In-flight requests = {}",
trace_context.trace_id,
remaining,
);
} else {
trace!(
"[{}] Received cancellation, but response handler \
is already complete.",
trace_context.trace_id,
);
}
}
}
impl<Req, Resp, T> Channel<Req, Resp, T> {
unsafe_pinned!(transport: Fuse<T>);
}
impl<Req, Resp, T> Channel<Req, Resp, T>
/// The server end of an open connection with a client, streaming in requests from, and sinking
/// responses to, the client.
///
/// Channels are free to somewhat rely on the assumption that all in-flight requests are eventually
/// either [cancelled](Channel::cancel_request) or [responded to](Sink::start_send). Safety cannot
/// rely on this assumption, but it is best for `Channel` users to always account for all outstanding
/// requests.
pub trait Channel
where
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
Req: Send,
Resp: Send,
Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
{
pub(crate) fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> io::Result<()> {
self.as_mut().transport().start_send(response)
/// Type of request item.
type Req: Send + 'static;
/// Type of response sink item.
type Resp: Send + 'static;
/// Configuration of the channel.
fn config(&self) -> &Config;
/// Returns the number of in-flight requests over this channel.
fn in_flight_requests(self: Pin<&mut Self>) -> usize;
/// Caps the number of concurrent requests.
fn max_concurrent_requests(self, n: usize) -> Throttler<Self>
where
Self: Sized,
{
Throttler::new(self, n)
}
pub(crate) fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
self.as_mut().transport().poll_ready(cx)
}
pub(crate) fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
self.as_mut().transport().poll_flush(cx)
}
pub(crate) fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<ClientMessage<Req>> {
self.as_mut().transport().poll_next(cx)
}
/// Returns the address of the client connected to the channel.
pub fn client_addr(&self) -> &SocketAddr {
&self.client_addr
}
/// Tells the Channel that request with ID `request_id` is being handled.
/// The request will be tracked until a response with the same ID is sent
/// to the Channel.
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration;
/// Respond to requests coming over the channel with `f`. Returns a future that drives the
/// responses and resolves when the connection is closed.
pub fn respond_with<F, Fut>(self, f: F) -> impl Future<Output = ()>
fn respond_with<F, Fut>(self, f: F) -> ResponseHandler<Self, F>
where
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
Req: 'static,
Resp: 'static,
F: FnOnce(context::Context, Self::Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Self::Resp>> + Send + 'static,
Self: Sized,
{
let (responses_tx, responses) = mpsc::channel(self.config.pending_response_buffer);
let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
let responses = responses.fuse();
let peer = self.client_addr;
ClientHandler {
ResponseHandler {
channel: self,
f,
pending_responses: responses,
responses_tx,
in_flight_requests: FnvHashMap::default(),
}
.unwrap_or_else(move |e| {
info!("[{}] ClientHandler errored out: {}", peer, e);
})
}
}
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send + 'static,
Req: Send + 'static,
Resp: Send + 'static,
{
type Item = io::Result<Request<Req>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
loop {
match ready!(self.as_mut().transport().poll_next(cx)?) {
Some(message) => match message {
ClientMessage::Request(request) => {
return Poll::Ready(Some(Ok(request)));
}
ClientMessage::Cancel {
trace_context,
request_id,
} => {
self.as_mut().cancel_request(&trace_context, request_id);
}
},
None => return Poll::Ready(None),
}
}
}
}
impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send + 'static,
Req: Send + 'static,
Resp: Send + 'static,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.transport().poll_ready(cx)
}
fn start_send(
mut self: Pin<&mut Self>,
response: Response<Resp>,
) -> Result<(), Self::Error> {
if self
.as_mut()
.in_flight_requests()
.remove(&response.request_id)
.is_some()
{
self.as_mut().in_flight_requests().compact(0.1);
}
self.transport().start_send(response)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.transport().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.transport().poll_close(cx)
}
}
impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
fn as_ref(&self) -> &T {
self.transport.get_ref()
}
}
impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send + 'static,
Req: Send + 'static,
Resp: Send + 'static,
{
type Req = Req;
type Resp = Resp;
fn config(&self) -> &Config {
&self.config
}
fn in_flight_requests(mut self: Pin<&mut Self>) -> usize {
self.as_mut().in_flight_requests().len()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
let (abort_handle, abort_registration) = AbortHandle::new_pair();
assert!(self
.in_flight_requests()
.insert(request_id, abort_handle)
.is_none());
abort_registration
}
}
/// A running handler serving all requests coming over a channel.
#[derive(Debug)]
struct ClientHandler<Req, Resp, T, F> {
channel: Channel<Req, Resp, T>,
pub struct ResponseHandler<C, F>
where
C: Channel,
{
channel: C,
/// Responses waiting to be written to the wire.
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<Resp>)>>,
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>,
/// Handed out to request handlers to fan in responses.
responses_tx: mpsc::Sender<(context::Context, Response<Resp>)>,
/// Number of requests currently being responded to.
in_flight_requests: FnvHashMap<u64, AbortHandle>,
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
/// Request handler.
f: F,
}
impl<Req, Resp, T, F> ClientHandler<Req, Resp, T, F> {
unsafe_pinned!(channel: Channel<Req, Resp, T>);
unsafe_pinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<Resp>)>>);
unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response<Resp>)>);
impl<C, F> ResponseHandler<C, F>
where
C: Channel,
{
unsafe_pinned!(channel: C);
unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>);
unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>);
// For this to be safe, field f must be private, and code in this module must never
// construct PinMut<F>.
unsafe_unpinned!(f: F);
}
impl<Req, Resp, T, F, Fut> ClientHandler<Req, Resp, T, F>
impl<C, F, Fut> ResponseHandler<C, F>
where
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
C: Channel,
F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<C::Resp>> + Send + 'static,
{
/// If at max in-flight requests, check that there's room to immediately write a throttled
/// response.
fn poll_ready_if_throttling(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
if self.in_flight_requests.len()
>= self.channel.config.max_in_flight_requests_per_connection
{
let peer = self.as_mut().channel().client_addr;
while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? {
info!(
"[{}] In-flight requests at max ({}), and transport is not ready.",
peer,
self.as_mut().in_flight_requests().len(),
);
try_ready!(self.as_mut().channel().poll_flush(cx));
}
}
Poll::Ready(Ok(()))
}
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
ready!(self.as_mut().poll_ready_if_throttling(cx)?);
Poll::Ready(match ready!(self.as_mut().channel().poll_next(cx)?) {
Some(message) => {
match message.message {
ClientMessageKind::Request(request) => {
self.handle_request(message.trace_context, request)?;
}
ClientMessageKind::Cancel { request_id } => {
self.cancel_request(&message.trace_context, request_id);
}
}
Some(Ok(()))
match ready!(self.as_mut().channel().poll_next(cx)?) {
Some(request) => {
self.handle_request(request)?;
Poll::Ready(Some(Ok(())))
}
None => {
trace!("[{}] Read half closed", self.channel.client_addr);
None
}
})
None => Poll::Ready(None),
}
}
fn pump_write(
@@ -368,7 +456,12 @@ where
read_half_closed: bool,
) -> PollIo<()> {
match self.as_mut().poll_next_response(cx)? {
Poll::Ready(Some((_, response))) => {
Poll::Ready(Some((ctx, response))) => {
trace!(
"[{}] Staging response. In-flight requests = {}.",
ctx.trace_id(),
self.as_mut().channel().in_flight_requests(),
);
self.as_mut().channel().start_send(response)?;
Poll::Ready(Some(Ok(())))
}
@@ -384,7 +477,7 @@ where
// Being here means there are no staged requests and all written responses are
// fully flushed. So, if the read half is closed and there are no in-flight
// requests, then we can close the write half.
if read_half_closed && self.as_mut().in_flight_requests().is_empty() {
if read_half_closed && self.as_mut().channel().in_flight_requests() == 0 {
Poll::Ready(None)
} else {
Poll::Pending
@@ -396,90 +489,33 @@ where
fn poll_next_response(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<(context::Context, Response<Resp>)> {
) -> PollIo<(context::Context, Response<C::Resp>)> {
// Ensure there's room to write a response.
while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? {
ready!(self.as_mut().channel().poll_flush(cx)?);
}
let peer = self.as_mut().channel().client_addr;
match ready!(self.as_mut().pending_responses().poll_next(cx)) {
Some((ctx, response)) => {
if self
.as_mut()
.in_flight_requests()
.remove(&response.request_id)
.is_some()
{
self.as_mut().in_flight_requests().compact(0.1);
}
trace!(
"[{}/{}] Staging response. In-flight requests = {}.",
ctx.trace_id(),
peer,
self.as_mut().in_flight_requests().len(),
);
Poll::Ready(Some(Ok((ctx, response))))
}
Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
None => {
// This branch likely won't happen, since the ClientHandler is holding a Sender.
trace!("[{}] No new responses.", peer);
// This branch likely won't happen, since the ResponseHandler is holding a Sender.
Poll::Ready(None)
}
}
}
fn handle_request(
mut self: Pin<&mut Self>,
trace_context: trace::Context,
request: Request<Req>,
) -> io::Result<()> {
fn handle_request(mut self: Pin<&mut Self>, request: Request<C::Req>) -> io::Result<()> {
let request_id = request.id;
let peer = self.as_mut().channel().client_addr;
let ctx = context::Context {
deadline: request.deadline,
trace_context,
};
let request = request.message;
if self.as_mut().in_flight_requests().len()
>= self
.as_mut()
.channel()
.config
.max_in_flight_requests_per_connection
{
debug!(
"[{}/{}] Client has reached in-flight request limit ({}/{}).",
ctx.trace_id(),
peer,
self.as_mut().in_flight_requests().len(),
self.as_mut()
.channel()
.config
.max_in_flight_requests_per_connection
);
self.as_mut().channel().start_send(Response {
request_id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()),
}),
})?;
return Ok(());
}
let deadline = ctx.deadline;
let deadline = request.context.deadline;
let timeout = deadline.as_duration();
trace!(
"[{}/{}] Received request with deadline {} (timeout {:?}).",
ctx.trace_id(),
peer,
"[{}] Received request with deadline {} (timeout {:?}).",
request.context.trace_id(),
format_rfc3339(deadline),
timeout,
);
let ctx = request.context;
let request = request.message;
let mut response_tx = self.as_mut().responses_tx().clone();
let trace_id = *ctx.trace_id();
@@ -490,18 +526,19 @@ where
request_id,
message: match result {
Ok(message) => Ok(message),
Err(e) => Err(make_server_error(e, trace_id, peer, deadline)),
Err(e) => Err(make_server_error(e, trace_id, deadline)),
},
};
trace!("[{}/{}] Sending response.", trace_id, peer);
trace!("[{}] Sending response.", trace_id);
response_tx
.send((ctx, response))
.unwrap_or_else(|_| ())
.await;
},
);
let (abortable_response, abort_handle) = abortable(response);
crate::spawn(abortable_response.map(|_| ())).map_err(|e| {
let abort_registration = self.as_mut().channel().start_request(request_id);
let response = Abortable::new(response, abort_registration);
crate::spawn(response.map(|_| ())).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!(
@@ -510,92 +547,49 @@ where
),
)
})?;
self.as_mut()
.in_flight_requests()
.insert(request_id, abort_handle);
Ok(())
}
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
// It's possible the request was already completed, so it's fine
// if this is None.
if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) {
self.as_mut().in_flight_requests().compact(0.1);
cancel_handle.abort();
let remaining = self.as_mut().in_flight_requests().len();
trace!(
"[{}/{}] Request canceled. In-flight requests = {}",
trace_context.trace_id,
self.channel.client_addr,
remaining,
);
} else {
trace!(
"[{}/{}] Received cancellation, but response handler \
is already complete.",
trace_context.trace_id,
self.channel.client_addr
);
}
}
}
impl<Req, Resp, T, F, Fut> Future for ClientHandler<Req, Resp, T, F>
impl<C, F, Fut> Future for ResponseHandler<C, F>
where
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
C: Channel,
F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<C::Resp>> + Send + 'static,
{
type Output = io::Result<()>;
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
trace!("[{}] ClientHandler::poll", self.channel.client_addr);
loop {
let read = self.as_mut().pump_read(cx)?;
match (
read,
self.as_mut().pump_write(cx, read == Poll::Ready(None))?,
) {
(Poll::Ready(None), Poll::Ready(None)) => {
info!("[{}] Client disconnected.", self.channel.client_addr);
return Poll::Ready(Ok(()));
}
(read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => {
trace!(
"[{}] read: {:?}, write: {:?}.",
self.channel.client_addr,
read,
write
)
}
(read, write) => {
trace!(
"[{}] read: {:?}, write: {:?} (not ready).",
self.channel.client_addr,
read,
write,
);
return Poll::Pending;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
move || -> Poll<io::Result<()>> {
loop {
let read = self.as_mut().pump_read(cx)?;
match (
read,
self.as_mut().pump_write(cx, read == Poll::Ready(None))?,
) {
(Poll::Ready(None), Poll::Ready(None)) => {
return Poll::Ready(Ok(()));
}
(Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
_ => {
return Poll::Pending;
}
}
}
}
}()
.map(|r| r.unwrap_or_else(|e| info!("ResponseHandler errored out: {}", e)))
}
}
fn make_server_error(
e: timeout::Error<io::Error>,
trace_id: TraceId,
peer: SocketAddr,
deadline: SystemTime,
) -> ServerError {
if e.is_elapsed() {
debug!(
"[{}/{}] Response did not complete before deadline of {}s.",
"[{}] Response did not complete before deadline of {}s.",
trace_id,
peer,
format_rfc3339(deadline)
);
// No point in responding, since the client will have dropped the request.
@@ -608,8 +602,8 @@ fn make_server_error(
}
} else if e.is_timer() {
error!(
"[{}/{}] Response failed because of an issue with a timer: {}",
trace_id, peer, e
"[{}] Response failed because of an issue with a timer: {}",
trace_id, e
);
ServerError {
@@ -623,7 +617,7 @@ fn make_server_error(
detail: Some(e.description().into()),
}
} else {
error!("[{}/{}] Unexpected response failure: {}", trace_id, peer, e);
error!("[{}] Unexpected response failure: {}", trace_id, e);
ServerError {
kind: io::ErrorKind::Other,

125
rpc/src/server/testing.rs Normal file
View File

@@ -0,0 +1,125 @@
use crate::server::{Channel, Config};
use crate::{context, Request, Response};
use fnv::FnvHashSet;
use futures::future::{AbortHandle, AbortRegistration};
use futures::{Sink, Stream};
use futures_test::task::noop_waker_ref;
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::collections::VecDeque;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::SystemTime;
pub(crate) struct FakeChannel<In, Out> {
pub stream: VecDeque<In>,
pub sink: VecDeque<Out>,
pub config: Config,
pub in_flight_requests: FnvHashSet<u64>,
}
impl<In, Out> FakeChannel<In, Out> {
unsafe_pinned!(stream: VecDeque<In>);
unsafe_pinned!(sink: VecDeque<Out>);
unsafe_unpinned!(in_flight_requests: FnvHashSet<u64>);
}
impl<In, Out> Stream for FakeChannel<In, Out>
where
In: Unpin,
{
type Item = In;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.stream().poll_next(cx)
}
}
impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.sink().poll_ready(cx).map_err(|e| match e {})
}
fn start_send(
mut self: Pin<&mut Self>,
response: Response<Resp>,
) -> Result<(), Self::Error> {
self.as_mut()
.in_flight_requests()
.remove(&response.request_id);
self.sink().start_send(response).map_err(|e| match e {})
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.sink().poll_flush(cx).map_err(|e| match e {})
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.sink().poll_close(cx).map_err(|e| match e {})
}
}
impl<Req, Resp> Channel for FakeChannel<io::Result<Request<Req>>, Response<Resp>>
where
Req: Unpin + Send + 'static,
Resp: Send + 'static,
{
type Req = Req;
type Resp = Resp;
fn config(&self) -> &Config {
&self.config
}
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
self.in_flight_requests.len()
}
fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration {
self.in_flight_requests().insert(id);
AbortHandle::new_pair().1
}
}
impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
pub fn push_req(&mut self, id: u64, message: Req) {
self.stream.push_back(Ok(Request {
context: context::Context {
deadline: SystemTime::UNIX_EPOCH,
trace_context: Default::default(),
},
id,
message,
}));
}
}
impl FakeChannel<(), ()> {
pub fn default<Req, Resp>() -> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
FakeChannel {
stream: VecDeque::default(),
sink: VecDeque::default(),
config: Config::default(),
in_flight_requests: FnvHashSet::default(),
}
}
}
pub trait PollExt {
fn is_done(&self) -> bool;
}
impl<T> PollExt for Poll<Option<T>> {
fn is_done(&self) -> bool {
match self {
Poll::Ready(None) => true,
_ => false,
}
}
}
pub fn cx() -> Context<'static> {
Context::from_waker(&noop_waker_ref())
}

332
rpc/src/server/throttle.rs Normal file
View File

@@ -0,0 +1,332 @@
use super::{Channel, Config};
use crate::{Response, ServerError};
use futures::{
future::AbortRegistration,
prelude::*,
ready,
task::{Context, Poll},
};
use log::debug;
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::{io, pin::Pin};
/// A [`Channel`] that limits the number of concurrent
/// requests by throttling.
#[derive(Debug)]
pub struct Throttler<C> {
max_in_flight_requests: usize,
inner: C,
}
impl<C> Throttler<C> {
unsafe_unpinned!(max_in_flight_requests: usize);
unsafe_pinned!(inner: C);
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
}
}
impl<C> Throttler<C>
where
C: Channel,
{
/// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to
/// `max_in_flight_requests`.
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
Throttler {
inner,
max_in_flight_requests,
}
}
}
impl<C> Stream for Throttler<C>
where
C: Channel,
{
type Item = <C as Stream>::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
while self.as_mut().in_flight_requests() >= *self.as_mut().max_in_flight_requests() {
ready!(self.as_mut().inner().poll_ready(cx)?);
match ready!(self.as_mut().inner().poll_next(cx)?) {
Some(request) => {
debug!(
"[{}] Client has reached in-flight request limit ({}/{}).",
request.context.trace_id(),
self.as_mut().in_flight_requests(),
self.as_mut().max_in_flight_requests(),
);
self.as_mut().start_send(Response {
request_id: request.id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()),
}),
})?;
}
None => return Poll::Ready(None),
}
}
self.inner().poll_next(cx)
}
}
impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
where
C: Channel,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.inner().poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
self.inner().start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.inner().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.inner().poll_close(cx)
}
}
impl<C> AsRef<C> for Throttler<C> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C> Channel for Throttler<C>
where
C: Channel,
{
type Req = <C as Channel>::Req;
type Resp = <C as Channel>::Resp;
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
self.inner().in_flight_requests()
}
fn config(&self) -> &Config {
self.inner.config()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
self.inner().start_request(request_id)
}
}
/// A stream of throttling channels.
#[derive(Debug)]
pub struct ThrottlerStream<S> {
inner: S,
max_in_flight_requests: usize,
}
impl<S> ThrottlerStream<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
unsafe_pinned!(inner: S);
unsafe_unpinned!(max_in_flight_requests: usize);
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
Self {
inner,
max_in_flight_requests,
}
}
}
impl<S> Stream for ThrottlerStream<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
type Item = Throttler<<S as Stream>::Item>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match ready!(self.as_mut().inner().poll_next(cx)) {
Some(channel) => Poll::Ready(Some(Throttler::new(
channel,
*self.max_in_flight_requests(),
))),
None => Poll::Ready(None),
}
}
}
#[cfg(test)]
use super::testing::{self, FakeChannel, PollExt};
#[cfg(test)]
use crate::Request;
#[cfg(test)]
use pin_utils::pin_mut;
#[cfg(test)]
use std::marker::PhantomData;
#[test]
fn throttler_in_flight_requests() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
for i in 0..5 {
throttler.inner.in_flight_requests.insert(i);
}
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
}
#[test]
fn throttler_start_request() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.as_mut().start_request(1);
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
}
#[test]
fn throttler_poll_next_done() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
}
#[test]
fn throttler_poll_next_some() -> io::Result<()> {
let throttler = Throttler {
max_in_flight_requests: 1,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(0, 1);
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
assert_eq!(
throttler
.as_mut()
.poll_next(&mut testing::cx())?
.map(|r| r.map(|r| (r.id, r.message))),
Poll::Ready(Some((0, 1)))
);
Ok(())
}
#[test]
fn throttler_poll_next_throttled() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(1, 1);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
assert_eq!(throttler.inner.sink.len(), 1);
let resp = throttler.inner.sink.get(0).unwrap();
assert_eq!(resp.request_id, 1);
assert!(resp.message.is_err());
}
#[test]
fn throttler_poll_next_throttled_sink_not_ready() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: PendingSink::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
struct PendingSink<In, Out> {
ghost: PhantomData<fn(Out) -> In>,
}
impl PendingSink<(), ()> {
pub fn default<Req, Resp>() -> PendingSink<io::Result<Request<Req>>, Response<Resp>> {
PendingSink { ghost: PhantomData }
}
}
impl<In, Out> Stream for PendingSink<In, Out> {
type Item = In;
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
unimplemented!()
}
}
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
}
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>>
where
Req: Send + 'static,
Resp: Send + 'static,
{
type Req = Req;
type Resp = Resp;
fn config(&self) -> &Config {
unimplemented!()
}
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
0
}
fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration {
unimplemented!()
}
}
}
#[test]
fn throttler_start_send() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.in_flight_requests.insert(0);
throttler
.as_mut()
.start_send(Response {
request_id: 0,
message: Ok(1),
})
.unwrap();
assert!(throttler.inner.in_flight_requests.is_empty());
assert_eq!(
throttler.inner.sink.get(0),
Some(&Response {
request_id: 0,
message: Ok(1),
})
);
}

View File

@@ -6,14 +6,11 @@
//! Transports backed by in-memory channels.
use crate::{PollIo, Transport};
use crate::PollIo;
use futures::{channel::mpsc, task::Context, Poll, Sink, Stream};
use pin_utils::unsafe_pinned;
use std::io;
use std::pin::Pin;
use std::{
io,
net::{IpAddr, Ipv4Addr, SocketAddr},
};
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
/// [`Sink`].
@@ -51,7 +48,7 @@ impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
}
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
type SinkError = io::Error;
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.tx()
@@ -65,7 +62,7 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::SinkError>> {
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx()
.poll_flush(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
@@ -78,19 +75,6 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
}
}
impl<Item, SinkItem> Transport for UnboundedChannel<Item, SinkItem> {
type SinkItem = SinkItem;
type Item = Item;
fn peer_addr(&self) -> io::Result<SocketAddr> {
Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
}
fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
}
}
#[cfg(test)]
mod tests {
use crate::{
@@ -110,7 +94,7 @@ mod tests {
let (client_channel, server_channel) = transport::channel::unbounded();
let server = Server::<String, u64>::default()
.incoming(stream::once(future::ready(Ok(server_channel))))
.incoming(stream::once(future::ready(server_channel)))
.respond_with(|_ctx, request| {
future::ready(request.parse::<u64>().map_err(|_| {
io::Error::new(

View File

@@ -10,114 +10,10 @@
//! can be plugged in, using whatever protocol it wants.
use futures::prelude::*;
use std::{
io,
marker::PhantomData,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
};
use std::io;
pub mod channel;
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
pub trait Transport
where
Self: Stream<Item = io::Result<<Self as Transport>::Item>>,
Self: Sink<<Self as Transport>::SinkItem, SinkError = io::Error>,
{
/// The type read off the transport.
type Item;
/// The type written to the transport.
type SinkItem;
/// The address of the remote peer this transport is in communication with.
fn peer_addr(&self) -> io::Result<SocketAddr>;
/// The address of the local half of this transport.
fn local_addr(&self) -> io::Result<SocketAddr>;
}
/// Returns a new Transport backed by the given Stream + Sink and connecting addresses.
pub fn new<S, SinkItem, Item>(
inner: S,
peer_addr: SocketAddr,
local_addr: SocketAddr,
) -> impl Transport<Item = Item, SinkItem = SinkItem>
where
S: Stream<Item = io::Result<Item>>,
S: Sink<SinkItem, SinkError = io::Error>,
{
TransportShim {
inner,
peer_addr,
local_addr,
_marker: PhantomData,
}
}
/// A transport created by adding peers to a Stream + Sink.
#[derive(Debug)]
struct TransportShim<S, SinkItem> {
peer_addr: SocketAddr,
local_addr: SocketAddr,
inner: S,
_marker: PhantomData<SinkItem>,
}
impl<S, SinkItem> TransportShim<S, SinkItem> {
pin_utils::unsafe_pinned!(inner: S);
}
impl<S, SinkItem> Stream for TransportShim<S, SinkItem>
where
S: Stream,
{
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<S::Item>> {
self.inner().poll_next(cx)
}
}
impl<S, Item> Sink<Item> for TransportShim<S, Item>
where
S: Sink<Item>,
{
type SinkError = S::SinkError;
fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), S::SinkError> {
self.inner().start_send(item)
}
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::SinkError>> {
self.inner().poll_ready(cx)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::SinkError>> {
self.inner().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::SinkError>> {
self.inner().poll_close(cx)
}
}
impl<S, SinkItem, Item> Transport for TransportShim<S, SinkItem>
where
S: Stream + Sink<SinkItem>,
Self: Stream<Item = io::Result<Item>>,
Self: Sink<SinkItem, SinkError = io::Error>,
{
type Item = Item;
type SinkItem = SinkItem;
/// The address of the remote peer this transport is in communication with.
fn peer_addr(&self) -> io::Result<SocketAddr> {
Ok(self.peer_addr)
}
/// The address of the local half of this transport.
fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(self.local_addr)
}
}
pub trait Transport<SinkItem, Item> =
Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error>;

View File

@@ -38,9 +38,11 @@ where
H: BuildHasher,
{
fn compact(&mut self, usage_ratio_threshold: f64) {
let usage_ratio = self.len() as f64 / self.capacity() as f64;
if usage_ratio < usage_ratio_threshold {
self.shrink_to_fit();
if self.capacity() > 1000 {
let usage_ratio = self.len() as f64 / self.capacity() as f64;
if usage_ratio < usage_ratio_threshold {
self.shrink_to_fit();
}
}
}
}

View File

@@ -19,7 +19,7 @@ serde1 = ["rpc/serde1", "serde", "serde/derive"]
travis-ci = { repository = "google/tarpc" }
[dependencies]
futures-preview = { version = "0.3.0-alpha.16", features = ["compat"] }
futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] }
log = "0.4"
serde = { optional = true, version = "1.0" }
rpc = { package = "tarpc-lib", path = "../rpc", version = "0.6" }
@@ -31,7 +31,6 @@ bytes = { version = "0.4", features = ["serde"] }
humantime = "1.0"
bincode-transport = { package = "tarpc-bincode-transport", version = "0.7", path = "../bincode-transport" }
env_logger = "0.6"
libtest = "0.0.1"
tokio = "0.1"
tokio-executor = "0.1"
tokio-tcp = "0.1"

View File

@@ -18,7 +18,7 @@ use futures::{
};
use rpc::{
client, context,
server::{self, Handler, Server},
server::{self, Handler},
};
use std::{
collections::HashMap,
@@ -60,8 +60,9 @@ impl subscriber::Service for Subscriber {
impl Subscriber {
async fn listen(id: u32, config: server::Config) -> io::Result<SocketAddr> {
let incoming = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = incoming.local_addr();
let incoming = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
let addr = incoming.get_ref().local_addr();
tokio_executor::spawn(
server::new(config)
.incoming(incoming)
@@ -140,12 +141,13 @@ impl publisher::Service for Publisher {
async fn run() -> io::Result<()> {
env_logger::init();
let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let publisher_addr = transport.local_addr();
let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
let publisher_addr = transport.get_ref().local_addr();
tokio_executor::spawn(
Server::default()
.incoming(transport)
transport
.take(1)
.map(server::BaseChannel::with_defaults)
.respond_with(publisher::serve(Publisher::new()))
.unit_error()
.boxed()

View File

@@ -13,7 +13,7 @@ use futures::{
};
use rpc::{
client, context,
server::{Handler, Server},
server::{BaseChannel, Channel},
};
use std::io;
@@ -43,20 +43,22 @@ impl Service for HelloServer {
async fn run() -> io::Result<()> {
// bincode_transport is provided by the associated crate bincode-transport. It makes it easy
// to start up a serde-powered bincode serialization strategy over TCP.
let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let mut transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = transport.local_addr();
// The server is configured with the defaults.
let server = Server::default()
// Server can listen on any type that implements the Transport trait.
.incoming(transport)
// Close the stream after the client connects
.take(1)
// For this example, we're just going to wait for one connection.
let client = transport.next().await.unwrap()?;
// `Channel` is a trait representing a server-side connection. It is a trait to allow
// for some channels to be instrumented: for example, to track the number of open connections.
// BaseChannel is the most basic channel, simply wrapping a transport with no added
// functionality.
let server = BaseChannel::with_defaults(client)
// serve is generated by the tarpc::service! macro. It takes as input any type implementing
// the generated Service trait.
.respond_with(serve(HelloServer));
tokio_executor::spawn(server.unit_error().boxed().compat());
tokio::spawn(server.unit_error().boxed().compat());
let transport = bincode_transport::connect(&addr).await?;

View File

@@ -69,8 +69,9 @@ impl DoubleService for DoubleServer {
}
async fn run() -> io::Result<()> {
let add_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = add_listener.local_addr();
let add_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
let addr = add_listener.get_ref().local_addr();
let add_server = Server::default()
.incoming(add_listener)
.take(1)
@@ -80,8 +81,9 @@ async fn run() -> io::Result<()> {
let to_add_server = bincode_transport::connect(&addr).await?;
let add_client = add::new_stub(client::Config::default(), to_add_server).await?;
let double_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = double_listener.local_addr();
let double_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
let addr = double_listener.get_ref().local_addr();
let double_server = rpc::Server::default()
.incoming(double_listener)
.take(1)

View File

@@ -1,9 +1,4 @@
#![feature(
async_await,
arbitrary_self_types,
proc_macro_hygiene,
impl_trait_in_bindings
)]
#![feature(async_await, arbitrary_self_types, proc_macro_hygiene)]
mod registry {
use bytes::Bytes;
@@ -382,8 +377,9 @@ async fn run() -> io::Result<()> {
read_service::serve(server.clone()),
);
let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let server_addr = listener.local_addr();
let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
let server_addr = listener.get_ref().local_addr();
let server = tarpc::Server::default()
.incoming(listener)
.take(1)

View File

@@ -58,13 +58,13 @@ macro_rules! service {
(
$(
$(#[$attr:meta])*
rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) $(-> $out:ty)*;
rpc $fn_name:ident( $( $(#[$argattr:meta])* $arg:ident : $in_:ty ),* ) $(-> $out:ty)*;
)*
) => {
$crate::service! {{
$(
$(#[$attr])*
rpc $fn_name( $( $arg : $in_ ),* ) $(-> $out)*;
rpc $fn_name( $( $(#[$argattr])* $arg : $in_ ),* ) $(-> $out)*;
)*
}}
};
@@ -72,7 +72,7 @@ macro_rules! service {
(
{
$(#[$attr:meta])*
rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* );
rpc $fn_name:ident( $( $(#[$argattr:meta])* $arg:ident : $in_:ty ),* );
$( $unexpanded:tt )*
}
@@ -84,14 +84,14 @@ macro_rules! service {
$( $expanded )*
$(#[$attr])*
rpc $fn_name( $( $arg : $in_ ),* ) -> ();
rpc $fn_name( $( $(#[$argattr])* $arg : $in_ ),* ) -> ();
}
};
// Pattern for when the next rpc has an explicit return type.
(
{
$(#[$attr:meta])*
rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty;
rpc $fn_name:ident( $( $(#[$argattr:meta])* $arg:ident : $in_:ty ),* ) -> $out:ty;
$( $unexpanded:tt )*
}
@@ -103,7 +103,7 @@ macro_rules! service {
$( $expanded )*
$(#[$attr])*
rpc $fn_name( $( $arg : $in_ ),* ) -> $out;
rpc $fn_name( $( $(#[$argattr])* $arg : $in_ ),* ) -> $out;
}
};
// Pattern for when all return types have been expanded
@@ -111,7 +111,7 @@ macro_rules! service {
{ } // none left to expand
$(
$(#[$attr:meta])*
rpc $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty;
rpc $fn_name:ident ( $( $(#[$argattr:meta])* $arg:ident : $in_:ty ),* ) -> $out:ty;
)*
) => {
$crate::add_serde_if_enabled! {
@@ -122,7 +122,7 @@ macro_rules! service {
pub enum Request {
$(
$(#[$attr])*
$fn_name{ $($arg: $in_,)* }
$fn_name{ $( $(#[$argattr])* $arg: $in_,)* }
),*
}
}
@@ -218,9 +218,7 @@ macro_rules! service {
pub async fn new_stub<T>(config: $crate::client::Config, transport: T)
-> ::std::io::Result<Client>
where
T: $crate::Transport<
Item = $crate::Response<Response>,
SinkItem = $crate::ClientMessage<Request>> + Send + 'static,
T: $crate::Transport<$crate::ClientMessage<Request>, $crate::Response<Response>> + Send + 'static,
{
Ok(Client($crate::client::new(config, transport).await?))
}
@@ -321,7 +319,7 @@ mod functional_test {
let (tx, rx) = channel::unbounded();
tokio_executor::spawn(
crate::Server::default()
.incoming(stream::once(ready(Ok(rx))))
.incoming(stream::once(ready(rx)))
.respond_with(serve(Server))
.unit_error()
.boxed()
@@ -350,7 +348,7 @@ mod functional_test {
let (tx, rx) = channel::unbounded();
tokio_executor::spawn(
rpc::Server::default()
.incoming(stream::once(ready(Ok(rx))))
.incoming(stream::once(ready(rx)))
.respond_with(serve(Server))
.unit_error()
.boxed()

View File

@@ -12,8 +12,10 @@
proc_macro_hygiene
)]
extern crate test;
use futures::{compat::Executor01CompatExt, future, prelude::*};
use libtest::stats::Stats;
use test::stats::Stats;
use rpc::{
client, context,
server::{Handler, Server},
@@ -41,8 +43,9 @@ impl ack::Service for Serve {
}
async fn bench() -> io::Result<()> {
let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?;
let addr = listener.local_addr();
let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?
.filter_map(|r| future::ready(r.ok()));
let addr = listener.get_ref().local_addr();
tokio_executor::spawn(
Server::default()

View File

@@ -13,7 +13,7 @@ readme = "../README.md"
description = "foundations for tracing in tarpc"
[dependencies]
rand = "0.6"
rand = "0.7"
[dependencies.serde]
version = "1.0"

View File

@@ -26,7 +26,7 @@ use std::{
///
/// Consists of a span identifying an event, an optional parent span identifying a causal event
/// that triggered the current span, and a trace with which all related spans are associated.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Context {
/// An identifier of the trace associated with the current context. A trace ID is typically
@@ -46,12 +46,12 @@ pub struct Context {
/// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the
/// same trace ID.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TraceId(u128);
/// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SpanId(u64);