33 Commits

Author SHA1 Message Date
Tim Kuehn
9e4c3a2b8b Fix poll_expired incorrectly returning Pending when there are no outstanding requests. 2021-10-08 22:01:57 -07:00
Tim Kuehn
d78b24b631 Assert that poll_expired yields None when DelayQueue is empty 2021-10-08 21:28:43 -07:00
Tim Kuehn
49900d7a35 Close TcpStream when client disconnects 2021-10-08 20:10:16 -07:00
Tim Kuehn
1e680e3a5a Fix typos in docs.
Fixes https://github.com/google/tarpc/issues/352.
2021-10-08 19:19:50 -07:00
Tim Kuehn
2591d21e94 Update release notes to mention io::Error = 2021-09-23 13:57:43 -07:00
Tim Kuehn
6632f68d95 Prepare for 0.27 release 2021-09-22 15:41:34 -07:00
Dmitry Kakurin
25985ad56a Update README.md (#350)
Fixed 2 typos
2021-09-01 17:58:49 -07:00
Tim Kuehn
d6a24e9420 Address Clippy lint 2021-08-24 12:40:18 -07:00
Tim Kuehn
281a78f3c7 Add tokio-serde-bincode feature 2021-08-24 12:37:57 -07:00
Julian Tescher
a0787d0091 Update to opentelemetry 0.16.x (#349) 2021-08-17 00:00:07 -04:00
Frederik-Baetens
d2acba0e8a add serde-transport-json feature flag (#346)
In general, it should be possible to use, or at least import all functionality of a library, when having only that library in your cargo.toml.
2021-05-06 08:41:57 -07:00
Tim Kuehn
ea7b6763c4 Refactor server module.
In the interest of the user's attention, some ancillary APIs have been
moved to new submodules:

- server::limits contains what was previously called Throttler and
  ChannelFilter. Both of those names were very generic, when the methods
  applied by these types were very specific (and also simplistic). Renames
  have occurred:
  - ThrottlerStream => MaxRequestsPerChannel
  - Throttler => MaxRequests
  - ChannelFilter => MaxChannelsPerKey
- server::incoming contains the Incoming trait.
- server::tokio contains the tokio-specific helper types.

The 5 structs and 1 enum remaining in the base server module are all
core to the functioning of the server.
2021-04-21 17:05:49 -07:00
Tim Kuehn
eb67c540b9 Use more structured errors in client. 2021-04-21 14:54:45 -07:00
Tim Kuehn
4151d0abd3 Move Span creation into BaseChannel.
It's important for Channel decorators, like Throttler, to have access to
the Span. This means that the BaseChannel becomes responsible for
starting its own requests. Actually, this simplifies the integration for
the Channel users, as they can assume any yielded requests are already
tracked.

This entails the following breaking changes:

- removed trait method Channel::start_request as it is now done
  internally.
2021-04-21 14:54:45 -07:00
Tim Kuehn
d0c11a6efa Change RPC error type from io::Error => RpcError.
Becaue tarpc is a library, not an application, it should strive to
use structured errors in its API so that users have maximal flexibility
in how they handle errors. io::Error makes that hard, because it is a
kitchen-sink error type.

RPCs in particular only have 3 classes of errors:

- The connection breaks.
- The request expires.
- The server decides not to process the request.

(Of course, RPCs themselves can have application-specific errors, but
from the perspective of the RPC library, those can be classified as
successful responsees).
2021-04-20 18:29:55 -07:00
Tim Kuehn
82c4da1743 Prepare release of v0.26.2 2021-04-20 11:28:15 -07:00
Tim Kuehn
0a15e0b75c Rustdoc: link RPC futures to their methods. 2021-04-20 11:25:26 -07:00
Tim Kuehn
0b315c29bf It's not currently possible to document the enum variants, which means
projects that #[deny(missing_docs)] wouldn't compile if using tarpc
services.
2021-04-20 09:01:39 -07:00
Tim Kuehn
56f09bf61f Fix log that's split across lines. 2021-04-17 17:15:16 -07:00
Tim Kuehn
6d82e82419 Fix formatting 2021-04-16 16:51:21 -07:00
Tim Kuehn
9bebaf814a Address clippy lint 2021-04-14 17:49:27 -07:00
Tim Kuehn
5f4d6e6008 Prepare release of v0.26.0 2021-04-14 17:08:44 -07:00
Tim Kuehn
07d07d7ba3 Remove tracing_appender as it does not support build target mipsel-unknown-linux-gnu 2021-04-01 19:37:02 -07:00
Tim Kuehn
a41bbf65b2 Use rustfmt instead of cargo fmt so that diff is only printed once 2021-04-01 17:24:34 -07:00
Tim Kuehn
21e2f7ca62 Tear out requirement that Transport's error type is io::Error. 2021-04-01 17:24:34 -07:00
Tim Kuehn
7b7c182411 Instrument tarpc with tracing.
tarpc is now instrumented with tracing primitives extended with
OpenTelemetry traces. Using a compatible tracing-opentelemetry
subscriber like Jaeger, each RPC can be traced through the client,
server, amd other dependencies downstream of the server. Even for
applications not connected to a distributed tracing collector, the
instrumentation can also be ingested by regular loggers like env_logger.

 # Breaking Changes

 ## Logging

Logged events are now structured using tracing. For applications using a
logger and not a tracing subscriber, these logs may look different or
contain information in a less consumable manner. The easiest solution is
to add a tracing subscriber that logs to stdout, such as
tracing_subscriber::fmt.

 ##  Context

- Context no longer has parent_span, which was actually never needed,
  because the context sent in an RPC is inherently the parent context.
  For purposes of distributed tracing, the client side of the RPC has all
  necessary information to link the span to its parent; the server side
  need do nothing more than export the (trace ID, span ID) tuple.
- Context has a new field, SamplingDecision, which has two variants,
  Sampled and Unsampled. This field can be used by downstream systems to
  determine whether a trace needs to be exported. If the parent span is
  sampled, the expectation is that all child spans be exported, as well;
  to do otherwise could result in lossy traces being exported. Note that
  if an Openetelemetry tracing subscriber is not installed, the fallback
  context will still be used, but the Context's sampling decision will
  always be inherited by the parent Context's sampling decision.
- Context::scope has been removed. Context propagation is now done via
  tracing's task-local spans. Spans can be propagated across tasks via
  Span::in_scope. When a service receives a request, it attaches an
  Opentelemetry context to the local Span created before request handling,
  and this context contains the request deadline. This span-local deadline
  is retrieved by Context::current, but it cannot be modified so that
  future Context::current calls contain a different deadline. However, the
  deadline in the context passed into an RPC call will override it, so
  users can retrieve the current context and then modify the deadline
  field, as has been historically possible.
- Context propgation precedence changes: when an RPC is initiated, the
  current Span's Opentelemetry context takes precedence over the trace
  context passed into the RPC method. If there is no current Span, then
  the trace context argument is used as it has been historically. Note
  that Opentelemetry context propagation requires an Opentelemetry
  tracing subscriber to be installed.

 ## Server

- The server::Channel trait now has an additional required associated
  type and method which returns the underlying transport. This makes it
  more ergonomic for users to retrieve transport-specific information,
  like IP Address. BaseChannel implements Channel::transport by returning
  the underlying transport, and channel decorators like Throttler just
  delegate to the Channel::transport method of the wrapped channel.

 # References

[1] https://github.com/tokio-rs/tracing
[2] https://opentelemetry.io
[3] https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
[4] https://github.com/env-logger-rs/env_logger
2021-04-01 17:24:34 -07:00
Ben Ludewig
db0c778ead Serialize u128 TraceId as LE bytes (#344) 2021-03-30 08:41:19 -07:00
Tim Kuehn
c3efb83ac1 Add more context to errors returned by serde transport 2021-03-28 20:03:03 -07:00
Tim Kuehn
3d7b0171fe Fix cargo fmt portion of pre-commit 2021-03-26 19:39:56 -07:00
oblique
c191ff5b2e Do not enable tokio-serde/json by default (#345) 2021-03-26 18:22:44 -07:00
Tim Kuehn
90bc7f741d Fix up imports 2021-03-17 12:44:39 -07:00
Kitsu
d3f6c01df2 Reduce required tokio features (#343)
* Move async tests behind cfg-ed mod
* Use explicit tokio features for the example
* Use only relative crate path for example dependency
2021-03-17 12:30:18 -07:00
Tim Kuehn
c6450521e6 Add method to run a future in the current context.
Previously, `Context::current` would always return a new context. Now,
it uses tokio task-local data to look for the current context. Tokio
task locals are not actually tied to a tokio executor; instead, they
provide data scoped to a future.

The basic pattern is:

```rust
let ctx = Context::new_root();
ctx.scope(async {
    let ctx2 = context::current();
    assert_eq!(ctx2.trace_context.span_id, ctx.trace_context.span_id);
});
```

`server::InFlightRequest::execute` uses `Context::scope` to set the
current context before executing a request, so calls to
`context::current` in request handlers will return the context provided
by the client. This does not propagate to new spawned tasks. To
propagate the client context to child tasks, the following pattern will
work:

```rust
tokio::spawn(context::current().scope(async { /* do work here */ }));
```

This commit also introduces a breaking change to Context serialization.
Previously, the deadline only serialized second-level precision. Now, it
provides full fidelity serialization to the nanosecond.
2021-03-13 16:05:02 -08:00
36 changed files with 2381 additions and 1660 deletions

View File

@@ -1,7 +1,11 @@
[workspace] [workspace]
resolver = "2"
members = [ members = [
"example-service", "example-service",
"tarpc", "tarpc",
"plugins", "plugins",
] ]
[profile.dev]
split-debuginfo = "unpacked"

View File

@@ -51,6 +51,14 @@ Some other features of tarpc:
requests sent by the server that use the request context will propagate the request deadline. requests sent by the server that use the request context will propagate the request deadline.
For example, if a server is handling a request with a 10s deadline, does 2s of work, then For example, if a server is handling a request with a 10s deadline, does 2s of work, then
sends a request to another server, that server will see an 8s deadline. sends a request to another server, that server will see an 8s deadline.
- Distributed tracing: tarpc is instrumented with
[tracing](https://github.com/tokio-rs/tracing) primitives extended with
[OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
[Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
each RPC can be traced through the client, server, amd other dependencies downstream of the
server. Even for applications not connected to a distributed tracing collector, the
instrumentation can also be ingested by regular loggers like
[env_logger](https://github.com/env-logger-rs/env_logger/).
- Serde serialization: enabling the `serde1` Cargo feature will make service requests and - Serde serialization: enabling the `serde1` Cargo feature will make service requests and
responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
be used, as well, so the price of serialization doesn't have to be paid when it's not needed. be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
@@ -59,7 +67,7 @@ Some other features of tarpc:
Add to your `Cargo.toml` dependencies: Add to your `Cargo.toml` dependencies:
```toml ```toml
tarpc = "0.25" tarpc = "0.27"
``` ```
The `tarpc::service` attribute expands to a collection of items that form an rpc service. The `tarpc::service` attribute expands to a collection of items that form an rpc service.
@@ -72,8 +80,9 @@ This example uses [tokio](https://tokio.rs), so add the following dependencies t
your `Cargo.toml`: your `Cargo.toml`:
```toml ```toml
futures = "1.0" anyhow = "1.0"
tarpc = { version = "0.25", features = ["tokio1"] } futures = "0.3"
tarpc = { version = "0.27", features = ["tokio1"] }
tokio = { version = "1.0", features = ["macros"] } tokio = { version = "1.0", features = ["macros"] }
``` ```
@@ -91,9 +100,8 @@ use futures::{
}; };
use tarpc::{ use tarpc::{
client, context, client, context,
server::{self, Incoming}, server::{self, incoming::Incoming},
}; };
use std::io;
// This is the service definition. It looks a lot like a trait definition. // This is the service definition. It looks a lot like a trait definition.
// It defines one RPC, hello, which takes one arg, name, and returns a String. // It defines one RPC, hello, which takes one arg, name, and returns a String.
@@ -132,7 +140,7 @@ available behind the `tcp` feature.
```rust ```rust
#[tokio::main] #[tokio::main]
async fn main() -> io::Result<()> { async fn main() -> anyhow::Result<()> {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = server::BaseChannel::with_defaults(server_transport); let server = server::BaseChannel::with_defaults(server_transport);
@@ -140,7 +148,7 @@ async fn main() -> io::Result<()> {
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
// that takes a config and any Transport as input. // that takes a config and any Transport as input.
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?; let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same // The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context // args as defined, with the addition of a Context, which is always the first arg. The Context

View File

@@ -1,3 +1,111 @@
## 0.27.1 (2021-09-22)
### Breaking Changes
### RPC error type is changing
RPC return types are changing from `Result<Response, io::Error>` to `Result<Response,
tarpc::client::RpcError>`.
Becaue tarpc is a library, not an application, it should strive to
use structured errors in its API so that users have maximal flexibility
in how they handle errors. io::Error makes that hard, because it is a
kitchen-sink error type.
RPCs in particular only have 3 classes of errors:
- The connection breaks.
- The request expires.
- The server decides not to process the request.
RPC responses can also contain application-specific errors, but from the
perspective of the RPC library, those are opaque to the framework, classified
as successful responsees.
### Open Telemetry
The Opentelemetry dependency is updated to version 0.16.x.
## 0.27.0 (2021-09-22)
This version was yanked due to tarpc-plugins version mismatches.
## 0.26.0 (2021-04-14)
### New Features
#### Tracing
tarpc is now instrumented with tracing primitives extended with
OpenTelemetry traces. Using a compatible tracing-opentelemetry
subscriber like Jaeger, each RPC can be traced through the client,
server, amd other dependencies downstream of the server. Even for
applications not connected to a distributed tracing collector, the
instrumentation can also be ingested by regular loggers like env_logger.
### Breaking Changes
#### Logging
Logged events are now structured using tracing. For applications using a
logger and not a tracing subscriber, these logs may look different or
contain information in a less consumable manner. The easiest solution is
to add a tracing subscriber that logs to stdout, such as
tracing_subscriber::fmt.
#### Context
- Context no longer has parent_span, which was actually never needed,
because the context sent in an RPC is inherently the parent context.
For purposes of distributed tracing, the client side of the RPC has all
necessary information to link the span to its parent; the server side
need do nothing more than export the (trace ID, span ID) tuple.
- Context has a new field, SamplingDecision, which has two variants,
Sampled and Unsampled. This field can be used by downstream systems to
determine whether a trace needs to be exported. If the parent span is
sampled, the expectation is that all child spans be exported, as well;
to do otherwise could result in lossy traces being exported. Note that
if an Openetelemetry tracing subscriber is not installed, the fallback
context will still be used, but the Context's sampling decision will
always be inherited by the parent Context's sampling decision.
- Context::scope has been removed. Context propagation is now done via
tracing's task-local spans. Spans can be propagated across tasks via
Span::in_scope. When a service receives a request, it attaches an
Opentelemetry context to the local Span created before request handling,
and this context contains the request deadline. This span-local deadline
is retrieved by Context::current, but it cannot be modified so that
future Context::current calls contain a different deadline. However, the
deadline in the context passed into an RPC call will override it, so
users can retrieve the current context and then modify the deadline
field, as has been historically possible.
- Context propgation precedence changes: when an RPC is initiated, the
current Span's Opentelemetry context takes precedence over the trace
context passed into the RPC method. If there is no current Span, then
the trace context argument is used as it has been historically. Note
that Opentelemetry context propagation requires an Opentelemetry
tracing subscriber to be installed.
#### Server
- The server::Channel trait now has an additional required associated
type and method which returns the underlying transport. This makes it
more ergonomic for users to retrieve transport-specific information,
like IP Address. BaseChannel implements Channel::transport by returning
the underlying transport, and channel decorators like Throttler just
delegate to the Channel::transport method of the wrapped channel.
#### Client
- NewClient::spawn no longer returns a result, as spawn can't fail.
### References
1. https://github.com/tokio-rs/tracing
2. https://opentelemetry.io
3. https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
4. https://github.com/env-logger-rs/env_logger
## 0.25.0 (2021-03-10) ## 0.25.0 (2021-03-10)
### Breaking Changes ### Breaking Changes
@@ -170,7 +278,7 @@ nameable futures and will just be boxing the return type anyway. This macro does
### Breaking Changes ### Breaking Changes
- Enums had _non_exhaustive fields replaced with the #[non_exhaustive] attribute. - Enums had `_non_exhaustive` fields replaced with the #[non_exhaustive] attribute.
### Bug Fixes ### Bug Fixes

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "tarpc-example-service" name = "tarpc-example-service"
version = "0.9.0" version = "0.10.0"
authors = ["Tim Kuehn <tikue@google.com>"] authors = ["Tim Kuehn <tikue@google.com>"]
edition = "2018" edition = "2018"
license = "MIT" license = "MIT"
@@ -13,12 +13,18 @@ readme = "../README.md"
description = "An example server built on tarpc." description = "An example server built on tarpc."
[dependencies] [dependencies]
clap = "2.33" anyhow = "1.0"
env_logger = "0.8" clap = "3.0.0-beta.2"
log = "0.4"
futures = "0.3" futures = "0.3"
serde = { version = "1.0" } opentelemetry = { version = "0.16", features = ["rt-tokio"] }
tarpc = { version = "0.25", path = "../tarpc", features = ["full"] } opentelemetry-jaeger = { version = "0.15", features = ["rt-tokio"] }
tokio = { version = "1", features = ["full"] } rand = "0.8"
tarpc = { version = "0.27", path = "../tarpc", features = ["full"] }
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
tracing = { version = "0.1" }
tracing-opentelemetry = "0.15"
tracing-subscriber = "0.2"
[lib] [lib]
name = "service" name = "service"

View File

@@ -4,57 +4,49 @@
// license that can be found in the LICENSE file or at // license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT. // https://opensource.org/licenses/MIT.
use clap::{App, Arg}; use clap::Clap;
use std::{io, net::SocketAddr}; use service::{init_tracing, WorldClient};
use std::{net::SocketAddr, time::Duration};
use tarpc::{client, context, tokio_serde::formats::Json}; use tarpc::{client, context, tokio_serde::formats::Json};
use tokio::time::sleep;
use tracing::Instrument;
#[derive(Clap)]
struct Flags {
/// Sets the server address to connect to.
#[clap(long)]
server_addr: SocketAddr,
/// Sets the name to say hello to.
#[clap(long)]
name: String,
}
#[tokio::main] #[tokio::main]
async fn main() -> io::Result<()> { async fn main() -> anyhow::Result<()> {
env_logger::init(); let flags = Flags::parse();
init_tracing("Tarpc Example Client")?;
let flags = App::new("Hello Client") let transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
.version("0.1")
.author("Tim <tikue@google.com>")
.about("Say hello!")
.arg(
Arg::with_name("server_addr")
.long("server_addr")
.value_name("ADDRESS")
.help("Sets the server address to connect to.")
.required(true)
.takes_value(true),
)
.arg(
Arg::with_name("name")
.short("n")
.long("name")
.value_name("STRING")
.help("Sets the name to say hello to.")
.required(true)
.takes_value(true),
)
.get_matches();
let server_addr = flags.value_of("server_addr").unwrap();
let server_addr = server_addr
.parse::<SocketAddr>()
.unwrap_or_else(|e| panic!(r#"--server_addr value "{}" invalid: {}"#, server_addr, e));
let name = flags.value_of("name").unwrap().into();
let mut transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default);
transport.config_mut().max_frame_length(usize::MAX);
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a // WorldClient is generated by the service attribute. It has a constructor `new` that takes a
// config and any Transport as input. // config and any Transport as input.
let client = service::WorldClient::new(client::Config::default(), transport.await?).spawn()?; let client = WorldClient::new(client::Config::default(), transport.await?).spawn();
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same let hello = async move {
// args as defined, with the addition of a Context, which is always the first arg. The Context // Send the request twice, just to be safe! ;)
// specifies a deadline and trace information which can be helpful in debugging requests. tokio::select! {
let hello = client.hello(context::current(), name).await?; hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 }
hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 }
}
}
.instrument(tracing::info_span!("Two Hellos"))
.await;
println!("{}", hello); tracing::info!("{:?}", hello);
// Let the background span processor finish.
sleep(Duration::from_micros(1)).await;
opentelemetry::global::shutdown_tracer_provider();
Ok(()) Ok(())
} }

View File

@@ -4,6 +4,9 @@
// license that can be found in the LICENSE file or at // license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT. // https://opensource.org/licenses/MIT.
use std::env;
use tracing_subscriber::{fmt::format::FmtSpan, prelude::*};
/// This is the service definition. It looks a lot like a trait definition. /// This is the service definition. It looks a lot like a trait definition.
/// It defines one RPC, hello, which takes one arg, name, and returns a String. /// It defines one RPC, hello, which takes one arg, name, and returns a String.
#[tarpc::service] #[tarpc::service]
@@ -11,3 +14,21 @@ pub trait World {
/// Returns a greeting for name. /// Returns a greeting for name.
async fn hello(name: String) -> String; async fn hello(name: String) -> String;
} }
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
pub fn init_tracing(service_name: &str) -> anyhow::Result<()> {
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
let tracer = opentelemetry_jaeger::new_pipeline()
.with_service_name(service_name)
.with_max_packet_size(2usize.pow(13))
.install_batch(opentelemetry::runtime::Tokio)?;
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer().with_span_events(FmtSpan::NEW | FmtSpan::CLOSE))
.with(tracing_opentelemetry::layer().with_tracer(tracer))
.try_init()?;
Ok(())
}

View File

@@ -4,18 +4,30 @@
// license that can be found in the LICENSE file or at // license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT. // https://opensource.org/licenses/MIT.
use clap::{App, Arg}; use clap::Clap;
use futures::{future, prelude::*}; use futures::{future, prelude::*};
use service::World; use rand::{
distributions::{Distribution, Uniform},
thread_rng,
};
use service::{init_tracing, World};
use std::{ use std::{
io, net::{IpAddr, Ipv6Addr, SocketAddr},
net::{IpAddr, SocketAddr}, time::Duration,
}; };
use tarpc::{ use tarpc::{
context, context,
server::{self, Channel, Incoming}, server::{self, incoming::Incoming, Channel},
tokio_serde::formats::Json, tokio_serde::formats::Json,
}; };
use tokio::time;
#[derive(Clap)]
struct Flags {
/// Sets the port number to listen on.
#[clap(long)]
port: u16,
}
// This is the type that implements the generated World trait. It is the business logic // This is the type that implements the generated World trait. It is the business logic
// and is used to start the server. // and is used to start the server.
@@ -25,35 +37,19 @@ struct HelloServer(SocketAddr);
#[tarpc::server] #[tarpc::server]
impl World for HelloServer { impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String { async fn hello(self, _: context::Context, name: String) -> String {
format!("Hello, {}! You are connected from {:?}.", name, self.0) let sleep_time =
Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng()));
time::sleep(sleep_time).await;
format!("Hello, {}! You are connected from {}", name, self.0)
} }
} }
#[tokio::main] #[tokio::main]
async fn main() -> io::Result<()> { async fn main() -> anyhow::Result<()> {
env_logger::init(); let flags = Flags::parse();
init_tracing("Tarpc Example Server")?;
let flags = App::new("Hello Server") let server_addr = (IpAddr::V6(Ipv6Addr::LOCALHOST), flags.port);
.version("0.1")
.author("Tim <tikue@google.com>")
.about("Say hello!")
.arg(
Arg::with_name("port")
.short("p")
.long("port")
.value_name("NUMBER")
.help("Sets the port number to listen on")
.required(true)
.takes_value(true),
)
.get_matches();
let port = flags.value_of("port").unwrap();
let port = port
.parse()
.unwrap_or_else(|e| panic!(r#"--port value "{}" invalid: {}"#, port, e));
let server_addr = (IpAddr::from([0, 0, 0, 0]), port);
// JSON transport is provided by the json_transport tarpc module. It makes it easy // JSON transport is provided by the json_transport tarpc module. It makes it easy
// to start up a serde-powered json serialization strategy over TCP. // to start up a serde-powered json serialization strategy over TCP.
@@ -64,12 +60,12 @@ async fn main() -> io::Result<()> {
.filter_map(|r| future::ready(r.ok())) .filter_map(|r| future::ready(r.ok()))
.map(server::BaseChannel::with_defaults) .map(server::BaseChannel::with_defaults)
// Limit channels to 1 per IP. // Limit channels to 1 per IP.
.max_channels_per_key(1, |t| t.as_ref().peer_addr().unwrap().ip()) .max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip())
// serve is generated by the service attribute. It takes as input any type implementing // serve is generated by the service attribute. It takes as input any type implementing
// the generated World trait. // the generated World trait.
.map(|channel| { .map(|channel| {
let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap()); let server = HelloServer(channel.transport().peer_addr().unwrap());
channel.requests().execute(server.serve()) channel.execute(server.serve())
}) })
// Max 10 channels. // Max 10 channels.
.buffer_unordered(10) .buffer_unordered(10)

View File

@@ -67,7 +67,7 @@ else
fi fi
printf "${PREFIX} Checking for rustfmt ... " printf "${PREFIX} Checking for rustfmt ... "
command -v cargo fmt &>/dev/null command -v rustfmt &>/dev/null
if [ $? == 0 ]; then if [ $? == 0 ]; then
printf "${SUCCESS}\n" printf "${SUCCESS}\n"
else else
@@ -93,19 +93,19 @@ diff=""
for file in $(git diff --name-only --cached); for file in $(git diff --name-only --cached);
do do
if [ ${file: -3} == ".rs" ]; then if [ ${file: -3} == ".rs" ]; then
diff="$diff$(cargo fmt -- --check $file)" diff="$diff$(rustfmt --edition 2018 --check $file)"
if [ $? != 0 ]; then
FMTRESULT=1
fi
fi fi
done done
if grep --quiet "^[-+]" <<< "$diff"; then
FMTRESULT=1
fi
if [ "${TARPC_SKIP_RUSTFMT}" == 1 ]; then if [ "${TARPC_SKIP_RUSTFMT}" == 1 ]; then
printf "${SKIPPED}\n"$? printf "${SKIPPED}\n"$?
elif [ ${FMTRESULT} != 0 ]; then elif [ ${FMTRESULT} != 0 ]; then
FAILED=1 FAILED=1
printf "${FAILURE}\n" printf "${FAILURE}\n"
echo "$diff" | sed 's/Using rustfmt config file.*$/d/' echo "$diff"
else else
printf "${SUCCESS}\n" printf "${SUCCESS}\n"
fi fi

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "tarpc-plugins" name = "tarpc-plugins"
version = "0.10.0" version = "0.12.0"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"] authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2018" edition = "2018"
license = "MIT" license = "MIT"
@@ -30,4 +30,4 @@ proc-macro = true
assert-type-eq = "0.1.0" assert-type-eq = "0.1.0"
futures = "0.3" futures = "0.3"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
tarpc = { path = "../tarpc" } tarpc = { path = "../tarpc", features = ["serde1"] }

View File

@@ -267,18 +267,25 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
None None
}; };
let methods = rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>();
let request_names = methods
.iter()
.map(|m| format!("{}.{}", ident, m))
.collect::<Vec<_>>();
ServiceGenerator { ServiceGenerator {
response_fut_name, response_fut_name,
service_ident: ident, service_ident: ident,
server_ident: &format_ident!("Serve{}", ident), server_ident: &format_ident!("Serve{}", ident),
response_fut_ident: &Ident::new(&response_fut_name, ident.span()), response_fut_ident: &Ident::new(response_fut_name, ident.span()),
client_ident: &format_ident!("{}Client", ident), client_ident: &format_ident!("{}Client", ident),
request_ident: &format_ident!("{}Request", ident), request_ident: &format_ident!("{}Request", ident),
response_ident: &format_ident!("{}Response", ident), response_ident: &format_ident!("{}Response", ident),
vis, vis,
args, args,
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(), method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
method_idents: &rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>(), method_idents: &methods,
request_names: &*request_names,
attrs, attrs,
rpcs, rpcs,
return_types: &rpcs return_types: &rpcs
@@ -399,11 +406,7 @@ fn verify_types_were_provided(
) -> syn::Result<()> { ) -> syn::Result<()> {
let mut result = Ok(()); let mut result = Ok(());
for (method, expected) in expected { for (method, expected) in expected {
if provided if !provided.iter().any(|typedecl| typedecl.ident == expected) {
.iter()
.find(|typedecl| typedecl.ident == expected)
.is_none()
{
let mut e = syn::Error::new( let mut e = syn::Error::new(
span, span,
format!("not all trait items implemented, missing: `{}`", expected), format!("not all trait items implemented, missing: `{}`", expected),
@@ -441,6 +444,7 @@ struct ServiceGenerator<'a> {
camel_case_idents: &'a [Ident], camel_case_idents: &'a [Ident],
future_types: &'a [Type], future_types: &'a [Type],
method_idents: &'a [&'a Ident], method_idents: &'a [&'a Ident],
request_names: &'a [String],
method_attrs: &'a [&'a [Attribute]], method_attrs: &'a [&'a [Attribute]],
args: &'a [&'a [PatType]], args: &'a [&'a [PatType]],
return_types: &'a [&'a Type], return_types: &'a [&'a Type],
@@ -475,7 +479,7 @@ impl<'a> ServiceGenerator<'a> {
), ),
output, output,
)| { )| {
let ty_doc = format!("The response future returned by {}.", ident); let ty_doc = format!("The response future returned by [`{}::{}`].", service_ident, ident);
quote! { quote! {
#[doc = #ty_doc] #[doc = #ty_doc]
type #future_type: std::future::Future<Output = #output>; type #future_type: std::future::Future<Output = #output>;
@@ -524,6 +528,7 @@ impl<'a> ServiceGenerator<'a> {
camel_case_idents, camel_case_idents,
arg_pats, arg_pats,
method_idents, method_idents,
request_names,
.. ..
} = self; } = self;
@@ -534,6 +539,16 @@ impl<'a> ServiceGenerator<'a> {
type Resp = #response_ident; type Resp = #response_ident;
type Fut = #response_fut_ident<S>; type Fut = #response_fut_ident<S>;
fn method(&self, req: &#request_ident) -> Option<&'static str> {
Some(match req {
#(
#request_ident::#camel_case_idents{..} => {
#request_names
}
)*
})
}
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut { fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
match req { match req {
#( #(
@@ -563,6 +578,7 @@ impl<'a> ServiceGenerator<'a> {
quote! { quote! {
/// The request sent over the wire from the client to the server. /// The request sent over the wire from the client to the server.
#[allow(missing_docs)]
#[derive(Debug)] #[derive(Debug)]
#derive_serialize #derive_serialize
#vis enum #request_ident { #vis enum #request_ident {
@@ -583,6 +599,7 @@ impl<'a> ServiceGenerator<'a> {
quote! { quote! {
/// The response sent over the wire from the server to the client. /// The response sent over the wire from the server to the client.
#[allow(missing_docs)]
#[derive(Debug)] #[derive(Debug)]
#derive_serialize #derive_serialize
#vis enum #response_ident { #vis enum #response_ident {
@@ -603,6 +620,7 @@ impl<'a> ServiceGenerator<'a> {
quote! { quote! {
/// A future resolving to a server response. /// A future resolving to a server response.
#[allow(missing_docs)]
#vis enum #response_fut_ident<S: #service_ident> { #vis enum #response_fut_ident<S: #service_ident> {
#( #camel_case_idents(<S as #service_ident>::#future_types) ),* #( #camel_case_idents(<S as #service_ident>::#future_types) ),*
} }
@@ -714,6 +732,7 @@ impl<'a> ServiceGenerator<'a> {
method_attrs, method_attrs,
vis, vis,
method_idents, method_idents,
request_names,
args, args,
return_types, return_types,
arg_pats, arg_pats,
@@ -727,9 +746,9 @@ impl<'a> ServiceGenerator<'a> {
#[allow(unused)] #[allow(unused)]
#( #method_attrs )* #( #method_attrs )*
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*) #vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ { -> impl std::future::Future<Output = Result<#return_types, tarpc::client::RpcError>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
let resp = self.0.call(ctx, request); let resp = self.0.call(ctx, #request_names, request);
async move { async move {
match resp.await? { match resp.await? {
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg), #response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "tarpc" name = "tarpc"
version = "0.25.1" version = "0.27.1"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"] authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2018" edition = "2018"
license = "MIT" license = "MIT"
@@ -17,10 +17,12 @@ default = []
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"] serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
tokio1 = ["tokio/rt-multi-thread"] tokio1 = ["tokio/rt-multi-thread"]
serde-transport = ["serde1", "tokio1", "tokio-serde/json", "tokio-util/codec"] serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
serde-transport-json = ["tokio-serde/json"]
serde-transport-bincode = ["tokio-serde/bincode"]
tcp = ["tokio/net"] tcp = ["tokio/net"]
full = ["serde1", "tokio1", "serde-transport", "tcp"] full = ["serde1", "tokio1", "serde-transport", "serde-transport-json", "serde-transport-bincode", "tcp"]
[badges] [badges]
travis-ci = { repository = "google/tarpc" } travis-ci = { repository = "google/tarpc" }
@@ -30,26 +32,31 @@ anyhow = "1.0"
fnv = "1.0" fnv = "1.0"
futures = "0.3" futures = "0.3"
humantime = "2.0" humantime = "2.0"
log = "0.4"
pin-project = "1.0" pin-project = "1.0"
rand = "0.7" rand = "0.8"
serde = { optional = true, version = "1.0", features = ["derive"] } serde = { optional = true, version = "1.0", features = ["derive"] }
static_assertions = "1.1.0" static_assertions = "1.1.0"
tarpc-plugins = { path = "../plugins", version = "0.10" } tarpc-plugins = { path = "../plugins", version = "0.12" }
thiserror = "1.0"
tokio = { version = "1", features = ["time"] } tokio = { version = "1", features = ["time"] }
tokio-util = { version = "0.6.3", features = ["time"] } tokio-util = { version = "0.6.3", features = ["time"] }
tokio-serde = { optional = true, version = "0.8" } tokio-serde = { optional = true, version = "0.8" }
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
tracing-opentelemetry = { version = "0.15", default-features = false }
opentelemetry = { version = "0.16", default-features = false }
[dev-dependencies] [dev-dependencies]
assert_matches = "1.4" assert_matches = "1.4"
bincode = "1.3" bincode = "1.3"
bytes = { version = "1", features = ["serde"] } bytes = { version = "1", features = ["serde"] }
env_logger = "0.8"
flate2 = "1.0" flate2 = "1.0"
futures-test = "0.3" futures-test = "0.3"
log = "0.4" opentelemetry = { version = "0.16", default-features = false, features = ["rt-tokio"] }
opentelemetry-jaeger = { version = "0.15", features = ["rt-tokio"] }
pin-utils = "0.1.0-alpha" pin-utils = "0.1.0-alpha"
serde_bytes = "0.11" serde_bytes = "0.11"
tracing-subscriber = "0.2"
tokio = { version = "1", features = ["full", "test-util"] } tokio = { version = "1", features = ["full", "test-util"] }
tokio-serde = { version = "0.8", features = ["json", "bincode"] } tokio-serde = { version = "0.8", features = ["json", "bincode"] }
trybuild = "1.0" trybuild = "1.0"
@@ -63,7 +70,7 @@ name = "compression"
required-features = ["serde-transport", "tcp"] required-features = ["serde-transport", "tcp"]
[[example]] [[example]]
name = "server_calling_server" name = "tracing"
required-features = ["full"] required-features = ["full"]
[[example]] [[example]]

View File

@@ -7,8 +7,8 @@ use tarpc::{
client, context, client, context,
serde_transport::tcp, serde_transport::tcp,
server::{BaseChannel, Channel}, server::{BaseChannel, Channel},
tokio_serde::formats::Bincode,
}; };
use tokio_serde::formats::Bincode;
/// Type of compression that should be enabled on the request. The transport is free to ignore this. /// Type of compression that should be enabled on the request. The transport is free to ignore this.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)] #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)]
@@ -118,7 +118,7 @@ async fn main() -> anyhow::Result<()> {
}); });
let transport = tcp::connect(addr, Bincode::default).await?; let transport = tcp::connect(addr, Bincode::default).await?;
let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn()?; let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn();
println!( println!(
"{}", "{}",

View File

@@ -1,9 +1,7 @@
use futures::future;
use tarpc::context::Context;
use tarpc::serde_transport as transport; use tarpc::serde_transport as transport;
use tarpc::server::{BaseChannel, Channel}; use tarpc::server::{BaseChannel, Channel};
use tarpc::{context::Context, tokio_serde::formats::Bincode};
use tokio::net::{UnixListener, UnixStream}; use tokio::net::{UnixListener, UnixStream};
use tokio_serde::formats::Bincode;
use tokio_util::codec::length_delimited::LengthDelimitedCodec; use tokio_util::codec::length_delimited::LengthDelimitedCodec;
#[tarpc::service] #[tarpc::service]
@@ -14,16 +12,13 @@ pub trait PingService {
#[derive(Clone)] #[derive(Clone)]
struct Service; struct Service;
#[tarpc::server]
impl PingService for Service { impl PingService for Service {
type PingFut = future::Ready<()>; async fn ping(self, _: Context) {}
fn ping(self, _: Context) -> Self::PingFut {
future::ready(())
}
} }
#[tokio::main] #[tokio::main]
async fn main() -> std::io::Result<()> { async fn main() -> anyhow::Result<()> {
let bind_addr = "/tmp/tarpc_on_unix_example.sock"; let bind_addr = "/tmp/tarpc_on_unix_example.sock";
let _ = std::fs::remove_file(bind_addr); let _ = std::fs::remove_file(bind_addr);
@@ -44,7 +39,9 @@ async fn main() -> std::io::Result<()> {
let conn = UnixStream::connect(bind_addr).await?; let conn = UnixStream::connect(bind_addr).await?;
let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); let transport = transport::new(codec_builder.new_framed(conn), Bincode::default());
PingServiceClient::new(Default::default(), transport) PingServiceClient::new(Default::default(), transport)
.spawn()? .spawn()
.ping(tarpc::context::current()) .ping(tarpc::context::current())
.await .await?;
Ok(())
} }

View File

@@ -38,10 +38,11 @@ use futures::{
future::{self, AbortHandle}, future::{self, AbortHandle},
prelude::*, prelude::*,
}; };
use log::info;
use publisher::Publisher as _; use publisher::Publisher as _;
use std::{ use std::{
collections::HashMap, collections::HashMap,
env,
error::Error,
io, io,
net::SocketAddr, net::SocketAddr,
sync::{Arc, Mutex, RwLock}, sync::{Arc, Mutex, RwLock},
@@ -54,6 +55,8 @@ use tarpc::{
}; };
use tokio::net::ToSocketAddrs; use tokio::net::ToSocketAddrs;
use tokio_serde::formats::Json; use tokio_serde::formats::Json;
use tracing::info;
use tracing_subscriber::prelude::*;
pub mod subscriber { pub mod subscriber {
#[tarpc::service] #[tarpc::service]
@@ -83,10 +86,7 @@ impl subscriber::Subscriber for Subscriber {
} }
async fn receive(self, _: context::Context, topic: String, message: String) { async fn receive(self, _: context::Context, topic: String, message: String) {
info!( info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage")
"[{}] received message on topic '{}': {}",
self.local_addr, topic, message
);
} }
} }
@@ -120,7 +120,7 @@ impl Subscriber {
let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve())); let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve()));
tokio::spawn(async move { tokio::spawn(async move {
match handler.await { match handler.await {
Ok(()) | Err(future::Aborted) => info!("[{}] subscriber shutdown.", local_addr), Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."),
} }
}); });
Ok(SubscriberHandle(abort_handle)) Ok(SubscriberHandle(abort_handle))
@@ -153,13 +153,13 @@ impl Publisher {
subscriptions: self.clone().start_subscription_manager().await?, subscriptions: self.clone().start_subscription_manager().await?,
}; };
info!("[{}] listening for publishers.", publisher_addrs.publisher); info!(publisher_addr = %publisher_addrs.publisher, "listening for publishers.",);
tokio::spawn(async move { tokio::spawn(async move {
// Because this is just an example, we know there will only be one publisher. In more // Because this is just an example, we know there will only be one publisher. In more
// realistic code, this would be a loop to continually accept new publisher // realistic code, this would be a loop to continually accept new publisher
// connections. // connections.
let publisher = connecting_publishers.next().await.unwrap().unwrap(); let publisher = connecting_publishers.next().await.unwrap().unwrap();
info!("[{}] publisher connected.", publisher.peer_addr().unwrap()); info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected.");
server::BaseChannel::with_defaults(publisher) server::BaseChannel::with_defaults(publisher)
.execute(self.serve()) .execute(self.serve())
@@ -174,7 +174,7 @@ impl Publisher {
.await? .await?
.filter_map(|r| future::ready(r.ok())); .filter_map(|r| future::ready(r.ok()));
let new_subscriber_addr = connecting_subscribers.get_ref().local_addr(); let new_subscriber_addr = connecting_subscribers.get_ref().local_addr();
info!("[{}] listening for subscribers.", new_subscriber_addr); info!(?new_subscriber_addr, "listening for subscribers.");
tokio::spawn(async move { tokio::spawn(async move {
while let Some(conn) = connecting_subscribers.next().await { while let Some(conn) = connecting_subscribers.next().await {
@@ -215,7 +215,7 @@ impl Publisher {
}, },
); );
info!("[{}] subscribed to topics: {:?}", subscriber_addr, topics); info!(%subscriber_addr, ?topics, "subscribed to new topics");
let mut subscriptions = self.subscriptions.write().unwrap(); let mut subscriptions = self.subscriptions.write().unwrap();
for topic in topics { for topic in topics {
subscriptions subscriptions
@@ -226,18 +226,18 @@ impl Publisher {
} }
} }
fn start_subscriber_gc( fn start_subscriber_gc<E: Error>(
self, self,
subscriber_addr: SocketAddr, subscriber_addr: SocketAddr,
client_dispatch: impl Future<Output = anyhow::Result<()>> + Send + 'static, client_dispatch: impl Future<Output = Result<(), E>> + Send + 'static,
subscriber_ready: oneshot::Receiver<()>, subscriber_ready: oneshot::Receiver<()>,
) { ) {
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = client_dispatch.await { if let Err(e) = client_dispatch.await {
info!( info!(
"[{}] subscriber connection broken: {:?}", %subscriber_addr,
subscriber_addr, e error = %e,
) "subscriber connection broken");
} }
// Don't clean up the subscriber until initialization is done. // Don't clean up the subscriber until initialization is done.
let _ = subscriber_ready.await; let _ = subscriber_ready.await;
@@ -281,13 +281,29 @@ impl publisher::Publisher for Publisher {
} }
} }
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
let tracer = opentelemetry_jaeger::new_pipeline()
.with_service_name(service_name)
.with_max_packet_size(2usize.pow(13))
.install_batch(opentelemetry::runtime::Tokio)?;
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer())
.with(tracing_opentelemetry::layer().with_tracer(tracer))
.try_init()?;
Ok(())
}
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
env_logger::init(); init_tracing("Pub/Sub")?;
let clients = Arc::new(Mutex::new(HashMap::new()));
let addrs = Publisher { let addrs = Publisher {
clients, clients: Arc::new(Mutex::new(HashMap::new())),
subscriptions: Arc::new(RwLock::new(HashMap::new())), subscriptions: Arc::new(RwLock::new(HashMap::new())),
} }
.start() .start()
@@ -309,7 +325,7 @@ async fn main() -> anyhow::Result<()> {
client::Config::default(), client::Config::default(),
tcp::connect(addrs.publisher, Json::default).await?, tcp::connect(addrs.publisher, Json::default).await?,
) )
.spawn()?; .spawn();
publisher publisher
.publish(context::current(), "calculus".into(), "sqrt(2)".into()) .publish(context::current(), "calculus".into(), "sqrt(2)".into())
@@ -337,6 +353,7 @@ async fn main() -> anyhow::Result<()> {
) )
.await?; .await?;
opentelemetry::global::shutdown_tracer_provider();
info!("done."); info!("done.");
Ok(()) Ok(())

View File

@@ -5,7 +5,6 @@
// https://opensource.org/licenses/MIT. // https://opensource.org/licenses/MIT.
use futures::future::{self, Ready}; use futures::future::{self, Ready};
use std::io;
use tarpc::{ use tarpc::{
client, context, client, context,
server::{self, Channel}, server::{self, Channel},
@@ -35,7 +34,7 @@ impl World for HelloServer {
} }
#[tokio::main] #[tokio::main]
async fn main() -> io::Result<()> { async fn main() -> anyhow::Result<()> {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = server::BaseChannel::with_defaults(server_transport); let server = server::BaseChannel::with_defaults(server_transport);
@@ -43,7 +42,7 @@ async fn main() -> io::Result<()> {
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
// that takes a config and any Transport as input. // that takes a config and any Transport as input.
let client = WorldClient::new(client::Config::default(), client_transport).spawn()?; let client = WorldClient::new(client::Config::default(), client_transport).spawn();
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same // The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context // args as defined, with the addition of a Context, which is always the first arg. The Context

View File

@@ -6,12 +6,13 @@
use crate::{add::Add as AddService, double::Double as DoubleService}; use crate::{add::Add as AddService, double::Double as DoubleService};
use futures::{future, prelude::*}; use futures::{future, prelude::*};
use std::io; use std::env;
use tarpc::{ use tarpc::{
client, context, client, context,
server::{BaseChannel, Incoming}, server::{incoming::Incoming, BaseChannel},
}; };
use tokio_serde::formats::Json; use tokio_serde::formats::Json;
use tracing_subscriber::prelude::*;
pub mod add { pub mod add {
#[tarpc::service] #[tarpc::service]
@@ -54,9 +55,25 @@ impl DoubleService for DoubleServer {
} }
} }
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
let tracer = opentelemetry_jaeger::new_pipeline()
.with_service_name(service_name)
.with_max_packet_size(2usize.pow(13))
.install_batch(opentelemetry::runtime::Tokio)?;
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer())
.with(tracing_opentelemetry::layer().with_tracer(tracer))
.try_init()?;
Ok(())
}
#[tokio::main] #[tokio::main]
async fn main() -> io::Result<()> { async fn main() -> anyhow::Result<()> {
env_logger::init(); init_tracing("tarpc_tracing_example")?;
let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await? .await?
@@ -69,7 +86,7 @@ async fn main() -> io::Result<()> {
tokio::spawn(add_server); tokio::spawn(add_server);
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn()?; let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn();
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await? .await?
@@ -83,10 +100,14 @@ async fn main() -> io::Result<()> {
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
let double_client = let double_client =
double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?; double::DoubleClient::new(client::Config::default(), to_double_server).spawn();
for i in 1..=5 { let ctx = context::current();
eprintln!("{:?}", double_client.double(context::current(), i).await?); for _ in 1..=5 {
tracing::info!("{:?}", double_client.double(ctx, 1).await?);
} }
opentelemetry::global::shutdown_tracer_provider();
Ok(()) Ok(())
} }

View File

@@ -8,16 +8,14 @@
mod in_flight_requests; mod in_flight_requests;
use crate::{ use crate::{context, trace, ClientMessage, Request, Response, ServerError, Transport};
context, trace::SpanId, ClientMessage, PollContext, PollIo, Request, Response, Transport,
};
use futures::{prelude::*, ready, stream::Fuse, task::*}; use futures::{prelude::*, ready, stream::Fuse, task::*};
use in_flight_requests::InFlightRequests; use in_flight_requests::{DeadlineExceededError, InFlightRequests};
use log::{info, trace};
use pin_project::pin_project; use pin_project::pin_project;
use std::{ use std::{
convert::TryFrom, convert::TryFrom,
fmt, io, error::Error,
fmt, mem,
pin::Pin, pin::Pin,
sync::{ sync::{
atomic::{AtomicUsize, Ordering}, atomic::{AtomicUsize, Ordering},
@@ -25,6 +23,7 @@ use std::{
}, },
}; };
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tracing::Span;
/// Settings that control the behavior of the client. /// Settings that control the behavior of the client.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@@ -61,19 +60,18 @@ pub struct NewClient<C, D> {
impl<C, D, E> NewClient<C, D> impl<C, D, E> NewClient<C, D>
where where
D: Future<Output = Result<(), E>> + Send + 'static, D: Future<Output = Result<(), E>> + Send + 'static,
E: std::fmt::Display, E: std::error::Error + Send + Sync + 'static,
{ {
/// Helper method to spawn the dispatch on the default executor. /// Helper method to spawn the dispatch on the default executor.
#[cfg(feature = "tokio1")] #[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
pub fn spawn(self) -> io::Result<C> { pub fn spawn(self) -> C {
use log::warn; let dispatch = self.dispatch.unwrap_or_else(move |e| {
let e = anyhow::Error::new(e);
let dispatch = self tracing::warn!("Connection broken: {:?}", e);
.dispatch });
.unwrap_or_else(move |e| warn!("Connection broken: {}", e));
tokio::spawn(dispatch); tokio::spawn(dispatch);
Ok(self.client) self.client
} }
} }
@@ -114,100 +112,117 @@ impl<Req, Resp> Clone for Channel<Req, Resp> {
impl<Req, Resp> Channel<Req, Resp> { impl<Req, Resp> Channel<Req, Resp> {
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves when the request is sent (not when the response is received). /// resolves to the response.
fn send( #[tracing::instrument(
name = "RPC",
skip(self, ctx, request_name, request),
fields(
rpc.trace_id = tracing::field::Empty,
otel.kind = "client",
otel.name = request_name)
)]
pub async fn call(
&self, &self,
mut ctx: context::Context, mut ctx: context::Context,
request_name: &str,
request: Req, request: Req,
) -> impl Future<Output = io::Result<DispatchResponse<Resp>>> + '_ { ) -> Result<Resp, RpcError> {
// Convert the context to the call context. let span = Span::current();
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id); ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng()); tracing::warn!(
"OpenTelemetry subscriber not installed; making unsampled child context."
let (response_completion, response) = oneshot::channel(); );
let cancellation = self.cancellation.clone(); ctx.trace_context.new_child()
});
span.record("rpc.trace_id", &tracing::field::display(ctx.trace_id()));
let (response_completion, mut response) = oneshot::channel();
let request_id = let request_id =
u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
// DispatchResponse impls Drop to cancel in-flight requests. It should be created before // ResponseGuard impls Drop to cancel in-flight requests. It should be created before
// sending out the request; otherwise, the response future could be dropped after the // sending out the request; otherwise, the response future could be dropped after the
// request is sent out but before DispatchResponse is created, rendering the cancellation // request is sent out but before ResponseGuard is created, rendering the cancellation
// logic inactive. // logic inactive.
let response = DispatchResponse { let response_guard = ResponseGuard {
response, response: &mut response,
request_id, request_id,
cancellation: Some(cancellation), cancellation: &self.cancellation,
ctx,
}; };
async move { self.to_dispatch
self.to_dispatch .send(DispatchRequest {
.send(DispatchRequest { ctx,
ctx, span,
request_id, request_id,
request, request,
response_completion, response_completion,
}) })
.await .await
.map_err(|mpsc::error::SendError(_)| { .map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
io::Error::from(io::ErrorKind::ConnectionReset) response_guard.response().await
})?;
Ok(response)
}
}
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves to the response.
pub async fn call(&self, ctx: context::Context, request: Req) -> io::Result<Resp> {
let dispatch_response = self.send(ctx, request).await?;
dispatch_response.await
} }
} }
/// A server response that is completed by request dispatch when the corresponding response /// A server response that is completed by request dispatch when the corresponding response
/// arrives off the wire. /// arrives off the wire.
#[derive(Debug)] struct ResponseGuard<'a, Resp> {
struct DispatchResponse<Resp> { response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
response: oneshot::Receiver<Response<Resp>>, cancellation: &'a RequestCancellation,
ctx: context::Context,
cancellation: Option<RequestCancellation>,
request_id: u64, request_id: u64,
} }
impl<Resp> Future for DispatchResponse<Resp> { /// An error that can occur in the processing of an RPC. This is not request-specific errors but
type Output = io::Result<Resp>; /// rather cross-cutting errors that can always occur.
#[derive(thiserror::Error, Debug)]
pub enum RpcError {
/// The client disconnected from the server.
#[error("the client disconnected from the server")]
Disconnected,
/// The request exceeded its deadline.
#[error("the request exceeded its deadline")]
DeadlineExceeded,
/// The server aborted request processing.
#[error("the server aborted request processing")]
Server(#[from] ServerError),
}
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> { impl From<DeadlineExceededError> for RpcError {
let resp = ready!(self.response.poll_unpin(cx)); fn from(_: DeadlineExceededError) -> Self {
self.cancellation.take(); RpcError::DeadlineExceeded
Poll::Ready(match resp { }
Ok(resp) => Ok(resp.message?), }
impl<Resp> ResponseGuard<'_, Resp> {
async fn response(mut self) -> Result<Resp, RpcError> {
let response = (&mut self.response).await;
// Cancel drop logic once a response has been received.
mem::forget(self);
match response {
Ok(resp) => Ok(resp?.message?),
Err(oneshot::error::RecvError { .. }) => { Err(oneshot::error::RecvError { .. }) => {
// The oneshot is Canceled when the dispatch task ends. In that case, // The oneshot is Canceled when the dispatch task ends. In that case,
// there's nothing listening on the other side, so there's no point in // there's nothing listening on the other side, so there's no point in
// propagating cancellation. // propagating cancellation.
Err(io::Error::from(io::ErrorKind::ConnectionReset)) Err(RpcError::Disconnected)
} }
}) }
} }
} }
// Cancels the request when dropped, if not already complete. // Cancels the request when dropped, if not already complete.
impl<Resp> Drop for DispatchResponse<Resp> { impl<Resp> Drop for ResponseGuard<'_, Resp> {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(cancellation) = &mut self.cancellation { // The receiver needs to be closed to handle the edge case that the request has not
// The receiver needs to be closed to handle the edge case that the request has not // yet been received by the dispatch task. It is possible for the cancel message to
// yet been received by the dispatch task. It is possible for the cancel message to // arrive before the request itself, in which case the request could get stuck in the
// arrive before the request itself, in which case the request could get stuck in the // dispatch map forever if the server never responds (e.g. if the server dies while
// dispatch map forever if the server never responds (e.g. if the server dies while // responding). Even if the server does respond, it will have unnecessarily done work
// responding). Even if the server does respond, it will have unnecessarily done work // for a client no longer waiting for a response. To avoid this, the dispatch task
// for a client no longer waiting for a response. To avoid this, the dispatch task // checks if the receiver is closed before inserting the request in the map. By
// checks if the receiver is closed before inserting the request in the map. By // closing the receiver before sending the cancel message, it is guaranteed that if the
// closing the receiver before sending the cancel message, it is guaranteed that if the // dispatch task misses an early-arriving cancellation message, then it will see the
// dispatch task misses an early-arriving cancellation message, then it will see the // receiver as closed.
// receiver as closed. self.response.close();
self.response.close(); self.cancellation.cancel(self.request_id);
cancellation.cancel(self.request_id);
}
} }
} }
@@ -258,6 +273,32 @@ pub struct RequestDispatch<Req, Resp, C> {
config: Config, config: Config,
} }
/// Critical errors that result in a Channel disconnecting.
#[derive(thiserror::Error, Debug)]
pub enum ChannelError<E>
where
E: Error + Send + Sync + 'static,
{
/// Could not read from the transport.
#[error("could not read from the transport")]
Read(#[source] E),
/// Could not ready the transport for writes.
#[error("could not ready the transport for writes")]
Ready(#[source] E),
/// Could not write to the transport.
#[error("could not write to the transport")]
Write(#[source] E),
/// Could not flush the transport.
#[error("could not flush the transport")]
Flush(#[source] E),
/// Could not close the write end of the transport.
#[error("could not close the write end of the transport")]
Close(#[source] E),
/// Could not poll expired requests.
#[error("could not poll expired requests")]
Timer(#[source] tokio::time::error::Error),
}
impl<Req, Resp, C> RequestDispatch<Req, Resp, C> impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
where where
C: Transport<ClientMessage<Req>, Response<Resp>>, C: Transport<ClientMessage<Req>, Response<Resp>>,
@@ -270,6 +311,42 @@ where
self.as_mut().project().transport self.as_mut().project().transport
} }
fn poll_ready<'a>(
self: &'a mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), ChannelError<C::Error>>> {
self.transport_pin_mut()
.poll_ready(cx)
.map_err(ChannelError::Ready)
}
fn start_send(
self: &mut Pin<&mut Self>,
message: ClientMessage<Req>,
) -> Result<(), ChannelError<C::Error>> {
self.transport_pin_mut()
.start_send(message)
.map_err(ChannelError::Write)
}
fn poll_flush<'a>(
self: &'a mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), ChannelError<C::Error>>> {
self.transport_pin_mut()
.poll_flush(cx)
.map_err(ChannelError::Flush)
}
fn poll_close<'a>(
self: &'a mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), ChannelError<C::Error>>> {
self.transport_pin_mut()
.poll_close(cx)
.map_err(ChannelError::Close)
}
fn canceled_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut CanceledRequests { fn canceled_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut CanceledRequests {
self.as_mut().project().canceled_requests self.as_mut().project().canceled_requests
} }
@@ -280,17 +357,22 @@ where
self.as_mut().project().pending_requests self.as_mut().project().pending_requests
} }
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { fn pump_read(
Poll::Ready(match ready!(self.transport_pin_mut().poll_next(cx)?) { mut self: Pin<&mut Self>,
Some(response) => { cx: &mut Context<'_>,
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
self.transport_pin_mut()
.poll_next(cx)
.map_err(ChannelError::Read)
.map_ok(|response| {
self.complete(response); self.complete(response);
Some(Ok(())) })
}
None => None,
})
} }
fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { fn pump_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
enum ReceiverStatus { enum ReceiverStatus {
Pending, Pending,
Closed, Closed,
@@ -311,7 +393,11 @@ where
// Receiving Poll::Ready(None) when polling expired requests never indicates "Closed", // Receiving Poll::Ready(None) when polling expired requests never indicates "Closed",
// because there can temporarily be zero in-flight rquests. Therefore, there is no need to // because there can temporarily be zero in-flight rquests. Therefore, there is no need to
// track the status like is done with pending and cancelled requests. // track the status like is done with pending and cancelled requests.
if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx)? { if let Poll::Ready(Some(_)) = self
.in_flight_requests()
.poll_expired(cx)
.map_err(ChannelError::Timer)?
{
// Expired requests are considered complete; there is no compelling reason to send a // Expired requests are considered complete; there is no compelling reason to send a
// cancellation message to the server, since it will have already exhausted its // cancellation message to the server, since it will have already exhausted its
// allotted processing time. // allotted processing time.
@@ -320,12 +406,12 @@ where
match (pending_requests_status, canceled_requests_status) { match (pending_requests_status, canceled_requests_status) {
(ReceiverStatus::Closed, ReceiverStatus::Closed) => { (ReceiverStatus::Closed, ReceiverStatus::Closed) => {
ready!(self.transport_pin_mut().poll_flush(cx)?); ready!(self.poll_close(cx)?);
Poll::Ready(None) Poll::Ready(None)
} }
(ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => { (ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => {
// No more messages to process, so flush any messages buffered in the transport. // No more messages to process, so flush any messages buffered in the transport.
ready!(self.transport_pin_mut().poll_flush(cx)?); ready!(self.poll_flush(cx)?);
// Even if we fully-flush, we return Pending, because we have no more requests // Even if we fully-flush, we return Pending, because we have no more requests
// or cancellations right now. // or cancellations right now.
@@ -341,9 +427,9 @@ where
fn poll_next_request( fn poll_next_request(
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> PollIo<DispatchRequest<Req, Resp>> { ) -> Poll<Option<Result<DispatchRequest<Req, Resp>, ChannelError<C::Error>>>> {
if self.in_flight_requests().len() >= self.config.max_in_flight_requests { if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
info!( tracing::info!(
"At in-flight request capacity ({}/{}).", "At in-flight request capacity ({}/{}).",
self.in_flight_requests().len(), self.in_flight_requests().len(),
self.config.max_in_flight_requests self.config.max_in_flight_requests
@@ -360,10 +446,8 @@ where
match ready!(self.pending_requests_mut().poll_recv(cx)) { match ready!(self.pending_requests_mut().poll_recv(cx)) {
Some(request) => { Some(request) => {
if request.response_completion.is_closed() { if request.response_completion.is_closed() {
trace!( let _entered = request.span.enter();
"[{}] Request canceled before being sent.", tracing::info!("AbortRequest");
request.ctx.trace_id()
);
continue; continue;
} }
@@ -381,14 +465,15 @@ where
fn poll_next_cancellation( fn poll_next_cancellation(
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> PollIo<(context::Context, u64)> { ) -> Poll<Option<Result<(context::Context, Span, u64), ChannelError<C::Error>>>> {
ready!(self.ensure_writeable(cx)?); ready!(self.ensure_writeable(cx)?);
loop { loop {
match ready!(self.canceled_requests_mut().poll_next_unpin(cx)) { match ready!(self.canceled_requests_mut().poll_next_unpin(cx)) {
Some(request_id) => { Some(request_id) => {
if let Some(ctx) = self.in_flight_requests().cancel_request(request_id) { if let Some((ctx, span)) = self.in_flight_requests().cancel_request(request_id)
return Poll::Ready(Some(Ok((ctx, request_id)))); {
return Poll::Ready(Some(Ok((ctx, span, request_id))));
} }
} }
None => return Poll::Ready(None), None => return Poll::Ready(None),
@@ -399,54 +484,73 @@ where
/// Returns Ready if writing a message to the transport (i.e. via write_request or /// Returns Ready if writing a message to the transport (i.e. via write_request or
/// write_cancel) would not fail due to a full buffer. If the transport is not ready to be /// write_cancel) would not fail due to a full buffer. If the transport is not ready to be
/// written to, flushes it until it is ready. /// written to, flushes it until it is ready.
fn ensure_writeable<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { fn ensure_writeable<'a>(
while self.transport_pin_mut().poll_ready(cx)?.is_pending() { self: &'a mut Pin<&mut Self>,
ready!(self.transport_pin_mut().poll_flush(cx)?); cx: &mut Context<'_>,
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
while self.poll_ready(cx)?.is_pending() {
ready!(self.poll_flush(cx)?);
} }
Poll::Ready(Some(Ok(()))) Poll::Ready(Some(Ok(())))
} }
fn poll_write_request<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { fn poll_write_request<'a>(
let dispatch_request = match ready!(self.as_mut().poll_next_request(cx)?) { self: &'a mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
let DispatchRequest {
ctx,
span,
request_id,
request,
response_completion,
} = match ready!(self.as_mut().poll_next_request(cx)?) {
Some(dispatch_request) => dispatch_request, Some(dispatch_request) => dispatch_request,
None => return Poll::Ready(None), None => return Poll::Ready(None),
}; };
let entered = span.enter();
// poll_next_request only returns Ready if there is room to buffer another request. // poll_next_request only returns Ready if there is room to buffer another request.
// Therefore, we can call write_request without fear of erroring due to a full // Therefore, we can call write_request without fear of erroring due to a full
// buffer. // buffer.
let request_id = dispatch_request.request_id; let request_id = request_id;
let request = ClientMessage::Request(Request { let request = ClientMessage::Request(Request {
id: request_id, id: request_id,
message: dispatch_request.request, message: request,
context: context::Context { context: context::Context {
deadline: dispatch_request.ctx.deadline, deadline: ctx.deadline,
trace_context: dispatch_request.ctx.trace_context, trace_context: ctx.trace_context,
}, },
}); });
self.transport_pin_mut().start_send(request)?; self.start_send(request)?;
let deadline = ctx.deadline;
tracing::info!(
tarpc.deadline = %humantime::format_rfc3339(deadline),
"SendRequest"
);
drop(entered);
self.in_flight_requests() self.in_flight_requests()
.insert_request( .insert_request(request_id, ctx, span, response_completion)
request_id,
dispatch_request.ctx,
dispatch_request.response_completion,
)
.expect("Request IDs should be unique"); .expect("Request IDs should be unique");
Poll::Ready(Some(Ok(()))) Poll::Ready(Some(Ok(())))
} }
fn poll_write_cancel<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { fn poll_write_cancel<'a>(
let (context, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { self: &'a mut Pin<&mut Self>,
Some((context, request_id)) => (context, request_id), cx: &mut Context<'_>,
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
Some(triple) => triple,
None => return Poll::Ready(None), None => return Poll::Ready(None),
}; };
let _entered = span.enter();
let trace_id = *context.trace_id();
let cancel = ClientMessage::Cancel { let cancel = ClientMessage::Cancel {
trace_context: context.trace_context, trace_context: context.trace_context,
request_id, request_id,
}; };
self.transport_pin_mut().start_send(cancel)?; self.start_send(cancel)?;
trace!("[{}] Cancel message sent.", trace_id); tracing::info!("CancelRequest");
Poll::Ready(Some(Ok(()))) Poll::Ready(Some(Ok(())))
} }
@@ -460,28 +564,24 @@ impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
where where
C: Transport<ClientMessage<Req>, Response<Resp>>, C: Transport<ClientMessage<Req>, Response<Resp>>,
{ {
type Output = anyhow::Result<()>; type Output = Result<(), ChannelError<C::Error>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<anyhow::Result<()>> { fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), ChannelError<C::Error>>> {
loop { loop {
match ( match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) {
self.as_mut()
.pump_read(cx)
.context("failed to read from transport")?,
self.as_mut()
.pump_write(cx)
.context("failed to write to transport")?,
) {
(Poll::Ready(None), _) => { (Poll::Ready(None), _) => {
info!("Shutdown: read half closed, so shutting down."); tracing::info!("Shutdown: read half closed, so shutting down.");
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
(read, Poll::Ready(None)) => { (read, Poll::Ready(None)) => {
if self.in_flight_requests().is_empty() { if self.in_flight_requests.is_empty() {
info!("Shutdown: write half closed, and no requests in flight."); tracing::info!("Shutdown: write half closed, and no requests in flight.");
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
info!( tracing::info!(
"Shutdown: write half closed, and {} requests in flight.", "Shutdown: write half closed, and {} requests in flight.",
self.in_flight_requests().len() self.in_flight_requests().len()
); );
@@ -502,9 +602,10 @@ where
#[derive(Debug)] #[derive(Debug)]
struct DispatchRequest<Req, Resp> { struct DispatchRequest<Req, Resp> {
pub ctx: context::Context, pub ctx: context::Context,
pub span: Span,
pub request_id: u64, pub request_id: u64,
pub request: Req, pub request: Req,
pub response_completion: oneshot::Sender<Response<Resp>>, pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
} }
/// Sends request cancellation signals. /// Sends request cancellation signals.
@@ -518,16 +619,14 @@ struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
/// Returns a channel to send request cancellation messages. /// Returns a channel to send request cancellation messages.
fn cancellations() -> (RequestCancellation, CanceledRequests) { fn cancellations() -> (RequestCancellation, CanceledRequests) {
// Unbounded because messages are sent in the drop fn. This is fine, because it's still // Unbounded because messages are sent in the drop fn. This is fine, because it's still
// bounded by the number of in-flight requests. Additionally, each request has a clone // bounded by the number of in-flight requests.
// of the sender, so the bounded channel would have the same behavior,
// since it guarantees a slot.
let (tx, rx) = mpsc::unbounded_channel(); let (tx, rx) = mpsc::unbounded_channel();
(RequestCancellation(tx), CanceledRequests(rx)) (RequestCancellation(tx), CanceledRequests(rx))
} }
impl RequestCancellation { impl RequestCancellation {
/// Cancels the request with ID `request_id`. /// Cancels the request with ID `request_id`.
fn cancel(&mut self, request_id: u64) { fn cancel(&self, request_id: u64) {
let _ = self.0.send(request_id); let _ = self.0.send(request_id);
} }
} }
@@ -549,28 +648,58 @@ impl Stream for CanceledRequests {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{ use super::{
cancellations, CanceledRequests, Channel, DispatchResponse, RequestCancellation, cancellations, CanceledRequests, Channel, DispatchRequest, RequestCancellation,
RequestDispatch, RequestDispatch, ResponseGuard,
}; };
use crate::{ use crate::{
client::{in_flight_requests::InFlightRequests, Config}, client::{
in_flight_requests::{DeadlineExceededError, InFlightRequests},
Config,
},
context, context,
transport::{self, channel::UnboundedChannel}, transport::{self, channel::UnboundedChannel},
ClientMessage, Response, ClientMessage, Response,
}; };
use assert_matches::assert_matches;
use futures::{prelude::*, task::*}; use futures::{prelude::*, task::*};
use std::{pin::Pin, sync::atomic::AtomicUsize, sync::Arc}; use std::{
convert::TryFrom,
pin::Pin,
sync::atomic::{AtomicUsize, Ordering},
sync::Arc,
};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tracing::Span;
#[tokio::test]
async fn response_completes_request_future() {
let (mut dispatch, mut _channel, mut server_channel) = set_up();
let cx = &mut Context::from_waker(&noop_waker_ref());
let (tx, mut rx) = oneshot::channel();
dispatch
.in_flight_requests
.insert_request(0, context::current(), Span::current(), tx)
.unwrap();
server_channel
.send(Response {
request_id: 0,
message: Ok("Resp".into()),
})
.await
.unwrap();
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
}
#[tokio::test] #[tokio::test]
async fn dispatch_response_cancels_on_drop() { async fn dispatch_response_cancels_on_drop() {
let (cancellation, mut canceled_requests) = cancellations(); let (cancellation, mut canceled_requests) = cancellations();
let (_, response) = oneshot::channel(); let (_, mut response) = oneshot::channel();
drop(DispatchResponse::<u32> { drop(ResponseGuard::<u32> {
response, response: &mut response,
cancellation: Some(cancellation), cancellation: &cancellation,
request_id: 3, request_id: 3,
ctx: context::current(),
}); });
// resp's drop() is run, which should send a cancel message. // resp's drop() is run, which should send a cancel message.
let cx = &mut Context::from_waker(&noop_waker_ref()); let cx = &mut Context::from_waker(&noop_waker_ref());
@@ -580,23 +709,22 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn dispatch_response_doesnt_cancel_after_complete() { async fn dispatch_response_doesnt_cancel_after_complete() {
let (cancellation, mut canceled_requests) = cancellations(); let (cancellation, mut canceled_requests) = cancellations();
let (tx, response) = oneshot::channel(); let (tx, mut response) = oneshot::channel();
tx.send(Response { tx.send(Ok(Response {
request_id: 0, request_id: 0,
message: Ok("well done"), message: Ok("well done"),
}) }))
.unwrap(); .unwrap();
{ // resp's drop() is run, but should not send a cancel message.
DispatchResponse { ResponseGuard {
response, response: &mut response,
cancellation: Some(cancellation), cancellation: &cancellation,
request_id: 3, request_id: 3,
ctx: context::current(),
}
.await
.unwrap();
// resp's drop() is run, but should not send a cancel message.
} }
.response()
.await
.unwrap();
drop(cancellation);
let cx = &mut Context::from_waker(&noop_waker_ref()); let cx = &mut Context::from_waker(&noop_waker_ref());
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(None)); assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(None));
} }
@@ -604,12 +732,12 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn stage_request() { async fn stage_request() {
let (mut dispatch, mut channel, _server_channel) = set_up(); let (mut dispatch, mut channel, _server_channel) = set_up();
let dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref()); let cx = &mut Context::from_waker(&noop_waker_ref());
let (tx, mut rx) = oneshot::channel();
let _resp = send_request(&mut channel, "hi").await; let _resp = send_request(&mut channel, "hi", tx, &mut rx).await;
let req = dispatch.poll_next_request(cx).ready(); let req = dispatch.as_mut().poll_next_request(cx).ready();
assert!(req.is_some()); assert!(req.is_some());
let req = req.unwrap(); let req = req.unwrap();
@@ -621,10 +749,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn stage_request_channel_dropped_doesnt_panic() { async fn stage_request_channel_dropped_doesnt_panic() {
let (mut dispatch, mut channel, mut server_channel) = set_up(); let (mut dispatch, mut channel, mut server_channel) = set_up();
let mut dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref()); let cx = &mut Context::from_waker(&noop_waker_ref());
let (tx, mut rx) = oneshot::channel();
let _ = send_request(&mut channel, "hi").await; let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
drop(channel); drop(channel);
assert!(dispatch.as_mut().poll(cx).is_ready()); assert!(dispatch.as_mut().poll(cx).is_ready());
@@ -642,61 +770,68 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn stage_request_response_future_dropped_is_canceled_before_sending() { async fn stage_request_response_future_dropped_is_canceled_before_sending() {
let (mut dispatch, mut channel, _server_channel) = set_up(); let (mut dispatch, mut channel, _server_channel) = set_up();
let dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref()); let cx = &mut Context::from_waker(&noop_waker_ref());
let (tx, mut rx) = oneshot::channel();
let _ = send_request(&mut channel, "hi").await; let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
// Drop the channel so polling returns none if no requests are currently ready. // Drop the channel so polling returns none if no requests are currently ready.
drop(channel); drop(channel);
// Test that a request future dropped before it's processed by dispatch will cause the request // Test that a request future dropped before it's processed by dispatch will cause the request
// to not be added to the in-flight request map. // to not be added to the in-flight request map.
assert!(dispatch.poll_next_request(cx).ready().is_none()); assert!(dispatch.as_mut().poll_next_request(cx).ready().is_none());
} }
#[tokio::test] #[tokio::test]
async fn stage_request_response_future_dropped_is_canceled_after_sending() { async fn stage_request_response_future_dropped_is_canceled_after_sending() {
let (mut dispatch, mut channel, _server_channel) = set_up(); let (mut dispatch, mut channel, _server_channel) = set_up();
let cx = &mut Context::from_waker(&noop_waker_ref()); let cx = &mut Context::from_waker(&noop_waker_ref());
let mut dispatch = Pin::new(&mut dispatch); let (tx, mut rx) = oneshot::channel();
let req = send_request(&mut channel, "hi").await; let req = send_request(&mut channel, "hi", tx, &mut rx).await;
assert!(dispatch.as_mut().pump_write(cx).ready().is_some()); assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
assert!(!dispatch.in_flight_requests().is_empty()); assert!(!dispatch.in_flight_requests.is_empty());
// Test that a request future dropped after it's processed by dispatch will cause the request // Test that a request future dropped after it's processed by dispatch will cause the request
// to be removed from the in-flight request map. // to be removed from the in-flight request map.
drop(req); drop(req);
if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(cx).unwrap() { assert_matches!(
// ok dispatch.as_mut().poll_next_cancellation(cx),
} else { Poll::Ready(Some(Ok(_)))
panic!("Expected request to be cancelled") );
}; assert!(dispatch.in_flight_requests.is_empty());
assert!(dispatch.in_flight_requests().is_empty());
} }
#[tokio::test] #[tokio::test]
async fn stage_request_response_closed_skipped() { async fn stage_request_response_closed_skipped() {
let (mut dispatch, mut channel, _server_channel) = set_up(); let (mut dispatch, mut channel, _server_channel) = set_up();
let dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref()); let cx = &mut Context::from_waker(&noop_waker_ref());
let (tx, mut rx) = oneshot::channel();
// Test that a request future that's closed its receiver but not yet canceled its request -- // Test that a request future that's closed its receiver but not yet canceled its request --
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request // i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
// map. // map.
let mut resp = send_request(&mut channel, "hi").await; let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
resp.response.close(); resp.response.close();
assert!(dispatch.poll_next_request(cx).is_pending()); assert!(dispatch.as_mut().poll_next_request(cx).is_pending());
} }
fn set_up() -> ( fn set_up() -> (
RequestDispatch<String, String, UnboundedChannel<Response<String>, ClientMessage<String>>>, Pin<
Box<
RequestDispatch<
String,
String,
UnboundedChannel<Response<String>, ClientMessage<String>>,
>,
>,
>,
Channel<String, String>, Channel<String, String>,
UnboundedChannel<ClientMessage<String>, Response<String>>, UnboundedChannel<ClientMessage<String>, Response<String>>,
) { ) {
let _ = env_logger::try_init(); let _ = tracing_subscriber::fmt().with_test_writer().try_init();
let (to_dispatch, pending_requests) = mpsc::channel(1); let (to_dispatch, pending_requests) = mpsc::channel(1);
let (cancel_tx, canceled_requests) = mpsc::unbounded_channel(); let (cancel_tx, canceled_requests) = mpsc::unbounded_channel();
@@ -717,17 +852,31 @@ mod tests {
next_request_id: Arc::new(AtomicUsize::new(0)), next_request_id: Arc::new(AtomicUsize::new(0)),
}; };
(dispatch, channel, server_channel) (Box::pin(dispatch), channel, server_channel)
} }
async fn send_request( async fn send_request<'a>(
channel: &mut Channel<String, String>, channel: &'a mut Channel<String, String>,
request: &str, request: &str,
) -> DispatchResponse<String> { response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
channel response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
.send(context::current(), request.to_string()) ) -> ResponseGuard<'a, String> {
.await let request_id =
.unwrap() u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
let request = DispatchRequest {
ctx: context::current(),
span: Span::current(),
request_id,
request: request.to_string(),
response_completion,
};
channel.to_dispatch.send(request).await.unwrap();
ResponseGuard {
response,
cancellation: &channel.cancellation,
request_id,
}
} }
async fn send_response( async fn send_response(

View File

@@ -1,18 +1,16 @@
use crate::{ use crate::{
context, context,
util::{Compact, TimeUntil}, util::{Compact, TimeUntil},
PollIo, Response, ServerError, Response,
}; };
use fnv::FnvHashMap; use fnv::FnvHashMap;
use futures::ready;
use log::{debug, trace};
use std::{ use std::{
collections::hash_map, collections::hash_map,
io,
task::{Context, Poll}, task::{Context, Poll},
}; };
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio_util::time::delay_queue::{self, DelayQueue}; use tokio_util::time::delay_queue::{self, DelayQueue};
use tracing::Span;
/// Requests already written to the wire that haven't yet received responses. /// Requests already written to the wire that haven't yet received responses.
#[derive(Debug)] #[derive(Debug)]
@@ -30,10 +28,17 @@ impl<Resp> Default for InFlightRequests<Resp> {
} }
} }
/// The request exceeded its deadline.
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[error("the request exceeded its deadline")]
pub struct DeadlineExceededError;
#[derive(Debug)] #[derive(Debug)]
struct RequestData<Resp> { struct RequestData<Resp> {
ctx: context::Context, ctx: context::Context,
response_completion: oneshot::Sender<Response<Resp>>, span: Span,
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
/// The key to remove the timer for the request's deadline. /// The key to remove the timer for the request's deadline.
deadline_key: delay_queue::Key, deadline_key: delay_queue::Key,
} }
@@ -59,20 +64,16 @@ impl<Resp> InFlightRequests<Resp> {
&mut self, &mut self,
request_id: u64, request_id: u64,
ctx: context::Context, ctx: context::Context,
response_completion: oneshot::Sender<Response<Resp>>, span: Span,
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
) -> Result<(), AlreadyExistsError> { ) -> Result<(), AlreadyExistsError> {
match self.request_data.entry(request_id) { match self.request_data.entry(request_id) {
hash_map::Entry::Vacant(vacant) => { hash_map::Entry::Vacant(vacant) => {
let timeout = ctx.deadline.time_until(); let timeout = ctx.deadline.time_until();
trace!(
"[{}] Queuing request with timeout {:?}.",
ctx.trace_id(),
timeout,
);
let deadline_key = self.deadlines.insert(request_id, timeout); let deadline_key = self.deadlines.insert(request_id, timeout);
vacant.insert(RequestData { vacant.insert(RequestData {
ctx, ctx,
span,
response_completion, response_completion,
deadline_key, deadline_key,
}); });
@@ -85,15 +86,15 @@ impl<Resp> InFlightRequests<Resp> {
/// Removes a request without aborting. Returns true iff the request was found. /// Removes a request without aborting. Returns true iff the request was found.
pub fn complete_request(&mut self, response: Response<Resp>) -> bool { pub fn complete_request(&mut self, response: Response<Resp>) -> bool {
if let Some(request_data) = self.request_data.remove(&response.request_id) { if let Some(request_data) = self.request_data.remove(&response.request_id) {
let _entered = request_data.span.enter();
tracing::info!("ReceiveResponse");
self.request_data.compact(0.1); self.request_data.compact(0.1);
trace!("[{}] Received response.", request_data.ctx.trace_id());
self.deadlines.remove(&request_data.deadline_key); self.deadlines.remove(&request_data.deadline_key);
request_data.complete(response); let _ = request_data.response_completion.send(Ok(response));
return true; return true;
} }
debug!( tracing::debug!(
"No in-flight request found for request_id = {}.", "No in-flight request found for request_id = {}.",
response.request_id response.request_id
); );
@@ -104,12 +105,11 @@ impl<Resp> InFlightRequests<Resp> {
/// Cancels a request without completing (typically used when a request handle was dropped /// Cancels a request without completing (typically used when a request handle was dropped
/// before the request completed). /// before the request completed).
pub fn cancel_request(&mut self, request_id: u64) -> Option<context::Context> { pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> {
if let Some(request_data) = self.request_data.remove(&request_id) { if let Some(request_data) = self.request_data.remove(&request_id) {
self.request_data.compact(0.1); self.request_data.compact(0.1);
trace!("[{}] Cancelling request.", request_data.ctx.trace_id());
self.deadlines.remove(&request_data.deadline_key); self.deadlines.remove(&request_data.deadline_key);
Some(request_data.ctx) Some((request_data.ctx, request_data.span))
} else { } else {
None None
} }
@@ -117,46 +117,21 @@ impl<Resp> InFlightRequests<Resp> {
/// Yields a request that has expired, completing it with a TimedOut error. /// Yields a request that has expired, completing it with a TimedOut error.
/// The caller should send cancellation messages for any yielded request ID. /// The caller should send cancellation messages for any yielded request ID.
pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo<u64> { pub fn poll_expired(
Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) { &mut self,
Some(Ok(expired)) => { cx: &mut Context,
let request_id = expired.into_inner(); ) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
if let Some(request_data) = self.request_data.remove(&request_id) { self.deadlines.poll_expired(cx).map_ok(|expired| {
self.request_data.compact(0.1); let request_id = expired.into_inner();
request_data.complete(Self::deadline_exceeded_error(request_id)); if let Some(request_data) = self.request_data.remove(&request_id) {
} let _entered = request_data.span.enter();
Some(Ok(request_id)) tracing::error!("DeadlineExceeded");
self.request_data.compact(0.1);
let _ = request_data
.response_completion
.send(Err(DeadlineExceededError));
} }
Some(Err(e)) => Some(Err(io::Error::new(io::ErrorKind::Other, e))), request_id
None => None,
}) })
} }
fn deadline_exceeded_error(request_id: u64) -> Response<Resp> {
Response {
request_id,
message: Err(ServerError {
kind: io::ErrorKind::TimedOut,
detail: Some("Client dropped expired request.".to_string()),
}),
}
}
}
/// When InFlightRequests is dropped, any outstanding requests are completed with a
/// deadline-exceeded error.
impl<Resp> Drop for InFlightRequests<Resp> {
fn drop(&mut self) {
let deadlines = &mut self.deadlines;
for (_, request_data) in self.request_data.drain() {
let expired = deadlines.remove(&request_data.deadline_key);
request_data.complete(Self::deadline_exceeded_error(expired.into_inner()));
}
}
}
impl<Resp> RequestData<Resp> {
fn complete(self, response: Response<Resp>) {
let _ = self.response_completion.send(response);
}
} }

View File

@@ -8,8 +8,13 @@
//! client to server and is used by the server to enforce response deadlines. //! client to server and is used by the server to enforce response deadlines.
use crate::trace::{self, TraceId}; use crate::trace::{self, TraceId};
use opentelemetry::trace::TraceContextExt;
use static_assertions::assert_impl_all; use static_assertions::assert_impl_all;
use std::time::{Duration, SystemTime}; use std::{
convert::TryFrom,
time::{Duration, SystemTime},
};
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// A request context that carries request-scoped information like deadlines and trace information. /// A request context that carries request-scoped information like deadlines and trace information.
/// It is sent from client to server and is used by the server to enforce response deadlines. /// It is sent from client to server and is used by the server to enforce response deadlines.
@@ -22,14 +27,6 @@ use std::time::{Duration, SystemTime};
pub struct Context { pub struct Context {
/// When the client expects the request to be complete by. The server should cancel the request /// When the client expects the request to be complete by. The server should cancel the request
/// if it is not complete by this time. /// if it is not complete by this time.
#[cfg_attr(
feature = "serde1",
serde(serialize_with = "crate::util::serde::serialize_epoch_secs")
)]
#[cfg_attr(
feature = "serde1",
serde(deserialize_with = "crate::util::serde::deserialize_epoch_secs")
)]
#[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))] #[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))]
pub deadline: SystemTime, pub deadline: SystemTime,
/// Uniquely identifies requests originating from the same source. /// Uniquely identifies requests originating from the same source.
@@ -41,23 +38,65 @@ pub struct Context {
assert_impl_all!(Context: Send, Sync); assert_impl_all!(Context: Send, Sync);
#[cfg(feature = "serde1")]
fn ten_seconds_from_now() -> SystemTime { fn ten_seconds_from_now() -> SystemTime {
SystemTime::now() + Duration::from_secs(10) SystemTime::now() + Duration::from_secs(10)
} }
/// Returns the context for the current request, or a default Context if no request is active. /// Returns the context for the current request, or a default Context if no request is active.
// TODO: populate Context with request-scoped data, with default fallbacks.
pub fn current() -> Context { pub fn current() -> Context {
Context { Context::current()
deadline: SystemTime::now() + Duration::from_secs(10), }
trace_context: trace::Context::new_root(),
#[derive(Clone)]
struct Deadline(SystemTime);
impl Default for Deadline {
fn default() -> Self {
Self(ten_seconds_from_now())
} }
} }
impl Context { impl Context {
/// Returns the context for the current request, or a default Context if no request is active.
pub fn current() -> Self {
let span = tracing::Span::current();
Self {
trace_context: trace::Context::try_from(&span)
.unwrap_or_else(|_| trace::Context::default()),
deadline: span
.context()
.get::<Deadline>()
.cloned()
.unwrap_or_default()
.0,
}
}
/// Returns the ID of the request-scoped trace. /// Returns the ID of the request-scoped trace.
pub fn trace_id(&self) -> &TraceId { pub fn trace_id(&self) -> &TraceId {
&self.trace_context.trace_id &self.trace_context.trace_id
} }
} }
/// An extension trait for [`tracing::Span`] for propagating tarpc Contexts.
pub(crate) trait SpanExt {
/// Sets the given context on this span. Newly-created spans will be children of the given
/// context's trace context.
fn set_context(&self, context: &Context);
}
impl SpanExt for tracing::Span {
fn set_context(&self, context: &Context) {
self.set_parent(
opentelemetry::Context::new()
.with_remote_span_context(opentelemetry::trace::SpanContext::new(
opentelemetry::trace::TraceId::from(context.trace_context.trace_id),
opentelemetry::trace::SpanId::from(context.trace_context.span_id),
opentelemetry::trace::TraceFlags::from(context.trace_context.sampling_decision),
true,
opentelemetry::trace::TraceState::default(),
))
.with_value(Deadline(context.deadline)),
);
}
}

View File

@@ -38,6 +38,14 @@
//! requests sent by the server that use the request context will propagate the request deadline. //! requests sent by the server that use the request context will propagate the request deadline.
//! For example, if a server is handling a request with a 10s deadline, does 2s of work, then //! For example, if a server is handling a request with a 10s deadline, does 2s of work, then
//! sends a request to another server, that server will see an 8s deadline. //! sends a request to another server, that server will see an 8s deadline.
//! - Distributed tracing: tarpc is instrumented with
//! [tracing](https://github.com/tokio-rs/tracing) primitives extended with
//! [OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
//! [Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
//! each RPC can be traced through the client, server, amd other dependencies downstream of the
//! server. Even for applications not connected to a distributed tracing collector, the
//! instrumentation can also be ingested by regular loggers like
//! [env_logger](https://github.com/env-logger-rs/env_logger/).
//! - Serde serialization: enabling the `serde1` Cargo feature will make service requests and //! - Serde serialization: enabling the `serde1` Cargo feature will make service requests and
//! responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can //! responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
//! be used, as well, so the price of serialization doesn't have to be paid when it's not needed. //! be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
@@ -46,7 +54,7 @@
//! Add to your `Cargo.toml` dependencies: //! Add to your `Cargo.toml` dependencies:
//! //!
//! ```toml //! ```toml
//! tarpc = "0.25" //! tarpc = "0.27"
//! ``` //! ```
//! //!
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service. //! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
@@ -59,8 +67,9 @@
//! your `Cargo.toml`: //! your `Cargo.toml`:
//! //!
//! ```toml //! ```toml
//! futures = "1.0" //! anyhow = "1.0"
//! tarpc = { version = "0.25", features = ["tokio1"] } //! futures = "0.3"
//! tarpc = { version = "0.27", features = ["tokio1"] }
//! tokio = { version = "1.0", features = ["macros"] } //! tokio = { version = "1.0", features = ["macros"] }
//! ``` //! ```
//! //!
@@ -79,9 +88,8 @@
//! }; //! };
//! use tarpc::{ //! use tarpc::{
//! client, context, //! client, context,
//! server::{self, Incoming}, //! server::{self, incoming::Incoming},
//! }; //! };
//! use std::io;
//! //!
//! // This is the service definition. It looks a lot like a trait definition. //! // This is the service definition. It looks a lot like a trait definition.
//! // It defines one RPC, hello, which takes one arg, name, and returns a String. //! // It defines one RPC, hello, which takes one arg, name, and returns a String.
@@ -103,9 +111,8 @@
//! # }; //! # };
//! # use tarpc::{ //! # use tarpc::{
//! # client, context, //! # client, context,
//! # server::{self, Incoming}, //! # server::{self, incoming::Incoming},
//! # }; //! # };
//! # use std::io;
//! # // This is the service definition. It looks a lot like a trait definition. //! # // This is the service definition. It looks a lot like a trait definition.
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String. //! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
//! # #[tarpc::service] //! # #[tarpc::service]
@@ -145,7 +152,6 @@
//! # client, context, //! # client, context,
//! # server::{self, Channel}, //! # server::{self, Channel},
//! # }; //! # };
//! # use std::io;
//! # // This is the service definition. It looks a lot like a trait definition. //! # // This is the service definition. It looks a lot like a trait definition.
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String. //! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
//! # #[tarpc::service] //! # #[tarpc::service]
@@ -169,7 +175,7 @@
//! # fn main() {} //! # fn main() {}
//! # #[cfg(feature = "tokio1")] //! # #[cfg(feature = "tokio1")]
//! #[tokio::main] //! #[tokio::main]
//! async fn main() -> io::Result<()> { //! async fn main() -> anyhow::Result<()> {
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); //! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
//! //!
//! let server = server::BaseChannel::with_defaults(server_transport); //! let server = server::BaseChannel::with_defaults(server_transport);
@@ -177,7 +183,7 @@
//! //!
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` //! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
//! // that takes a config and any Transport as input. //! // that takes a config and any Transport as input.
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?; //! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
//! //!
//! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same
//! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // args as defined, with the addition of a Context, which is always the first arg. The Context
@@ -199,6 +205,7 @@
#![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, feature(doc_cfg))]
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
#[doc(hidden)]
pub use serde; pub use serde;
#[cfg(feature = "serde-transport")] #[cfg(feature = "serde-transport")]
@@ -303,7 +310,7 @@ pub use crate::transport::sealed::Transport;
use anyhow::Context as _; use anyhow::Context as _;
use futures::task::*; use futures::task::*;
use std::{fmt::Display, io, time::SystemTime}; use std::{error::Error, fmt::Display, io, time::SystemTime};
/// A message from a client to a server. /// A message from a client to a server.
#[derive(Debug)] #[derive(Debug)]
@@ -355,8 +362,9 @@ pub struct Response<T> {
pub message: Result<T, ServerError>, pub message: Result<T, ServerError>,
} }
/// An error response from a server to a client. /// An error indicating the server aborted the request early, e.g., due to request throttling.
#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
#[error("{kind:?}: {detail}")]
#[non_exhaustive] #[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct ServerError { pub struct ServerError {
@@ -371,13 +379,7 @@ pub struct ServerError {
/// The type of error that occurred to fail the request. /// The type of error that occurred to fail the request.
pub kind: io::ErrorKind, pub kind: io::ErrorKind,
/// A message describing more detail about the error that occurred. /// A message describing more detail about the error that occurred.
pub detail: Option<String>, pub detail: String,
}
impl From<ServerError> for io::Error {
fn from(e: ServerError) -> io::Error {
io::Error::new(e.kind, e.detail.unwrap_or_default())
}
} }
impl<T> Request<T> { impl<T> Request<T> {
@@ -387,7 +389,6 @@ impl<T> Request<T> {
} }
} }
pub(crate) type PollIo<T> = Poll<Option<io::Result<T>>>;
pub(crate) trait PollContext<T> { pub(crate) trait PollContext<T> {
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>> fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
where where
@@ -399,7 +400,10 @@ pub(crate) trait PollContext<T> {
F: FnOnce() -> C; F: FnOnce() -> C;
} }
impl<T> PollContext<T> for PollIo<T> { impl<T, E> PollContext<T> for Poll<Option<Result<T, E>>>
where
E: Error + Send + Sync + 'static,
{
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>> fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
where where
C: Display + Send + Sync + 'static, C: Display + Send + Sync + 'static,

View File

@@ -42,14 +42,10 @@ where
type Item = io::Result<Item>; type Item = io::Result<Item>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
match self.project().inner.poll_next(cx) { self.project()
Poll::Pending => Poll::Pending, .inner
Poll::Ready(None) => Poll::Ready(None), .poll_next(cx)
Poll::Ready(Some(Ok::<_, CodecError>(next))) => Poll::Ready(Some(Ok(next))), .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
Poll::Ready(Some(Err::<_, CodecError>(e))) => {
Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e))))
}
}
} }
} }
@@ -65,7 +61,10 @@ where
type Error = io::Error; type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_ready(cx)) self.project()
.inner
.poll_ready(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
} }
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
@@ -76,20 +75,20 @@ where
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_flush(cx)) self.project()
.inner
.poll_flush(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
} }
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.project().inner.poll_close(cx)) self.project()
.inner
.poll_close(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
} }
} }
fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
poll: Poll<Result<(), E>>,
) -> Poll<io::Result<()>> {
poll.map(|ready| ready.map_err(|e| io::Error::new(io::ErrorKind::Other, e)))
}
/// Constructs a new transport from a framed transport and a serialization codec. /// Constructs a new transport from a framed transport and a serialization codec.
pub fn new<S, Item, SinkItem, Codec>( pub fn new<S, Item, SinkItem, Codec>(
framed_io: Framed<S, LengthDelimitedCodec>, framed_io: Framed<S, LengthDelimitedCodec>,

File diff suppressed because it is too large Load Diff

View File

@@ -1,19 +1,13 @@
use crate::{ use crate::util::{Compact, TimeUntil};
util::{Compact, TimeUntil},
PollIo,
};
use fnv::FnvHashMap; use fnv::FnvHashMap;
use futures::{ use futures::future::{AbortHandle, AbortRegistration};
future::{AbortHandle, AbortRegistration},
ready,
};
use std::{ use std::{
collections::hash_map, collections::hash_map,
io,
task::{Context, Poll}, task::{Context, Poll},
time::SystemTime, time::SystemTime,
}; };
use tokio_util::time::delay_queue::{self, DelayQueue}; use tokio_util::time::delay_queue::{self, DelayQueue};
use tracing::Span;
/// A data structure that tracks in-flight requests. It aborts requests, /// A data structure that tracks in-flight requests. It aborts requests,
/// either on demand or when a request deadline expires. /// either on demand or when a request deadline expires.
@@ -23,13 +17,15 @@ pub struct InFlightRequests {
deadlines: DelayQueue<u64>, deadlines: DelayQueue<u64>,
} }
#[derive(Debug)]
/// Data needed to clean up a single in-flight request. /// Data needed to clean up a single in-flight request.
#[derive(Debug)]
struct RequestData { struct RequestData {
/// Aborts the response handler for the associated request. /// Aborts the response handler for the associated request.
abort_handle: AbortHandle, abort_handle: AbortHandle,
/// The key to remove the timer for the request's deadline. /// The key to remove the timer for the request's deadline.
deadline_key: delay_queue::Key, deadline_key: delay_queue::Key,
/// The client span.
span: Span,
} }
/// An error returned when a request attempted to start with the same ID as a request already /// An error returned when a request attempted to start with the same ID as a request already
@@ -48,6 +44,7 @@ impl InFlightRequests {
&mut self, &mut self,
request_id: u64, request_id: u64,
deadline: SystemTime, deadline: SystemTime,
span: Span,
) -> Result<AbortRegistration, AlreadyExistsError> { ) -> Result<AbortRegistration, AlreadyExistsError> {
match self.request_data.entry(request_id) { match self.request_data.entry(request_id) {
hash_map::Entry::Vacant(vacant) => { hash_map::Entry::Vacant(vacant) => {
@@ -57,6 +54,7 @@ impl InFlightRequests {
vacant.insert(RequestData { vacant.insert(RequestData {
abort_handle, abort_handle,
deadline_key, deadline_key,
span,
}); });
Ok(abort_registration) Ok(abort_registration)
} }
@@ -66,12 +64,17 @@ impl InFlightRequests {
/// Cancels an in-flight request. Returns true iff the request was found. /// Cancels an in-flight request. Returns true iff the request was found.
pub fn cancel_request(&mut self, request_id: u64) -> bool { pub fn cancel_request(&mut self, request_id: u64) -> bool {
if let Some(request_data) = self.request_data.remove(&request_id) { if let Some(RequestData {
span,
abort_handle,
deadline_key,
}) = self.request_data.remove(&request_id)
{
let _entered = span.enter();
self.request_data.compact(0.1); self.request_data.compact(0.1);
abort_handle.abort();
request_data.abort_handle.abort(); self.deadlines.remove(&deadline_key);
self.deadlines.remove(&request_data.deadline_key); tracing::info!("ReceiveCancel");
true true
} else { } else {
false false
@@ -80,30 +83,37 @@ impl InFlightRequests {
/// Removes a request without aborting. Returns true iff the request was found. /// Removes a request without aborting. Returns true iff the request was found.
/// This method should be used when a response is being sent. /// This method should be used when a response is being sent.
pub fn remove_request(&mut self, request_id: u64) -> bool { pub fn remove_request(&mut self, request_id: u64) -> Option<Span> {
if let Some(request_data) = self.request_data.remove(&request_id) { if let Some(request_data) = self.request_data.remove(&request_id) {
self.request_data.compact(0.1); self.request_data.compact(0.1);
self.deadlines.remove(&request_data.deadline_key); self.deadlines.remove(&request_data.deadline_key);
Some(request_data.span)
true
} else { } else {
false None
} }
} }
/// Yields a request that has expired, aborting any ongoing processing of that request. /// Yields a request that has expired, aborting any ongoing processing of that request.
pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo<u64> { pub fn poll_expired(
Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) { &mut self,
Some(Ok(expired)) => { cx: &mut Context,
if let Some(request_data) = self.request_data.remove(expired.get_ref()) { ) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
self.request_data.compact(0.1); if self.deadlines.is_empty() {
request_data.abort_handle.abort(); // TODO(https://github.com/tokio-rs/tokio/issues/4161)
} // This is a workaround for DelayQueue not always treating this case correctly.
Some(Ok(expired.into_inner())) return Poll::Ready(None);
}
self.deadlines.poll_expired(cx).map_ok(|expired| {
if let Some(RequestData {
abort_handle, span, ..
}) = self.request_data.remove(expired.get_ref())
{
let _entered = span.enter();
self.request_data.compact(0.1);
abort_handle.abort();
tracing::error!("DeadlineExceeded");
} }
Some(Err(e)) => Some(Err(io::Error::new(io::ErrorKind::Other, e))), expired.into_inner()
None => None,
}) })
} }
} }
@@ -118,75 +128,96 @@ impl Drop for InFlightRequests {
} }
#[cfg(test)] #[cfg(test)]
use { mod tests {
assert_matches::assert_matches, use super::*;
futures::{
use assert_matches::assert_matches;
use futures::{
future::{pending, Abortable}, future::{pending, Abortable},
FutureExt, FutureExt,
}, };
futures_test::task::noop_context, use futures_test::task::noop_context;
};
#[tokio::test] #[tokio::test]
async fn start_request_increases_len() { async fn start_request_increases_len() {
let mut in_flight_requests = InFlightRequests::default(); let mut in_flight_requests = InFlightRequests::default();
assert_eq!(in_flight_requests.len(), 0); assert_eq!(in_flight_requests.len(), 0);
in_flight_requests in_flight_requests
.start_request(0, SystemTime::now()) .start_request(0, SystemTime::now(), Span::current())
.unwrap(); .unwrap();
assert_eq!(in_flight_requests.len(), 1); assert_eq!(in_flight_requests.len(), 1);
} }
#[tokio::test] #[tokio::test]
async fn polling_expired_aborts() { async fn polling_expired_aborts() {
let mut in_flight_requests = InFlightRequests::default(); let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests let abort_registration = in_flight_requests
.start_request(0, SystemTime::now()) .start_request(0, SystemTime::now(), Span::current())
.unwrap(); .unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
tokio::time::pause(); tokio::time::pause();
tokio::time::advance(std::time::Duration::from_secs(1000)).await; tokio::time::advance(std::time::Duration::from_secs(1000)).await;
assert_matches!( assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()), in_flight_requests.poll_expired(&mut noop_context()),
Poll::Ready(Some(Ok(_))) Poll::Ready(Some(Ok(_)))
); );
assert_matches!( assert_matches!(
abortable_future.poll_unpin(&mut noop_context()), abortable_future.poll_unpin(&mut noop_context()),
Poll::Ready(Err(_)) Poll::Ready(Err(_))
); );
assert_eq!(in_flight_requests.len(), 0); assert_eq!(in_flight_requests.len(), 0);
} }
#[tokio::test] #[tokio::test]
async fn cancel_request_aborts() { async fn cancel_request_aborts() {
let mut in_flight_requests = InFlightRequests::default(); let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests let abort_registration = in_flight_requests
.start_request(0, SystemTime::now()) .start_request(0, SystemTime::now(), Span::current())
.unwrap(); .unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
assert_eq!(in_flight_requests.cancel_request(0), true); assert_eq!(in_flight_requests.cancel_request(0), true);
assert_matches!( assert_matches!(
abortable_future.poll_unpin(&mut noop_context()), abortable_future.poll_unpin(&mut noop_context()),
Poll::Ready(Err(_)) Poll::Ready(Err(_))
); );
assert_eq!(in_flight_requests.len(), 0); assert_eq!(in_flight_requests.len(), 0);
} }
#[tokio::test] #[tokio::test]
async fn remove_request_doesnt_abort() { async fn remove_request_doesnt_abort() {
let mut in_flight_requests = InFlightRequests::default(); let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests assert!(in_flight_requests.deadlines.is_empty());
.start_request(0, SystemTime::now())
.unwrap(); let abort_registration = in_flight_requests
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); .start_request(
0,
assert_eq!(in_flight_requests.remove_request(0), true); SystemTime::now() + std::time::Duration::from_secs(10),
assert_matches!( Span::current(),
abortable_future.poll_unpin(&mut noop_context()), )
Poll::Pending .unwrap();
); let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
assert_eq!(in_flight_requests.len(), 0);
// Precondition: Pending expiration
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Pending
);
assert!(!in_flight_requests.deadlines.is_empty());
assert_matches!(in_flight_requests.remove_request(0), Some(_));
// Postcondition: No pending expirations
assert!(in_flight_requests.deadlines.is_empty());
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Ready(None)
);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Pending
);
assert_eq!(in_flight_requests.len(), 0);
}
} }

View File

@@ -0,0 +1,49 @@
use super::{
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
Channel,
};
use futures::prelude::*;
use std::{fmt, hash::Hash};
#[cfg(feature = "tokio1")]
use super::{tokio::TokioServerExecutor, Serve};
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
pub trait Incoming<C>
where
Self: Sized + Stream<Item = C>,
C: Channel,
{
/// Enforces channel per-key limits.
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> MaxChannelsPerKey<Self, K, KF>
where
K: fmt::Display + Eq + Hash + Clone + Unpin,
KF: Fn(&C) -> K,
{
MaxChannelsPerKey::new(self, n, keymaker)
}
/// Caps the number of concurrent requests per channel.
fn max_concurrent_requests_per_channel(self, n: usize) -> MaxRequestsPerChannel<Self> {
MaxRequestsPerChannel::new(self, n)
}
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
/// concurrently by spawning on tokio's default executor, and each request will be also
/// be spawned on tokio's default executor.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
where
S: Serve<C::Req, Resp = C::Resp>,
{
TokioServerExecutor::new(self, serve)
}
}
impl<S, C> Incoming<C> for S
where
S: Sized + Stream<Item = C>,
C: Channel,
{
}

View File

@@ -0,0 +1,5 @@
/// Provides functionality to limit the number of active channels.
pub mod channels_per_key;
/// Provides a [channel](crate::server::Channel) that limits the number of in-flight requests.
pub mod requests_per_channel;

View File

@@ -9,20 +9,23 @@ use crate::{
util::Compact, util::Compact,
}; };
use fnv::FnvHashMap; use fnv::FnvHashMap;
use futures::{future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*}; use futures::{prelude::*, ready, stream::Fuse, task::*};
use log::{debug, info, trace};
use pin_project::pin_project; use pin_project::pin_project;
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
use std::{ use std::{
collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin, collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin,
time::SystemTime,
}; };
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing::{debug, info, trace};
/// A single-threaded filter that drops channels based on per-key limits. /// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on
/// per-key limits.
///
/// The decision to drop a Channel is made once at the time the Channel materializes. Once a
/// Channel is yielded, it will not be prematurely dropped.
#[pin_project] #[pin_project]
#[derive(Debug)] #[derive(Debug)]
pub struct ChannelFilter<S, K, F> pub struct MaxChannelsPerKey<S, K, F>
where where
K: Eq + Hash, K: Eq + Hash,
{ {
@@ -35,7 +38,7 @@ where
keymaker: F, keymaker: F,
} }
/// A channel that is tracked by a ChannelFilter. /// A channel that is tracked by [`MaxChannelsPerKey`].
#[pin_project] #[pin_project]
#[derive(Debug)] #[derive(Debug)]
pub struct TrackedChannel<C, K> { pub struct TrackedChannel<C, K> {
@@ -103,6 +106,7 @@ where
{ {
type Req = C::Req; type Req = C::Req;
type Resp = C::Resp; type Resp = C::Resp;
type Transport = C::Transport;
fn config(&self) -> &server::Config { fn config(&self) -> &server::Config {
self.inner.config() self.inner.config()
@@ -112,12 +116,8 @@ where
self.inner.in_flight_requests() self.inner.in_flight_requests()
} }
fn start_request( fn transport(&self) -> &Self::Transport {
mut self: Pin<&mut Self>, self.inner.transport()
id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
self.inner_pin_mut().start_request(id, deadline)
} }
} }
@@ -133,7 +133,7 @@ impl<C, K> TrackedChannel<C, K> {
} }
} }
impl<S, K, F> ChannelFilter<S, K, F> impl<S, K, F> MaxChannelsPerKey<S, K, F>
where where
K: Eq + Hash, K: Eq + Hash,
S: Stream, S: Stream,
@@ -142,7 +142,7 @@ where
/// Sheds new channels to stay under configured limits. /// Sheds new channels to stay under configured limits.
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self { pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel(); let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel();
ChannelFilter { MaxChannelsPerKey {
listener: listener.fuse(), listener: listener.fuse(),
channels_per_key, channels_per_key,
dropped_keys, dropped_keys,
@@ -153,7 +153,7 @@ where
} }
} }
impl<S, K, F> ChannelFilter<S, K, F> impl<S, K, F> MaxChannelsPerKey<S, K, F>
where where
S: Stream, S: Stream,
K: fmt::Display + Eq + Hash + Clone + Unpin, K: fmt::Display + Eq + Hash + Clone + Unpin,
@@ -171,11 +171,10 @@ where
let tracker = self.as_mut().increment_channels_for_key(key.clone())?; let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
trace!( trace!(
"[{}] Opening channel ({}/{}) channels for key.", channel_filter_key = %key,
key, open_channels = Arc::strong_count(&tracker),
Arc::strong_count(&tracker), max_open_channels = self.channels_per_key,
self.channels_per_key "Opening channel");
);
Ok(TrackedChannel { Ok(TrackedChannel {
tracker, tracker,
@@ -200,9 +199,10 @@ where
let count = o.get().strong_count(); let count = o.get().strong_count();
if count >= TryFrom::try_from(*self_.channels_per_key).unwrap() { if count >= TryFrom::try_from(*self_.channels_per_key).unwrap() {
info!( info!(
"[{}] Opened max channels from key ({}/{}).", channel_filter_key = %key,
key, count, self_.channels_per_key open_channels = count,
); max_open_channels = *self_.channels_per_key,
"At open channel limit");
Err(key) Err(key)
} else { } else {
Ok(o.get().upgrade().unwrap_or_else(|| { Ok(o.get().upgrade().unwrap_or_else(|| {
@@ -233,7 +233,9 @@ where
let self_ = self.project(); let self_ = self.project();
match ready!(self_.dropped_keys.poll_recv(cx)) { match ready!(self_.dropped_keys.poll_recv(cx)) {
Some(key) => { Some(key) => {
debug!("All channels dropped for key [{}]", key); debug!(
channel_filter_key = %key,
"All channels dropped");
self_.key_counts.remove(&key); self_.key_counts.remove(&key);
self_.key_counts.compact(0.1); self_.key_counts.compact(0.1);
Poll::Ready(()) Poll::Ready(())
@@ -243,7 +245,7 @@ where
} }
} }
impl<S, K, F> Stream for ChannelFilter<S, K, F> impl<S, K, F> Stream for MaxChannelsPerKey<S, K, F>
where where
S: Stream, S: Stream,
K: fmt::Display + Eq + Hash + Clone + Unpin, K: fmt::Display + Eq + Hash + Clone + Unpin,
@@ -346,7 +348,7 @@ fn channel_filter_increment_channels_for_key() {
key: &'static str, key: &'static str,
} }
let (_, listener) = futures::channel::mpsc::unbounded(); let (_, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter); pin_mut!(filter);
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap(); let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
assert_eq!(Arc::strong_count(&tracker1), 1); assert_eq!(Arc::strong_count(&tracker1), 1);
@@ -367,7 +369,7 @@ fn channel_filter_handle_new_channel() {
key: &'static str, key: &'static str,
} }
let (_, listener) = futures::channel::mpsc::unbounded(); let (_, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter); pin_mut!(filter);
let channel1 = filter let channel1 = filter
.as_mut() .as_mut()
@@ -399,7 +401,7 @@ fn channel_filter_poll_listener() {
key: &'static str, key: &'static str,
} }
let (new_channels, listener) = futures::channel::mpsc::unbounded(); let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter); pin_mut!(filter);
new_channels new_channels
@@ -435,7 +437,7 @@ fn channel_filter_poll_closed_channels() {
key: &'static str, key: &'static str,
} }
let (new_channels, listener) = futures::channel::mpsc::unbounded(); let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter); pin_mut!(filter);
new_channels new_channels
@@ -463,7 +465,7 @@ fn channel_filter_stream() {
key: &'static str, key: &'static str,
} }
let (new_channels, listener) = futures::channel::mpsc::unbounded(); let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter); pin_mut!(filter);
new_channels new_channels

View File

@@ -0,0 +1,349 @@
// Copyright 2020 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use crate::{
server::{Channel, Config},
Response, ServerError,
};
use futures::{prelude::*, ready, task::*};
use pin_project::pin_project;
use std::{io, pin::Pin};
/// A [`Channel`] that limits the number of concurrent requests by throttling.
///
/// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low
/// for the resources available to the server. For production use cases, a more advanced throttler
/// is likely needed.
#[pin_project]
#[derive(Debug)]
pub struct MaxRequests<C> {
max_in_flight_requests: usize,
#[pin]
inner: C,
}
impl<C> MaxRequests<C> {
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
}
}
impl<C> MaxRequests<C>
where
C: Channel,
{
/// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to
/// `max_in_flight_requests`.
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
MaxRequests {
max_in_flight_requests,
inner,
}
}
}
impl<C> Stream for MaxRequests<C>
where
C: Channel,
{
type Item = <C as Stream>::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
{
ready!(self.as_mut().project().inner.poll_ready(cx)?);
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
Some(r) => {
let _entered = r.span.enter();
tracing::info!(
in_flight_requests = self.as_mut().in_flight_requests(),
"ThrottleRequest",
);
self.as_mut().start_send(Response {
request_id: r.request.id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: "server throttled the request.".into(),
}),
})?;
}
None => return Poll::Ready(None),
}
}
self.project().inner.poll_next(cx)
}
}
impl<C> Sink<Response<<C as Channel>::Resp>> for MaxRequests<C>
where
C: Channel,
{
type Error = C::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(
self: Pin<&mut Self>,
item: Response<<C as Channel>::Resp>,
) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}
impl<C> AsRef<C> for MaxRequests<C> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C> Channel for MaxRequests<C>
where
C: Channel,
{
type Req = <C as Channel>::Req;
type Resp = <C as Channel>::Resp;
type Transport = <C as Channel>::Transport;
fn in_flight_requests(&self) -> usize {
self.inner.in_flight_requests()
}
fn config(&self) -> &Config {
self.inner.config()
}
fn transport(&self) -> &Self::Transport {
self.inner.transport()
}
}
/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on
/// the number of in-flight requests.
#[pin_project]
#[derive(Debug)]
pub struct MaxRequestsPerChannel<S> {
#[pin]
inner: S,
max_in_flight_requests: usize,
}
impl<S> MaxRequestsPerChannel<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
Self {
inner,
max_in_flight_requests,
}
}
}
impl<S> Stream for MaxRequestsPerChannel<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
type Item = MaxRequests<<S as Stream>::Item>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match ready!(self.as_mut().project().inner.poll_next(cx)) {
Some(channel) => Poll::Ready(Some(MaxRequests::new(
channel,
*self.project().max_in_flight_requests,
))),
None => Poll::Ready(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::server::{
testing::{self, FakeChannel, PollExt},
TrackedRequest,
};
use pin_utils::pin_mut;
use std::{
marker::PhantomData,
time::{Duration, SystemTime},
};
use tracing::Span;
#[tokio::test]
async fn throttler_in_flight_requests() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
for i in 0..5 {
throttler
.inner
.in_flight_requests
.start_request(
i,
SystemTime::now() + Duration::from_secs(1),
Span::current(),
)
.unwrap();
}
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
}
#[test]
fn throttler_poll_next_done() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
}
#[test]
fn throttler_poll_next_some() -> io::Result<()> {
let throttler = MaxRequests {
max_in_flight_requests: 1,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(0, 1);
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
assert_eq!(
throttler
.as_mut()
.poll_next(&mut testing::cx())?
.map(|r| r.map(|r| (r.request.id, r.request.message))),
Poll::Ready(Some((0, 1)))
);
Ok(())
}
#[test]
fn throttler_poll_next_throttled() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(1, 1);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
assert_eq!(throttler.inner.sink.len(), 1);
let resp = throttler.inner.sink.get(0).unwrap();
assert_eq!(resp.request_id, 1);
assert!(resp.message.is_err());
}
#[test]
fn throttler_poll_next_throttled_sink_not_ready() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: PendingSink::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
struct PendingSink<In, Out> {
ghost: PhantomData<fn(Out) -> In>,
}
impl PendingSink<(), ()> {
pub fn default<Req, Resp>(
) -> PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
PendingSink { ghost: PhantomData }
}
}
impl<In, Out> Stream for PendingSink<In, Out> {
type Item = In;
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
unimplemented!()
}
}
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
}
impl<Req, Resp> Channel for PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
type Req = Req;
type Resp = Resp;
type Transport = ();
fn config(&self) -> &Config {
unimplemented!()
}
fn in_flight_requests(&self) -> usize {
0
}
fn transport(&self) -> &() {
&()
}
}
}
#[tokio::test]
async fn throttler_start_send() {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler
.inner
.in_flight_requests
.start_request(
0,
SystemTime::now() + Duration::from_secs(1),
Span::current(),
)
.unwrap();
throttler
.as_mut()
.start_send(Response {
request_id: 0,
message: Ok(1),
})
.unwrap();
assert_eq!(throttler.inner.in_flight_requests.len(), 0);
assert_eq!(
throttler.inner.sink.get(0),
Some(&Response {
request_id: 0,
message: Ok(1),
})
);
}
}

View File

@@ -6,15 +6,13 @@
use crate::{ use crate::{
context, context,
server::{Channel, Config}, server::{Channel, Config, TrackedRequest},
Request, Response, Request, Response,
}; };
use futures::{future::AbortRegistration, task::*, Sink, Stream}; use futures::{task::*, Sink, Stream};
use pin_project::pin_project; use pin_project::pin_project;
use std::collections::VecDeque; use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
use std::io; use tracing::Span;
use std::pin::Pin;
use std::time::SystemTime;
#[pin_project] #[pin_project]
pub(crate) struct FakeChannel<In, Out> { pub(crate) struct FakeChannel<In, Out> {
@@ -64,12 +62,13 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
} }
} }
impl<Req, Resp> Channel for FakeChannel<io::Result<Request<Req>>, Response<Resp>> impl<Req, Resp> Channel for FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>>
where where
Req: Unpin, Req: Unpin,
{ {
type Req = Req; type Req = Req;
type Resp = Resp; type Resp = Resp;
type Transport = ();
fn config(&self) -> &Config { fn config(&self) -> &Config {
&self.config &self.config
@@ -79,32 +78,31 @@ where
self.in_flight_requests.len() self.in_flight_requests.len()
} }
fn start_request( fn transport(&self) -> &() {
self: Pin<&mut Self>, &()
id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
self.project()
.in_flight_requests
.start_request(id, deadline)
} }
} }
impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> { impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
pub fn push_req(&mut self, id: u64, message: Req) { pub fn push_req(&mut self, id: u64, message: Req) {
self.stream.push_back(Ok(Request { let (_, abort_registration) = futures::future::AbortHandle::new_pair();
context: context::Context { self.stream.push_back(Ok(TrackedRequest {
deadline: SystemTime::UNIX_EPOCH, request: Request {
trace_context: Default::default(), context: context::Context {
deadline: SystemTime::UNIX_EPOCH,
trace_context: Default::default(),
},
id,
message,
}, },
id, abort_registration,
message, span: Span::none(),
})); }));
} }
} }
impl FakeChannel<(), ()> { impl FakeChannel<(), ()> {
pub fn default<Req, Resp>() -> FakeChannel<io::Result<Request<Req>>, Response<Resp>> { pub fn default<Req, Resp>() -> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
FakeChannel { FakeChannel {
stream: Default::default(), stream: Default::default(),
sink: Default::default(), sink: Default::default(),

View File

@@ -1,347 +0,0 @@
// Copyright 2020 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use super::{Channel, Config};
use crate::{Response, ServerError};
use futures::{future::AbortRegistration, prelude::*, ready, task::*};
use log::debug;
use pin_project::pin_project;
use std::{io, pin::Pin, time::SystemTime};
/// A [`Channel`] that limits the number of concurrent
/// requests by throttling.
#[pin_project]
#[derive(Debug)]
pub struct Throttler<C> {
max_in_flight_requests: usize,
#[pin]
inner: C,
}
impl<C> Throttler<C> {
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
}
}
impl<C> Throttler<C>
where
C: Channel,
{
/// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to
/// `max_in_flight_requests`.
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
Throttler {
max_in_flight_requests,
inner,
}
}
}
impl<C> Stream for Throttler<C>
where
C: Channel,
{
type Item = <C as Stream>::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
{
ready!(self.as_mut().project().inner.poll_ready(cx)?);
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
Some(request) => {
debug!(
"[{}] Client has reached in-flight request limit ({}/{}).",
request.context.trace_id(),
self.as_mut().in_flight_requests(),
self.as_mut().project().max_in_flight_requests,
);
self.as_mut().start_send(Response {
request_id: request.id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()),
}),
})?;
}
None => return Poll::Ready(None),
}
}
self.project().inner.poll_next(cx)
}
}
impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
where
C: Channel,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.project().inner.poll_close(cx)
}
}
impl<C> AsRef<C> for Throttler<C> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C> Channel for Throttler<C>
where
C: Channel,
{
type Req = <C as Channel>::Req;
type Resp = <C as Channel>::Resp;
fn in_flight_requests(&self) -> usize {
self.inner.in_flight_requests()
}
fn config(&self) -> &Config {
self.inner.config()
}
fn start_request(
self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
self.project().inner.start_request(id, deadline)
}
}
/// A stream of throttling channels.
#[pin_project]
#[derive(Debug)]
pub struct ThrottlerStream<S> {
#[pin]
inner: S,
max_in_flight_requests: usize,
}
impl<S> ThrottlerStream<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
Self {
inner,
max_in_flight_requests,
}
}
}
impl<S> Stream for ThrottlerStream<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
type Item = Throttler<<S as Stream>::Item>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match ready!(self.as_mut().project().inner.poll_next(cx)) {
Some(channel) => Poll::Ready(Some(Throttler::new(
channel,
*self.project().max_in_flight_requests,
))),
None => Poll::Ready(None),
}
}
}
#[cfg(test)]
use super::testing::{self, FakeChannel, PollExt};
#[cfg(test)]
use crate::Request;
#[cfg(test)]
use pin_utils::pin_mut;
#[cfg(test)]
use std::{marker::PhantomData, time::Duration};
#[tokio::test]
async fn throttler_in_flight_requests() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
for i in 0..5 {
throttler
.inner
.in_flight_requests
.start_request(i, SystemTime::now() + Duration::from_secs(1))
.unwrap();
}
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
}
#[tokio::test]
async fn throttler_start_request() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler
.as_mut()
.start_request(1, SystemTime::now() + Duration::from_secs(1))
.unwrap();
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
}
#[test]
fn throttler_poll_next_done() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
}
#[test]
fn throttler_poll_next_some() -> io::Result<()> {
let throttler = Throttler {
max_in_flight_requests: 1,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(0, 1);
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
assert_eq!(
throttler
.as_mut()
.poll_next(&mut testing::cx())?
.map(|r| r.map(|r| (r.id, r.message))),
Poll::Ready(Some((0, 1)))
);
Ok(())
}
#[test]
fn throttler_poll_next_throttled() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(1, 1);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
assert_eq!(throttler.inner.sink.len(), 1);
let resp = throttler.inner.sink.get(0).unwrap();
assert_eq!(resp.request_id, 1);
assert!(resp.message.is_err());
}
#[test]
fn throttler_poll_next_throttled_sink_not_ready() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: PendingSink::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
struct PendingSink<In, Out> {
ghost: PhantomData<fn(Out) -> In>,
}
impl PendingSink<(), ()> {
pub fn default<Req, Resp>() -> PendingSink<io::Result<Request<Req>>, Response<Resp>> {
PendingSink { ghost: PhantomData }
}
}
impl<In, Out> Stream for PendingSink<In, Out> {
type Item = In;
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
unimplemented!()
}
}
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
}
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>> {
type Req = Req;
type Resp = Resp;
fn config(&self) -> &Config {
unimplemented!()
}
fn in_flight_requests(&self) -> usize {
0
}
fn start_request(
self: Pin<&mut Self>,
_id: u64,
_deadline: SystemTime,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
unimplemented!()
}
}
}
#[tokio::test]
async fn throttler_start_send() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler
.inner
.in_flight_requests
.start_request(0, SystemTime::now() + Duration::from_secs(1))
.unwrap();
throttler
.as_mut()
.start_send(Response {
request_id: 0,
message: Ok(1),
})
.unwrap();
assert_eq!(throttler.inner.in_flight_requests.len(), 0);
assert_eq!(
throttler.inner.sink.get(0),
Some(&Response {
request_id: 0,
message: Ok(1),
})
);
}

111
tarpc/src/server/tokio.rs Normal file
View File

@@ -0,0 +1,111 @@
use super::{Channel, Requests, Serve};
use futures::{prelude::*, ready, task::*};
use pin_project::pin_project;
use std::pin::Pin;
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
/// for each new channel. Returned by
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
#[pin_project]
#[derive(Debug)]
pub struct TokioServerExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
impl<T, S> TokioServerExecutor<T, S> {
pub(crate) fn new(inner: T, serve: S) -> Self {
Self { inner, serve }
}
}
/// A future that drives the server by [spawning](tokio::spawn) each [response
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
/// [`Channel::execute`](crate::server::Channel::execute).
#[pin_project]
#[derive(Debug)]
pub struct TokioChannelExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
impl<T, S> TokioServerExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
impl<T, S> TokioChannelExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
// Send + 'static execution helper methods.
impl<C> Requests<C>
where
C: Channel,
C::Req: Send + 'static,
C::Resp: Send + 'static,
{
/// Executes all requests using the given service function. Requests are handled concurrently
/// by [spawning](::tokio::spawn) each handler on tokio's default executor.
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
where
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
{
TokioChannelExecutor { inner: self, serve }
}
}
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
where
St: Sized + Stream<Item = C>,
C: Channel + Send + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
Se::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
tokio::spawn(channel.execute(self.serve.clone()));
}
tracing::info!("Server shutting down.");
Poll::Ready(())
}
}
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
where
C: Channel + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
S::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
match response_handler {
Ok(resp) => {
let server = self.serve.clone();
tokio::spawn(async move {
resp.execute(server).await;
});
}
Err(e) => {
tracing::warn!("Requests stream errored out: {}", e);
break;
}
}
}
Poll::Ready(())
}
}

View File

@@ -16,11 +16,14 @@
//! This crate's design is based on [opencensus //! This crate's design is based on [opencensus
//! tracing](https://opencensus.io/core-concepts/tracing/). //! tracing](https://opencensus.io/core-concepts/tracing/).
use opentelemetry::trace::TraceContextExt;
use rand::Rng; use rand::Rng;
use std::{ use std::{
convert::TryFrom,
fmt::{self, Formatter}, fmt::{self, Formatter},
mem, num::{NonZeroU128, NonZeroU64},
}; };
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// A context for tracing the execution of processes, distributed or otherwise. /// A context for tracing the execution of processes, distributed or otherwise.
/// ///
@@ -36,33 +39,50 @@ pub struct Context {
/// before making an RPC, and the span ID is sent to the server. The server is free to create /// before making an RPC, and the span ID is sent to the server. The server is free to create
/// its own spans, for which it sets the client's span as the parent span. /// its own spans, for which it sets the client's span as the parent span.
pub span_id: SpanId, pub span_id: SpanId,
/// An identifier of the span that originated the current span. For example, if a server sends /// Indicates whether a sampler has already decided whether or not to sample the trace
/// an RPC in response to a client request that included a span, the server would create a span /// associated with the Context. If `sampling_decision` is None, then a decision has not yet
/// for the RPC and set its parent to the span_id in the incoming request's context. /// been made. Downstream samplers do not need to abide by "no sample" decisions--for example,
/// /// an upstream client may choose to never sample, which may not make sense for the client's
/// If `parent_id` is `None`, then this is a root context. /// dependencies. On the other hand, if an upstream process has chosen to sample this trace,
pub parent_id: Option<SpanId>, /// then the downstream samplers are expected to respect that decision and also sample the
/// trace. Otherwise, the full trace would not be able to be reconstructed.
pub sampling_decision: SamplingDecision,
} }
/// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the /// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the
/// same trace ID. /// same trace ID.
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)] #[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct TraceId(u128); pub struct TraceId(#[cfg_attr(feature = "serde1", serde(with = "u128_serde"))] u128);
/// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace. /// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace.
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)] #[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct SpanId(u64); pub struct SpanId(u64);
/// Indicates whether a sampler has decided whether or not to sample the trace associated with the
/// Context. Downstream samplers do not need to abide by "no sample" decisions--for example, an
/// upstream client may choose to never sample, which may not make sense for the client's
/// dependencies. On the other hand, if an upstream process has chosen to sample this trace, then
/// the downstream samplers are expected to respect that decision and also sample the trace.
/// Otherwise, the full trace would not be able to be reconstructed reliably.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[repr(u8)]
pub enum SamplingDecision {
/// The associated span was sampled by its creating process. Child spans must also be sampled.
Sampled,
/// The associated span was not sampled by its creating process.
Unsampled,
}
impl Context { impl Context {
/// Constructs a new root context. A root context is one with no parent span. /// Constructs a new context with the trace ID and sampling decision inherited from the parent.
pub fn new_root() -> Self { pub(crate) fn new_child(&self) -> Self {
let rng = &mut rand::thread_rng(); Self {
Context { trace_id: self.trace_id,
trace_id: TraceId::random(rng), span_id: SpanId::random(&mut rand::thread_rng()),
span_id: SpanId::random(rng), sampling_decision: self.sampling_decision,
parent_id: None,
} }
} }
} }
@@ -71,17 +91,128 @@ impl TraceId {
/// Returns a random trace ID that can be assumed to be globally unique if `rng` generates /// Returns a random trace ID that can be assumed to be globally unique if `rng` generates
/// actually-random numbers. /// actually-random numbers.
pub fn random<R: Rng>(rng: &mut R) -> Self { pub fn random<R: Rng>(rng: &mut R) -> Self {
TraceId(u128::from(rng.next_u64()) << mem::size_of::<u64>() | u128::from(rng.next_u64())) TraceId(rng.gen::<NonZeroU128>().get())
}
/// Returns true iff the trace ID is 0.
pub fn is_none(&self) -> bool {
self.0 == 0
} }
} }
impl SpanId { impl SpanId {
/// Returns a random span ID that can be assumed to be unique within a single trace. /// Returns a random span ID that can be assumed to be unique within a single trace.
pub fn random<R: Rng>(rng: &mut R) -> Self { pub fn random<R: Rng>(rng: &mut R) -> Self {
SpanId(rng.next_u64()) SpanId(rng.gen::<NonZeroU64>().get())
}
/// Returns true iff the span ID is 0.
pub fn is_none(&self) -> bool {
self.0 == 0
} }
} }
impl From<TraceId> for u128 {
fn from(trace_id: TraceId) -> Self {
trace_id.0
}
}
impl From<u128> for TraceId {
fn from(trace_id: u128) -> Self {
Self(trace_id)
}
}
impl From<SpanId> for u64 {
fn from(span_id: SpanId) -> Self {
span_id.0
}
}
impl From<u64> for SpanId {
fn from(span_id: u64) -> Self {
Self(span_id)
}
}
impl From<opentelemetry::trace::TraceId> for TraceId {
fn from(trace_id: opentelemetry::trace::TraceId) -> Self {
Self::from(trace_id.to_u128())
}
}
impl From<TraceId> for opentelemetry::trace::TraceId {
fn from(trace_id: TraceId) -> Self {
Self::from_u128(trace_id.into())
}
}
impl From<opentelemetry::trace::SpanId> for SpanId {
fn from(span_id: opentelemetry::trace::SpanId) -> Self {
Self::from(span_id.to_u64())
}
}
impl From<SpanId> for opentelemetry::trace::SpanId {
fn from(span_id: SpanId) -> Self {
Self::from_u64(span_id.0)
}
}
impl TryFrom<&tracing::Span> for Context {
type Error = NoActiveSpan;
fn try_from(span: &tracing::Span) -> Result<Self, NoActiveSpan> {
let context = span.context();
if context.has_active_span() {
Ok(Self::from(context.span()))
} else {
Err(NoActiveSpan)
}
}
}
impl From<opentelemetry::trace::SpanRef<'_>> for Context {
fn from(span: opentelemetry::trace::SpanRef<'_>) -> Self {
let otel_ctx = span.span_context();
Self {
trace_id: TraceId::from(otel_ctx.trace_id()),
span_id: SpanId::from(otel_ctx.span_id()),
sampling_decision: SamplingDecision::from(otel_ctx),
}
}
}
impl From<SamplingDecision> for opentelemetry::trace::TraceFlags {
fn from(decision: SamplingDecision) -> Self {
match decision {
SamplingDecision::Sampled => opentelemetry::trace::TraceFlags::SAMPLED,
SamplingDecision::Unsampled => opentelemetry::trace::TraceFlags::default(),
}
}
}
impl From<&opentelemetry::trace::SpanContext> for SamplingDecision {
fn from(context: &opentelemetry::trace::SpanContext) -> Self {
if context.is_sampled() {
SamplingDecision::Sampled
} else {
SamplingDecision::Unsampled
}
}
}
impl Default for SamplingDecision {
fn default() -> Self {
Self::Unsampled
}
}
/// Returned when a [`Context`] cannot be constructed from a [`Span`](tracing::Span).
#[derive(Debug)]
pub struct NoActiveSpan;
impl fmt::Display for TraceId { impl fmt::Display for TraceId {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "{:02x}", self.0)?; write!(f, "{:02x}", self.0)?;
@@ -89,9 +220,42 @@ impl fmt::Display for TraceId {
} }
} }
impl fmt::Debug for TraceId {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "{:02x}", self.0)?;
Ok(())
}
}
impl fmt::Display for SpanId { impl fmt::Display for SpanId {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "{:02x}", self.0)?; write!(f, "{:02x}", self.0)?;
Ok(()) Ok(())
} }
} }
impl fmt::Debug for SpanId {
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
write!(f, "{:02x}", self.0)?;
Ok(())
}
}
#[cfg(feature = "serde1")]
mod u128_serde {
pub fn serialize<S>(u: &u128, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serde::Serialize::serialize(&u.to_le_bytes(), serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<u128, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(u128::from_le_bytes(serde::Deserialize::deserialize(
deserializer,
)?))
}
}

View File

@@ -9,22 +9,32 @@
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`](sealed::Transport) //! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`](sealed::Transport)
//! can be plugged in, using whatever protocol it wants. //! can be plugged in, using whatever protocol it wants.
use futures::prelude::*;
use std::io;
pub mod channel; pub mod channel;
pub(crate) mod sealed { pub(crate) mod sealed {
use super::*; use futures::prelude::*;
use std::error::Error;
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages. /// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
pub trait Transport<SinkItem, Item>: pub trait Transport<SinkItem, Item>
Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error> where
Self: Stream<Item = Result<Item, <Self as Sink<SinkItem>>::Error>>,
Self: Sink<SinkItem, Error = <Self as Transport<SinkItem, Item>>::TransportError>,
<Self as Sink<SinkItem>>::Error: Error,
{ {
/// Associated type where clauses are not elaborated; this associated type allows users
/// bounding types by Transport to avoid having to explicitly add `T::Error: Error` to their
/// bounds.
type TransportError: Error + Send + Sync + 'static;
} }
impl<T, SinkItem, Item> Transport<SinkItem, Item> for T where impl<T, SinkItem, Item, E> Transport<SinkItem, Item> for T
T: Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error> + ?Sized where
T: ?Sized,
T: Stream<Item = Result<Item, E>>,
T: Sink<SinkItem, Error = E>,
T::Error: Error + Send + Sync + 'static,
{ {
type TransportError = E;
} }
} }

View File

@@ -6,13 +6,19 @@
//! Transports backed by in-memory channels. //! Transports backed by in-memory channels.
use crate::PollIo;
use futures::{task::*, Sink, Stream}; use futures::{task::*, Sink, Stream};
use pin_project::pin_project; use pin_project::pin_project;
use std::io; use std::{error::Error, pin::Pin};
use std::pin::Pin;
use tokio::sync::mpsc; use tokio::sync::mpsc;
/// Errors that occur in the sending or receiving of messages over a channel.
#[derive(thiserror::Error, Debug)]
pub enum ChannelError {
/// An error occurred sending over the channel.
#[error("an error occurred sending over the channel")]
Send(#[source] Box<dyn Error + Send + Sync + 'static>),
}
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's /// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
/// [`Sink`]. /// [`Sink`].
pub fn unbounded<SinkItem, Item>() -> ( pub fn unbounded<SinkItem, Item>() -> (
@@ -36,28 +42,33 @@ pub struct UnboundedChannel<Item, SinkItem> {
} }
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> { impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
type Item = Result<Item, io::Error>; type Item = Result<Item, ChannelError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> { fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, ChannelError>>> {
self.rx.poll_recv(cx).map(|option| option.map(Ok)) self.rx.poll_recv(cx).map(|option| option.map(Ok))
} }
} }
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> { const CLOSED_MESSAGE: &str = "the channel is closed and cannot accept new items for sending";
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
type Error = ChannelError;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(if self.tx.is_closed() { Poll::Ready(if self.tx.is_closed() {
Err(io::Error::from(io::ErrorKind::NotConnected)) Err(ChannelError::Send(CLOSED_MESSAGE.into()))
} else { } else {
Ok(()) Ok(())
}) })
} }
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
self.tx self.tx
.send(item) .send(item)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) .map_err(|_| ChannelError::Send(CLOSED_MESSAGE.into()))
} }
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@@ -65,7 +76,7 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// UnboundedSender can't initiate closure. // UnboundedSender can't initiate closure.
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
@@ -93,52 +104,45 @@ pub struct Channel<Item, SinkItem> {
} }
impl<Item, SinkItem> Stream for Channel<Item, SinkItem> { impl<Item, SinkItem> Stream for Channel<Item, SinkItem> {
type Item = Result<Item, io::Error>; type Item = Result<Item, ChannelError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> { fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, ChannelError>>> {
self.project().rx.poll_next(cx).map(|option| option.map(Ok)) self.project().rx.poll_next(cx).map(|option| option.map(Ok))
} }
} }
impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> { impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
type Error = io::Error; type Error = ChannelError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project() self.project()
.tx .tx
.poll_ready(cx) .poll_ready(cx)
.map_err(convert_send_err_to_io) .map_err(|e| ChannelError::Send(Box::new(e)))
} }
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
self.project() self.project()
.tx .tx
.start_send(item) .start_send(item)
.map_err(convert_send_err_to_io) .map_err(|e| ChannelError::Send(Box::new(e)))
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project() self.project()
.tx .tx
.poll_flush(cx) .poll_flush(cx)
.map_err(convert_send_err_to_io) .map_err(|e| ChannelError::Send(Box::new(e)))
} }
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project() self.project()
.tx .tx
.poll_close(cx) .poll_close(cx)
.map_err(convert_send_err_to_io) .map_err(|e| ChannelError::Send(Box::new(e)))
}
}
fn convert_send_err_to_io(e: futures::channel::mpsc::SendError) -> io::Error {
if e.is_disconnected() {
io::Error::from(io::ErrorKind::NotConnected)
} else if e.is_full() {
io::Error::from(io::ErrorKind::WouldBlock)
} else {
io::Error::new(io::ErrorKind::Other, e)
} }
} }
@@ -147,17 +151,27 @@ fn convert_send_err_to_io(e: futures::channel::mpsc::SendError) -> io::Error {
mod tests { mod tests {
use crate::{ use crate::{
client, context, client, context,
server::{BaseChannel, Incoming}, server::{incoming::Incoming, BaseChannel},
transport, transport::{
self,
channel::{Channel, UnboundedChannel},
},
}; };
use assert_matches::assert_matches; use assert_matches::assert_matches;
use futures::{prelude::*, stream}; use futures::{prelude::*, stream};
use log::trace;
use std::io; use std::io;
use tracing::trace;
#[test]
fn ensure_is_transport() {
fn is_transport<SinkItem, Item, T: crate::Transport<SinkItem, Item>>() {}
is_transport::<(), (), UnboundedChannel<(), ()>>();
is_transport::<(), (), Channel<(), ()>>();
}
#[tokio::test] #[tokio::test]
async fn integration() -> io::Result<()> { async fn integration() -> anyhow::Result<()> {
let _ = env_logger::try_init(); let _ = tracing_subscriber::fmt::try_init();
let (client_channel, server_channel) = transport::channel::unbounded(); let (client_channel, server_channel) = transport::channel::unbounded();
tokio::spawn( tokio::spawn(
@@ -173,10 +187,10 @@ mod tests {
}), }),
); );
let client = client::new(client::Config::default(), client_channel).spawn()?; let client = client::new(client::Config::default(), client_channel).spawn();
let response1 = client.call(context::current(), "123".into()).await?; let response1 = client.call(context::current(), "", "123".into()).await?;
let response2 = client.call(context::current(), "abc".into()).await?; let response2 = client.call(context::current(), "", "abc".into()).await?;
trace!("response1: {:?}, response2: {:?}", response1, response2); trace!("response1: {:?}, response2: {:?}", response1, response2);

View File

@@ -5,31 +5,7 @@
// https://opensource.org/licenses/MIT. // https://opensource.org/licenses/MIT.
use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::{ use std::io;
io,
time::{Duration, SystemTime},
};
/// Serializes `system_time` as a `u64` equal to the number of seconds since the epoch.
pub fn serialize_epoch_secs<S>(system_time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
const ZERO_SECS: Duration = Duration::from_secs(0);
system_time
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or(ZERO_SECS)
.as_secs() // Only care about second precision
.serialize(serializer)
}
/// Deserializes [`SystemTime`] from a `u64` equal to the number of seconds since the epoch.
pub fn deserialize_epoch_secs<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
where
D: Deserializer<'de>,
{
Ok(SystemTime::UNIX_EPOCH + Duration::from_secs(u64::deserialize(deserializer)?))
}
/// Serializes [`io::ErrorKind`] as a `u32`. /// Serializes [`io::ErrorKind`] as a `u32`.
#[allow(clippy::trivially_copy_pass_by_ref)] // Exact fn signature required by serde derive #[allow(clippy::trivially_copy_pass_by_ref)] // Exact fn signature required by serde derive

View File

@@ -1,9 +1,8 @@
use futures::prelude::*; use futures::prelude::*;
use std::io;
use tarpc::serde_transport; use tarpc::serde_transport;
use tarpc::{ use tarpc::{
client, context, client, context,
server::{BaseChannel, Incoming}, server::{incoming::Incoming, BaseChannel},
}; };
use tokio_serde::formats::Json; use tokio_serde::formats::Json;
@@ -33,7 +32,7 @@ impl ColorProtocol for ColorServer {
} }
#[tokio::test] #[tokio::test]
async fn test_call() -> io::Result<()> { async fn test_call() -> anyhow::Result<()> {
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?; let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
let addr = transport.local_addr(); let addr = transport.local_addr();
tokio::spawn( tokio::spawn(
@@ -45,7 +44,7 @@ async fn test_call() -> io::Result<()> {
); );
let transport = serde_transport::tcp::connect(addr, Json::default).await?; let transport = serde_transport::tcp::connect(addr, Json::default).await?;
let client = ColorProtocolClient::new(client::Config::default(), transport).spawn()?; let client = ColorProtocolClient::new(client::Config::default(), transport).spawn();
let color = client let color = client
.get_opposite_color(context::current(), TestData::White) .get_opposite_color(context::current(), TestData::White)

View File

@@ -3,14 +3,11 @@ use futures::{
future::{join_all, ready, Ready}, future::{join_all, ready, Ready},
prelude::*, prelude::*,
}; };
use std::{ use std::time::{Duration, SystemTime};
io,
time::{Duration, SystemTime},
};
use tarpc::{ use tarpc::{
client::{self}, client::{self},
context, context,
server::{self, BaseChannel, Channel, Incoming}, server::{self, incoming::Incoming, BaseChannel, Channel},
transport::channel, transport::channel,
}; };
use tokio::join; use tokio::join;
@@ -39,8 +36,8 @@ impl Service for Server {
} }
#[tokio::test] #[tokio::test]
async fn sequential() -> io::Result<()> { async fn sequential() -> anyhow::Result<()> {
let _ = env_logger::try_init(); let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded(); let (tx, rx) = channel::unbounded();
@@ -50,7 +47,7 @@ async fn sequential() -> io::Result<()> {
.execute(Server.serve()), .execute(Server.serve()),
); );
let client = ServiceClient::new(client::Config::default(), tx).spawn()?; let client = ServiceClient::new(client::Config::default(), tx).spawn();
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
assert_matches!( assert_matches!(
@@ -61,7 +58,7 @@ async fn sequential() -> io::Result<()> {
} }
#[tokio::test] #[tokio::test]
async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> { async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
#[tarpc_plugins::service] #[tarpc_plugins::service]
trait Loop { trait Loop {
async fn r#loop(); async fn r#loop();
@@ -82,16 +79,14 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
} }
} }
let _ = env_logger::try_init(); let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded(); let (tx, rx) = channel::unbounded();
// Set up a client that initiates a long-lived request. // Set up a client that initiates a long-lived request.
// The request will complete in error when the server drops the connection. // The request will complete in error when the server drops the connection.
tokio::spawn(async move { tokio::spawn(async move {
let client = LoopClient::new(client::Config::default(), tx) let client = LoopClient::new(client::Config::default(), tx).spawn();
.spawn()
.unwrap();
let mut ctx = context::current(); let mut ctx = context::current();
ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60); ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60);
@@ -113,11 +108,11 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
#[cfg(all(feature = "serde-transport", feature = "tcp"))] #[cfg(all(feature = "serde-transport", feature = "tcp"))]
#[tokio::test] #[tokio::test]
async fn serde() -> io::Result<()> { async fn serde() -> anyhow::Result<()> {
use tarpc::serde_transport; use tarpc::serde_transport;
use tokio_serde::formats::Json; use tokio_serde::formats::Json;
let _ = env_logger::try_init(); let _ = tracing_subscriber::fmt::try_init();
let transport = tarpc::serde_transport::tcp::listen("localhost:56789", Json::default).await?; let transport = tarpc::serde_transport::tcp::listen("localhost:56789", Json::default).await?;
let addr = transport.local_addr(); let addr = transport.local_addr();
@@ -130,7 +125,7 @@ async fn serde() -> io::Result<()> {
); );
let transport = serde_transport::tcp::connect(addr, Json::default).await?; let transport = serde_transport::tcp::connect(addr, Json::default).await?;
let client = ServiceClient::new(client::Config::default(), transport).spawn()?; let client = ServiceClient::new(client::Config::default(), transport).spawn();
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
assert_matches!( assert_matches!(
@@ -142,8 +137,8 @@ async fn serde() -> io::Result<()> {
} }
#[tokio::test] #[tokio::test]
async fn concurrent() -> io::Result<()> { async fn concurrent() -> anyhow::Result<()> {
let _ = env_logger::try_init(); let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded(); let (tx, rx) = channel::unbounded();
tokio::spawn( tokio::spawn(
@@ -152,7 +147,7 @@ async fn concurrent() -> io::Result<()> {
.execute(Server.serve()), .execute(Server.serve()),
); );
let client = ServiceClient::new(client::Config::default(), tx).spawn()?; let client = ServiceClient::new(client::Config::default(), tx).spawn();
let req1 = client.add(context::current(), 1, 2); let req1 = client.add(context::current(), 1, 2);
let req2 = client.add(context::current(), 3, 4); let req2 = client.add(context::current(), 3, 4);
@@ -166,8 +161,8 @@ async fn concurrent() -> io::Result<()> {
} }
#[tokio::test] #[tokio::test]
async fn concurrent_join() -> io::Result<()> { async fn concurrent_join() -> anyhow::Result<()> {
let _ = env_logger::try_init(); let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded(); let (tx, rx) = channel::unbounded();
tokio::spawn( tokio::spawn(
@@ -176,7 +171,7 @@ async fn concurrent_join() -> io::Result<()> {
.execute(Server.serve()), .execute(Server.serve()),
); );
let client = ServiceClient::new(client::Config::default(), tx).spawn()?; let client = ServiceClient::new(client::Config::default(), tx).spawn();
let req1 = client.add(context::current(), 1, 2); let req1 = client.add(context::current(), 1, 2);
let req2 = client.add(context::current(), 3, 4); let req2 = client.add(context::current(), 3, 4);
@@ -191,8 +186,8 @@ async fn concurrent_join() -> io::Result<()> {
} }
#[tokio::test] #[tokio::test]
async fn concurrent_join_all() -> io::Result<()> { async fn concurrent_join_all() -> anyhow::Result<()> {
let _ = env_logger::try_init(); let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded(); let (tx, rx) = channel::unbounded();
tokio::spawn( tokio::spawn(
@@ -201,7 +196,7 @@ async fn concurrent_join_all() -> io::Result<()> {
.execute(Server.serve()), .execute(Server.serve()),
); );
let client = ServiceClient::new(client::Config::default(), tx).spawn()?; let client = ServiceClient::new(client::Config::default(), tx).spawn();
let req1 = client.add(context::current(), 1, 2); let req1 = client.add(context::current(), 1, 2);
let req2 = client.add(context::current(), 3, 4); let req2 = client.add(context::current(), 3, 4);
@@ -214,7 +209,7 @@ async fn concurrent_join_all() -> io::Result<()> {
} }
#[tokio::test] #[tokio::test]
async fn counter() -> io::Result<()> { async fn counter() -> anyhow::Result<()> {
#[tarpc::service] #[tarpc::service]
trait Counter { trait Counter {
async fn count() -> u32; async fn count() -> u32;
@@ -241,7 +236,7 @@ async fn counter() -> io::Result<()> {
} }
}); });
let client = CounterClient::new(client::Config::default(), tx).spawn()?; let client = CounterClient::new(client::Config::default(), tx).spawn();
assert_matches!(client.count(context::current()).await, Ok(1)); assert_matches!(client.count(context::current()).await, Ok(1));
assert_matches!(client.count(context::current()).await, Ok(2)); assert_matches!(client.count(context::current()).await, Ok(2));