mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
Compare commits
100 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
99bf3e62a3 | ||
|
|
68863e3db0 | ||
|
|
453ba1c074 | ||
|
|
e3eac1b4f5 | ||
|
|
0e102288a5 | ||
|
|
4c8ba41b2f | ||
|
|
946c627579 | ||
|
|
104dd71bba | ||
|
|
012c481861 | ||
|
|
dc12bd09aa | ||
|
|
2594ea8ce9 | ||
|
|
839b87e394 | ||
|
|
57d0638a99 | ||
|
|
a3a6404a30 | ||
|
|
b36eac80b1 | ||
|
|
d7070e4bc3 | ||
|
|
b5d1828308 | ||
|
|
92cfe63c4f | ||
|
|
839a2f067c | ||
|
|
b5d593488c | ||
|
|
eea38b8bf4 | ||
|
|
70493c15f4 | ||
|
|
f7c5d6a7c3 | ||
|
|
98c5d2a18b | ||
|
|
46b534f7c6 | ||
|
|
42b4fc52b1 | ||
|
|
350dbcdad0 | ||
|
|
b1b4461d89 | ||
|
|
f694b7573a | ||
|
|
1e680e3a5a | ||
|
|
2591d21e94 | ||
|
|
6632f68d95 | ||
|
|
25985ad56a | ||
|
|
d6a24e9420 | ||
|
|
281a78f3c7 | ||
|
|
a0787d0091 | ||
|
|
d2acba0e8a | ||
|
|
ea7b6763c4 | ||
|
|
eb67c540b9 | ||
|
|
4151d0abd3 | ||
|
|
d0c11a6efa | ||
|
|
82c4da1743 | ||
|
|
0a15e0b75c | ||
|
|
0b315c29bf | ||
|
|
56f09bf61f | ||
|
|
6d82e82419 | ||
|
|
9bebaf814a | ||
|
|
5f4d6e6008 | ||
|
|
07d07d7ba3 | ||
|
|
a41bbf65b2 | ||
|
|
21e2f7ca62 | ||
|
|
7b7c182411 | ||
|
|
db0c778ead | ||
|
|
c3efb83ac1 | ||
|
|
3d7b0171fe | ||
|
|
c191ff5b2e | ||
|
|
90bc7f741d | ||
|
|
d3f6c01df2 | ||
|
|
c6450521e6 | ||
|
|
1da6bcec57 | ||
|
|
75a5591158 | ||
|
|
9462aad3bf | ||
|
|
0964fc51ff | ||
|
|
27aacab432 | ||
|
|
3feb465ad3 | ||
|
|
66cdc99ae0 | ||
|
|
66419db6fd | ||
|
|
72d5dbba89 | ||
|
|
e75193c191 | ||
|
|
ce4fd49161 | ||
|
|
3c978c5bf6 | ||
|
|
6f419e9a9a | ||
|
|
b3eb8d0b7a | ||
|
|
3b422eb179 | ||
|
|
4b513bad73 | ||
|
|
e71e17866d | ||
|
|
7e3fbec077 | ||
|
|
e4bc5e8e32 | ||
|
|
bc982c5584 | ||
|
|
d440e12c19 | ||
|
|
bc8128af69 | ||
|
|
1d87c14262 | ||
|
|
ca929c2178 | ||
|
|
569039734b | ||
|
|
3d43310e6a | ||
|
|
d21cbddb0d | ||
|
|
25aa857edf | ||
|
|
0bb2e2bbbe | ||
|
|
dc376343d6 | ||
|
|
2e7d1f8a88 | ||
|
|
6314591c65 | ||
|
|
7dd7494420 | ||
|
|
6c10e3649f | ||
|
|
4c6dee13d2 | ||
|
|
e45abe953a | ||
|
|
dec3e491b5 | ||
|
|
6ce341cf79 | ||
|
|
b9868250f8 | ||
|
|
a3f1064efe | ||
|
|
026083d653 |
48
.github/workflows/main.yml
vendored
48
.github/workflows/main.yml
vendored
@@ -1,4 +1,10 @@
|
||||
on: [push, pull_request]
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
name: Continuous integration
|
||||
|
||||
@@ -7,27 +13,59 @@ jobs:
|
||||
name: Check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cancel previous
|
||||
uses: styfle/cancel-workflow-action@0.7.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
target: mipsel-unknown-linux-gnu
|
||||
override: true
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: check
|
||||
args: --all-features
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: check
|
||||
args: --all-features --target mipsel-unknown-linux-gnu
|
||||
|
||||
test:
|
||||
name: Test Suite
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cancel previous
|
||||
uses: styfle/cancel-workflow-action@0.7.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features serde1
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features tokio1
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features serde-transport
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features tcp
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
@@ -37,6 +75,10 @@ jobs:
|
||||
name: Rustfmt
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cancel previous
|
||||
uses: styfle/cancel-workflow-action@0.7.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
@@ -53,6 +95,10 @@ jobs:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cancel previous
|
||||
uses: styfle/cancel-workflow-action@0.7.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
|
||||
members = [
|
||||
"example-service",
|
||||
"tarpc",
|
||||
"plugins",
|
||||
]
|
||||
|
||||
[profile.dev]
|
||||
split-debuginfo = "unpacked"
|
||||
|
||||
44
README.md
44
README.md
@@ -40,7 +40,7 @@ rather than in a separate language such as .proto. This means there's no separat
|
||||
process, and no context switching between different languages.
|
||||
|
||||
Some other features of tarpc:
|
||||
- Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
|
||||
- Pluggable transport: any type implementing `Stream<Item = Request> + Sink<Response>` can be
|
||||
used as a transport to connect the client and server.
|
||||
- `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
|
||||
- Cascading cancellation: dropping a request will send a cancellation message to the server.
|
||||
@@ -51,6 +51,14 @@ Some other features of tarpc:
|
||||
requests sent by the server that use the request context will propagate the request deadline.
|
||||
For example, if a server is handling a request with a 10s deadline, does 2s of work, then
|
||||
sends a request to another server, that server will see an 8s deadline.
|
||||
- Distributed tracing: tarpc is instrumented with
|
||||
[tracing](https://github.com/tokio-rs/tracing) primitives extended with
|
||||
[OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
|
||||
[Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
|
||||
each RPC can be traced through the client, server, and other dependencies downstream of the
|
||||
server. Even for applications not connected to a distributed tracing collector, the
|
||||
instrumentation can also be ingested by regular loggers like
|
||||
[env_logger](https://github.com/env-logger-rs/env_logger/).
|
||||
- Serde serialization: enabling the `serde1` Cargo feature will make service requests and
|
||||
responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
|
||||
be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
|
||||
@@ -59,7 +67,7 @@ Some other features of tarpc:
|
||||
Add to your `Cargo.toml` dependencies:
|
||||
|
||||
```toml
|
||||
tarpc = "0.22.0"
|
||||
tarpc = "0.29"
|
||||
```
|
||||
|
||||
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
@@ -68,12 +76,14 @@ Simply implement the generated service trait, and you're off to the races!
|
||||
|
||||
## Example
|
||||
|
||||
For this example, in addition to tarpc, also add two other dependencies to
|
||||
This example uses [tokio](https://tokio.rs), so add the following dependencies to
|
||||
your `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
anyhow = "1.0"
|
||||
futures = "0.3"
|
||||
tokio = "0.2"
|
||||
tarpc = { version = "0.29", features = ["tokio1"] }
|
||||
tokio = { version = "1.0", features = ["macros"] }
|
||||
```
|
||||
|
||||
In the following example, we use an in-process channel for communication between
|
||||
@@ -90,9 +100,8 @@ use futures::{
|
||||
};
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{self, Handler},
|
||||
server::{self, incoming::Incoming, Channel},
|
||||
};
|
||||
use std::io;
|
||||
|
||||
// This is the service definition. It looks a lot like a trait definition.
|
||||
// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
@@ -119,39 +128,34 @@ impl World for HelloServer {
|
||||
type HelloFut = Ready<String>;
|
||||
|
||||
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
||||
future::ready(format!("Hello, {}!", name))
|
||||
future::ready(format!("Hello, {name}!"))
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Lastly let's write our `main` that will start the server. While this example uses an
|
||||
[in-process channel](rpc::transport::channel), tarpc also ships a generic [`serde_transport`]
|
||||
[in-process channel](transport::channel), tarpc also ships a generic [`serde_transport`]
|
||||
behind the `serde-transport` feature, with additional [TCP](serde_transport::tcp) functionality
|
||||
available behind the `tcp` feature.
|
||||
|
||||
```rust
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
|
||||
let server = server::new(server::Config::default())
|
||||
// incoming() takes a stream of transports such as would be returned by
|
||||
// TcpListener::incoming (but a stream instead of an iterator).
|
||||
.incoming(stream::once(future::ready(server_transport)))
|
||||
.respond_with(HelloServer.serve());
|
||||
let server = server::BaseChannel::with_defaults(server_transport);
|
||||
tokio::spawn(server.execute(HelloServer.serve()));
|
||||
|
||||
tokio::spawn(server);
|
||||
|
||||
// WorldClient is generated by the macro. It has a constructor `new` that takes a config and
|
||||
// any Transport as input
|
||||
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
|
||||
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
// that takes a config and any Transport as input.
|
||||
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
||||
|
||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||
// specifies a deadline and trace information which can be helpful in debugging requests.
|
||||
let hello = client.hello(context::current(), "Stim".to_string()).await?;
|
||||
|
||||
println!("{}", hello);
|
||||
println!("{hello}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
238
RELEASES.md
238
RELEASES.md
@@ -1,3 +1,239 @@
|
||||
## 0.30.0 (2022-08-12)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- Some types that impl Future are now annotated with `#[must_use]`. Code that previously created
|
||||
these types but did not use them will now receive a warning. Code that disallows warnings will
|
||||
receive a compilation error.
|
||||
|
||||
### Fixes
|
||||
|
||||
- Servers will more reliably clean up request state for requests with long deadlines when response
|
||||
processing is aborted without sending a response.
|
||||
|
||||
### Other Changes
|
||||
|
||||
- `TrackedRequest` now contains a response guard that can be used to ensure state cleanup for
|
||||
aborted requests. (This was already handled automatically by `InFlightRequests`).
|
||||
- When the feature serde-transport is enabled, the crate tokio_serde is now re-exported.
|
||||
|
||||
## 0.29.0 (2022-05-26)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
`Context.deadline` is now serialized as a Duration. This prevents clock skew from affecting deadline
|
||||
behavior. For more details see https://github.com/google/tarpc/pull/367 and its [related
|
||||
issue](https://github.com/google/tarpc/issues/366).
|
||||
|
||||
## 0.28.0 (2022-04-06)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- The minimum supported Rust version has increased to 1.58.0.
|
||||
- The version of opentelemetry depended on by tarpc has increased to 0.17.0.
|
||||
|
||||
## 0.27.2 (2021-10-08)
|
||||
|
||||
### Fixes
|
||||
|
||||
Clients will now close their transport before dropping it. An attempt at a clean shutdown can help
|
||||
the server drop its connections more quickly.
|
||||
|
||||
## 0.27.1 (2021-09-22)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
#### RPC error type is changing
|
||||
|
||||
RPC return types are changing from `Result<Response, io::Error>` to `Result<Response,
|
||||
tarpc::client::RpcError>`.
|
||||
|
||||
Becaue tarpc is a library, not an application, it should strive to
|
||||
use structured errors in its API so that users have maximal flexibility
|
||||
in how they handle errors. io::Error makes that hard, because it is a
|
||||
kitchen-sink error type.
|
||||
|
||||
RPCs in particular only have 3 classes of errors:
|
||||
|
||||
- The connection breaks.
|
||||
- The request expires.
|
||||
- The server decides not to process the request.
|
||||
|
||||
RPC responses can also contain application-specific errors, but from the
|
||||
perspective of the RPC library, those are opaque to the framework, classified
|
||||
as successful responsees.
|
||||
|
||||
### Open Telemetry
|
||||
|
||||
The Opentelemetry dependency is updated to version 0.16.x.
|
||||
|
||||
## 0.27.0 (2021-09-22)
|
||||
|
||||
This version was yanked due to tarpc-plugins version mismatches.
|
||||
|
||||
|
||||
## 0.26.0 (2021-04-14)
|
||||
|
||||
### New Features
|
||||
|
||||
#### Tracing
|
||||
|
||||
tarpc is now instrumented with tracing primitives extended with
|
||||
OpenTelemetry traces. Using a compatible tracing-opentelemetry
|
||||
subscriber like Jaeger, each RPC can be traced through the client,
|
||||
server, amd other dependencies downstream of the server. Even for
|
||||
applications not connected to a distributed tracing collector, the
|
||||
instrumentation can also be ingested by regular loggers like env_logger.
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
#### Logging
|
||||
|
||||
Logged events are now structured using tracing. For applications using a
|
||||
logger and not a tracing subscriber, these logs may look different or
|
||||
contain information in a less consumable manner. The easiest solution is
|
||||
to add a tracing subscriber that logs to stdout, such as
|
||||
tracing_subscriber::fmt.
|
||||
|
||||
#### Context
|
||||
|
||||
- Context no longer has parent_span, which was actually never needed,
|
||||
because the context sent in an RPC is inherently the parent context.
|
||||
For purposes of distributed tracing, the client side of the RPC has all
|
||||
necessary information to link the span to its parent; the server side
|
||||
need do nothing more than export the (trace ID, span ID) tuple.
|
||||
- Context has a new field, SamplingDecision, which has two variants,
|
||||
Sampled and Unsampled. This field can be used by downstream systems to
|
||||
determine whether a trace needs to be exported. If the parent span is
|
||||
sampled, the expectation is that all child spans be exported, as well;
|
||||
to do otherwise could result in lossy traces being exported. Note that
|
||||
if an Openetelemetry tracing subscriber is not installed, the fallback
|
||||
context will still be used, but the Context's sampling decision will
|
||||
always be inherited by the parent Context's sampling decision.
|
||||
- Context::scope has been removed. Context propagation is now done via
|
||||
tracing's task-local spans. Spans can be propagated across tasks via
|
||||
Span::in_scope. When a service receives a request, it attaches an
|
||||
Opentelemetry context to the local Span created before request handling,
|
||||
and this context contains the request deadline. This span-local deadline
|
||||
is retrieved by Context::current, but it cannot be modified so that
|
||||
future Context::current calls contain a different deadline. However, the
|
||||
deadline in the context passed into an RPC call will override it, so
|
||||
users can retrieve the current context and then modify the deadline
|
||||
field, as has been historically possible.
|
||||
- Context propgation precedence changes: when an RPC is initiated, the
|
||||
current Span's Opentelemetry context takes precedence over the trace
|
||||
context passed into the RPC method. If there is no current Span, then
|
||||
the trace context argument is used as it has been historically. Note
|
||||
that Opentelemetry context propagation requires an Opentelemetry
|
||||
tracing subscriber to be installed.
|
||||
|
||||
#### Server
|
||||
|
||||
- The server::Channel trait now has an additional required associated
|
||||
type and method which returns the underlying transport. This makes it
|
||||
more ergonomic for users to retrieve transport-specific information,
|
||||
like IP Address. BaseChannel implements Channel::transport by returning
|
||||
the underlying transport, and channel decorators like Throttler just
|
||||
delegate to the Channel::transport method of the wrapped channel.
|
||||
|
||||
#### Client
|
||||
|
||||
- NewClient::spawn no longer returns a result, as spawn can't fail.
|
||||
|
||||
### References
|
||||
|
||||
1. https://github.com/tokio-rs/tracing
|
||||
2. https://opentelemetry.io
|
||||
3. https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
|
||||
4. https://github.com/env-logger-rs/env_logger
|
||||
|
||||
## 0.25.0 (2021-03-10)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
#### Major server module refactoring
|
||||
|
||||
1. Renames
|
||||
|
||||
Some of the items in this module were renamed to be less generic:
|
||||
|
||||
- Handler => Incoming
|
||||
- ClientHandler => Requests
|
||||
- ResponseHandler => InFlightRequest
|
||||
- Channel::{respond_with => requests}
|
||||
|
||||
In the case of Handler: handler of *what*? Now it's a bit clearer that this is a stream of Channels
|
||||
(aka *incoming* connections).
|
||||
|
||||
Similarly, ClientHandler was a stream of requests over a single connection. Hopefully Requests
|
||||
better reflects that.
|
||||
|
||||
ResponseHandler was renamed InFlightRequest because it no longer contains the serving function.
|
||||
Instead, it is just the request, plus the response channel and an abort hook. As a result of this,
|
||||
Channel::respond_with underwent a big change: it used to take the serving function and return a
|
||||
ClientHandler; now it has been renamed Channel::requests and does not take any args.
|
||||
|
||||
2. Execute methods
|
||||
|
||||
All methods thats actually result in responses being generated have been consolidated into methods
|
||||
named `execute`:
|
||||
|
||||
- InFlightRequest::execute returns a future that completes when a response has been generated and
|
||||
sent to the server Channel.
|
||||
- Requests::execute automatically spawns response handlers for all requests over a single channel.
|
||||
- Channel::execute is a convenience for `channel.requests().execute()`.
|
||||
- Incoming::execute automatically spawns response handlers for all requests over all channels.
|
||||
|
||||
3. Removal of Server.
|
||||
|
||||
server::Server was removed, as it provided no value over the Incoming/Channel abstractions.
|
||||
Additionally, server::new was removed, since it just returned a Server.
|
||||
|
||||
#### Client RPC methods now take &self
|
||||
|
||||
This required the breaking change of removing the Client trait. The intent of the Client trait was
|
||||
to facilitate the decorator pattern by allowing users to create their own Clients that added
|
||||
behavior on top of the base client. Unfortunately, this trait had become a maintenance burden,
|
||||
consistently causing issues with lifetimes and the lack of generic associated types. Specifically,
|
||||
it meant that Client impls could not use async fns, which is no longer tenable today, with channel
|
||||
libraries moving to async fns.
|
||||
|
||||
#### Servers no longer send deadline-exceed responses.
|
||||
|
||||
The deadline-exceeded response was largely redundant, because the client
|
||||
shouldn't normally be waiting for such a response, anyway -- the normal
|
||||
client will automatically remove the in-flight request when it reaches
|
||||
the deadline.
|
||||
|
||||
This also allows for internalizing the expiration+cleanup logic entirely
|
||||
within BaseChannel, without having it leak into the Channel trait and
|
||||
requiring action taken by the Requests struct.
|
||||
|
||||
#### Clients no longer send cancel messages when the request deadline is exceeded.
|
||||
|
||||
The server already knows when the request deadline was exceeded, so the client didn't need to inform
|
||||
it.
|
||||
|
||||
### Fixes
|
||||
|
||||
- When a channel is dropped, all in-flight requests for that channel are now aborted.
|
||||
|
||||
## 0.24.1 (2020-12-28)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
Upgrades tokio to 1.0.
|
||||
|
||||
## 0.24.0 (2020-12-28)
|
||||
|
||||
This release was yanked.
|
||||
|
||||
## 0.23.0 (2020-10-19)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
Upgrades tokio to 0.3.
|
||||
|
||||
## 0.22.0 (2020-08-02)
|
||||
|
||||
This release adds some flexibility and consistency to `serde_transport`, with one new feature and
|
||||
@@ -83,7 +319,7 @@ nameable futures and will just be boxing the return type anyway. This macro does
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- Enums had _non_exhaustive fields replaced with the #[non_exhaustive] attribute.
|
||||
- Enums had `_non_exhaustive` fields replaced with the #[non_exhaustive] attribute.
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
[package]
|
||||
name = "tarpc-example-service"
|
||||
version = "0.6.0"
|
||||
version = "0.12.0"
|
||||
rust-version = "1.56"
|
||||
authors = ["Tim Kuehn <tikue@google.com>"]
|
||||
edition = "2018"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
documentation = "https://docs.rs/tarpc-example-service"
|
||||
homepage = "https://github.com/google/tarpc"
|
||||
@@ -13,14 +14,18 @@ readme = "../README.md"
|
||||
description = "An example server built on tarpc."
|
||||
|
||||
[dependencies]
|
||||
clap = "2.0"
|
||||
anyhow = "1.0"
|
||||
clap = { version = "3.0.0-rc.9", features = ["derive"] }
|
||||
log = "0.4"
|
||||
futures = "0.3"
|
||||
serde = { version = "1.0" }
|
||||
tarpc = { version = "0.22", path = "../tarpc", features = ["full"] }
|
||||
tokio = { version = "0.2", features = ["full"] }
|
||||
tokio-serde = { version = "0.6", features = ["json"] }
|
||||
tokio-util = { version = "0.3", features = ["codec"] }
|
||||
env_logger = "0.6"
|
||||
opentelemetry = { version = "0.17", features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.16", features = ["rt-tokio"] }
|
||||
rand = "0.8"
|
||||
tarpc = { version = "0.30", path = "../tarpc", features = ["full"] }
|
||||
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
|
||||
tracing = { version = "0.1" }
|
||||
tracing-opentelemetry = "0.17"
|
||||
tracing-subscriber = {version = "0.3", features = ["env-filter"]}
|
||||
|
||||
[lib]
|
||||
name = "service"
|
||||
|
||||
@@ -4,59 +4,49 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use clap::{App, Arg};
|
||||
use std::{io, net::SocketAddr};
|
||||
use tarpc::{client, context};
|
||||
use tokio_serde::formats::Json;
|
||||
use clap::Parser;
|
||||
use service::{init_tracing, WorldClient};
|
||||
use std::{net::SocketAddr, time::Duration};
|
||||
use tarpc::{client, context, tokio_serde::formats::Json};
|
||||
use tokio::time::sleep;
|
||||
use tracing::Instrument;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Flags {
|
||||
/// Sets the server address to connect to.
|
||||
#[clap(long)]
|
||||
server_addr: SocketAddr,
|
||||
/// Sets the name to say hello to.
|
||||
#[clap(long)]
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
env_logger::init();
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let flags = Flags::parse();
|
||||
init_tracing("Tarpc Example Client")?;
|
||||
|
||||
let flags = App::new("Hello Client")
|
||||
.version("0.1")
|
||||
.author("Tim <tikue@google.com>")
|
||||
.about("Say hello!")
|
||||
.arg(
|
||||
Arg::with_name("server_addr")
|
||||
.long("server_addr")
|
||||
.value_name("ADDRESS")
|
||||
.help("Sets the server address to connect to.")
|
||||
.required(true)
|
||||
.takes_value(true),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("name")
|
||||
.short("n")
|
||||
.long("name")
|
||||
.value_name("STRING")
|
||||
.help("Sets the name to say hello to.")
|
||||
.required(true)
|
||||
.takes_value(true),
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
let server_addr = flags.value_of("server_addr").unwrap();
|
||||
let server_addr = server_addr
|
||||
.parse::<SocketAddr>()
|
||||
.unwrap_or_else(|e| panic!(r#"--server_addr value "{}" invalid: {}"#, server_addr, e));
|
||||
|
||||
let name = flags.value_of("name").unwrap().into();
|
||||
|
||||
let mut transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default);
|
||||
transport.config_mut().max_frame_length(4294967296);
|
||||
let transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
|
||||
|
||||
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
|
||||
// config and any Transport as input.
|
||||
let mut client =
|
||||
service::WorldClient::new(client::Config::default(), transport.await?).spawn()?;
|
||||
let client = WorldClient::new(client::Config::default(), transport.await?).spawn();
|
||||
|
||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||
// specifies a deadline and trace information which can be helpful in debugging requests.
|
||||
let hello = client.hello(context::current(), name).await?;
|
||||
let hello = async move {
|
||||
// Send the request twice, just to be safe! ;)
|
||||
tokio::select! {
|
||||
hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 }
|
||||
hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 }
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("Two Hellos"))
|
||||
.await;
|
||||
|
||||
println!("{}", hello);
|
||||
tracing::info!("{:?}", hello);
|
||||
|
||||
// Let the background span processor finish.
|
||||
sleep(Duration::from_micros(1)).await;
|
||||
opentelemetry::global::shutdown_tracer_provider();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use std::env;
|
||||
use tracing_subscriber::{fmt::format::FmtSpan, prelude::*};
|
||||
|
||||
/// This is the service definition. It looks a lot like a trait definition.
|
||||
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
#[tarpc::service]
|
||||
@@ -11,3 +14,21 @@ pub trait World {
|
||||
/// Returns a greeting for name.
|
||||
async fn hello(name: String) -> String;
|
||||
}
|
||||
|
||||
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
|
||||
pub fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||
|
||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||
.with_service_name(service_name)
|
||||
.with_max_packet_size(2usize.pow(13))
|
||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with(tracing_subscriber::fmt::layer().with_span_events(FmtSpan::NEW | FmtSpan::CLOSE))
|
||||
.with(tracing_opentelemetry::layer().with_tracer(tracer))
|
||||
.try_init()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,18 +4,30 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use clap::{App, Arg};
|
||||
use clap::Parser;
|
||||
use futures::{future, prelude::*};
|
||||
use service::World;
|
||||
use rand::{
|
||||
distributions::{Distribution, Uniform},
|
||||
thread_rng,
|
||||
};
|
||||
use service::{init_tracing, World};
|
||||
use std::{
|
||||
io,
|
||||
net::{IpAddr, SocketAddr},
|
||||
net::{IpAddr, Ipv6Addr, SocketAddr},
|
||||
time::Duration,
|
||||
};
|
||||
use tarpc::{
|
||||
context,
|
||||
server::{self, Channel, Handler},
|
||||
server::{self, incoming::Incoming, Channel},
|
||||
tokio_serde::formats::Json,
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
use tokio::time;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Flags {
|
||||
/// Sets the port number to listen on.
|
||||
#[clap(long)]
|
||||
port: u16,
|
||||
}
|
||||
|
||||
// This is the type that implements the generated World trait. It is the business logic
|
||||
// and is used to start the server.
|
||||
@@ -25,51 +37,35 @@ struct HelloServer(SocketAddr);
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
async fn hello(self, _: context::Context, name: String) -> String {
|
||||
format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
||||
let sleep_time =
|
||||
Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng()));
|
||||
time::sleep(sleep_time).await;
|
||||
format!("Hello, {name}! You are connected from {}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
env_logger::init();
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let flags = Flags::parse();
|
||||
init_tracing("Tarpc Example Server")?;
|
||||
|
||||
let flags = App::new("Hello Server")
|
||||
.version("0.1")
|
||||
.author("Tim <tikue@google.com>")
|
||||
.about("Say hello!")
|
||||
.arg(
|
||||
Arg::with_name("port")
|
||||
.short("p")
|
||||
.long("port")
|
||||
.value_name("NUMBER")
|
||||
.help("Sets the port number to listen on")
|
||||
.required(true)
|
||||
.takes_value(true),
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
let port = flags.value_of("port").unwrap();
|
||||
let port = port
|
||||
.parse()
|
||||
.unwrap_or_else(|e| panic!(r#"--port value "{}" invalid: {}"#, port, e));
|
||||
|
||||
let server_addr = (IpAddr::from([0, 0, 0, 0]), port);
|
||||
let server_addr = (IpAddr::V6(Ipv6Addr::LOCALHOST), flags.port);
|
||||
|
||||
// JSON transport is provided by the json_transport tarpc module. It makes it easy
|
||||
// to start up a serde-powered json serialization strategy over TCP.
|
||||
let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?;
|
||||
listener.config_mut().max_frame_length(4294967296);
|
||||
listener.config_mut().max_frame_length(usize::MAX);
|
||||
listener
|
||||
// Ignore accept errors.
|
||||
.filter_map(|r| future::ready(r.ok()))
|
||||
.map(server::BaseChannel::with_defaults)
|
||||
// Limit channels to 1 per IP.
|
||||
.max_channels_per_key(1, |t| t.as_ref().peer_addr().unwrap().ip())
|
||||
.max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip())
|
||||
// serve is generated by the service attribute. It takes as input any type implementing
|
||||
// the generated World trait.
|
||||
.map(|channel| {
|
||||
let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap());
|
||||
channel.respond_with(server.serve()).execute()
|
||||
let server = HelloServer(channel.transport().peer_addr().unwrap());
|
||||
channel.execute(server.serve())
|
||||
})
|
||||
// Max 10 channels.
|
||||
.buffer_unordered(10)
|
||||
|
||||
@@ -67,7 +67,7 @@ else
|
||||
fi
|
||||
|
||||
printf "${PREFIX} Checking for rustfmt ... "
|
||||
command -v cargo fmt &>/dev/null
|
||||
command -v rustfmt &>/dev/null
|
||||
if [ $? == 0 ]; then
|
||||
printf "${SUCCESS}\n"
|
||||
else
|
||||
@@ -93,19 +93,19 @@ diff=""
|
||||
for file in $(git diff --name-only --cached);
|
||||
do
|
||||
if [ ${file: -3} == ".rs" ]; then
|
||||
diff="$diff$(cargo fmt -- --unstable-features --skip-children --check $file)"
|
||||
diff="$diff$(rustfmt --edition 2018 --check $file)"
|
||||
if [ $? != 0 ]; then
|
||||
FMTRESULT=1
|
||||
fi
|
||||
fi
|
||||
done
|
||||
if grep --quiet "^[-+]" <<< "$diff"; then
|
||||
FMTRESULT=1
|
||||
fi
|
||||
|
||||
if [ "${TARPC_SKIP_RUSTFMT}" == 1 ]; then
|
||||
printf "${SKIPPED}\n"$?
|
||||
elif [ ${FMTRESULT} != 0 ]; then
|
||||
FAILED=1
|
||||
printf "${FAILURE}\n"
|
||||
echo "$diff" | sed 's/Using rustfmt config file.*$/d/'
|
||||
echo "$diff"
|
||||
else
|
||||
printf "${SUCCESS}\n"
|
||||
fi
|
||||
|
||||
@@ -84,11 +84,6 @@ command -v rustup &>/dev/null
|
||||
if [ "$?" == 0 ]; then
|
||||
printf "${SUCCESS}\n"
|
||||
|
||||
check_toolchain nightly
|
||||
if [ ${TOOLCHAIN_RESULT} == 1 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
try_run "Building ... " cargo +stable build --color=always
|
||||
try_run "Testing ... " cargo +stable test --color=always
|
||||
try_run "Testing with all features enabled ... " cargo +stable test --all-features --color=always
|
||||
@@ -97,6 +92,12 @@ if [ "$?" == 0 ]; then
|
||||
try_run "Running example \"$EXAMPLE\" ... " cargo +stable run --example $EXAMPLE
|
||||
done
|
||||
|
||||
check_toolchain nightly
|
||||
if [ ${TOOLCHAIN_RESULT} != 1 ]; then
|
||||
try_run "Running clippy ... " cargo +nightly clippy --color=always -Z unstable-options -- --deny warnings
|
||||
fi
|
||||
|
||||
|
||||
fi
|
||||
|
||||
exit $PREPUSH_RESULT
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
[package]
|
||||
name = "tarpc-plugins"
|
||||
version = "0.8.0"
|
||||
version = "0.12.0"
|
||||
rust-version = "1.56"
|
||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||
edition = "2018"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
documentation = "https://docs.rs/tarpc-plugins"
|
||||
homepage = "https://github.com/google/tarpc"
|
||||
@@ -19,15 +20,15 @@ serde1 = []
|
||||
travis-ci = { repository = "google/tarpc" }
|
||||
|
||||
[dependencies]
|
||||
syn = { version = "1.0.11", features = ["full"] }
|
||||
quote = "1.0.2"
|
||||
proc-macro2 = "1.0.6"
|
||||
proc-macro2 = "1.0"
|
||||
quote = "1.0"
|
||||
syn = { version = "1.0", features = ["full"] }
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[dev-dependencies]
|
||||
assert-type-eq = "0.1.0"
|
||||
futures = "0.3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tarpc = { path = "../tarpc" }
|
||||
assert-type-eq = "0.1.0"
|
||||
tarpc = { path = "../tarpc", features = ["serde1"] }
|
||||
|
||||
9
plugins/LICENSE
Normal file
9
plugins/LICENSE
Normal file
@@ -0,0 +1,9 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
@@ -83,7 +83,7 @@ impl Parse for Service {
|
||||
ident_errors,
|
||||
syn::Error::new(
|
||||
rpc.ident.span(),
|
||||
format!("method name conflicts with generated fn `{}::serve`", ident)
|
||||
format!("method name conflicts with generated fn `{ident}::serve`")
|
||||
)
|
||||
);
|
||||
}
|
||||
@@ -215,6 +215,25 @@ impl Parse for DeriveSerde {
|
||||
}
|
||||
}
|
||||
|
||||
/// A helper attribute to avoid a direct dependency on Serde.
|
||||
///
|
||||
/// Adds the following annotations to the annotated item:
|
||||
///
|
||||
/// ```rust
|
||||
/// #[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
||||
/// #[serde(crate = "tarpc::serde")]
|
||||
/// # struct Foo;
|
||||
/// ```
|
||||
#[proc_macro_attribute]
|
||||
pub fn derive_serde(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let mut gen: proc_macro2::TokenStream = quote! {
|
||||
#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
||||
#[serde(crate = "tarpc::serde")]
|
||||
};
|
||||
gen.extend(proc_macro2::TokenStream::from(item));
|
||||
proc_macro::TokenStream::from(gen)
|
||||
}
|
||||
|
||||
/// Generates:
|
||||
/// - service trait
|
||||
/// - serve fn
|
||||
@@ -240,23 +259,33 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
|
||||
let response_fut_name = &format!("{}ResponseFut", ident.unraw());
|
||||
let derive_serialize = if derive_serde.0 {
|
||||
Some(quote!(#[derive(serde::Serialize, serde::Deserialize)]))
|
||||
Some(
|
||||
quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
||||
#[serde(crate = "tarpc::serde")]},
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let methods = rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>();
|
||||
let request_names = methods
|
||||
.iter()
|
||||
.map(|m| format!("{ident}.{m}"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
ServiceGenerator {
|
||||
response_fut_name,
|
||||
service_ident: ident,
|
||||
server_ident: &format_ident!("Serve{}", ident),
|
||||
response_fut_ident: &Ident::new(&response_fut_name, ident.span()),
|
||||
response_fut_ident: &Ident::new(response_fut_name, ident.span()),
|
||||
client_ident: &format_ident!("{}Client", ident),
|
||||
request_ident: &format_ident!("{}Request", ident),
|
||||
response_ident: &format_ident!("{}Response", ident),
|
||||
vis,
|
||||
args,
|
||||
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
|
||||
method_idents: &rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>(),
|
||||
method_idents: &methods,
|
||||
request_names: &*request_names,
|
||||
attrs,
|
||||
rpcs,
|
||||
return_types: &rpcs
|
||||
@@ -277,7 +306,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
.collect::<Vec<_>>(),
|
||||
future_types: &camel_case_fn_names
|
||||
.iter()
|
||||
.map(|name| parse_str(&format!("{}Fut", name)).unwrap())
|
||||
.map(|name| parse_str(&format!("{name}Fut")).unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
derive_serialize: derive_serialize.as_ref(),
|
||||
}
|
||||
@@ -377,14 +406,10 @@ fn verify_types_were_provided(
|
||||
) -> syn::Result<()> {
|
||||
let mut result = Ok(());
|
||||
for (method, expected) in expected {
|
||||
if provided
|
||||
.iter()
|
||||
.find(|typedecl| typedecl.ident == expected)
|
||||
.is_none()
|
||||
{
|
||||
if !provided.iter().any(|typedecl| typedecl.ident == expected) {
|
||||
let mut e = syn::Error::new(
|
||||
span,
|
||||
format!("not all trait items implemented, missing: `{}`", expected),
|
||||
format!("not all trait items implemented, missing: `{expected}`"),
|
||||
);
|
||||
let fn_span = method.sig.fn_token.span();
|
||||
e.extend(syn::Error::new(
|
||||
@@ -419,6 +444,7 @@ struct ServiceGenerator<'a> {
|
||||
camel_case_idents: &'a [Ident],
|
||||
future_types: &'a [Type],
|
||||
method_idents: &'a [&'a Ident],
|
||||
request_names: &'a [String],
|
||||
method_attrs: &'a [&'a [Attribute]],
|
||||
args: &'a [&'a [PatType]],
|
||||
return_types: &'a [&'a Type],
|
||||
@@ -453,7 +479,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
),
|
||||
output,
|
||||
)| {
|
||||
let ty_doc = format!("The response future returned by {}.", ident);
|
||||
let ty_doc = format!("The response future returned by [`{service_ident}::{ident}`].");
|
||||
quote! {
|
||||
#[doc = #ty_doc]
|
||||
type #future_type: std::future::Future<Output = #output>;
|
||||
@@ -466,10 +492,11 @@ impl<'a> ServiceGenerator<'a> {
|
||||
|
||||
quote! {
|
||||
#( #attrs )*
|
||||
#vis trait #service_ident: Clone {
|
||||
#vis trait #service_ident: Sized {
|
||||
#( #types_and_fns )*
|
||||
|
||||
/// Returns a serving function to use with [tarpc::server::Channel::respond_with].
|
||||
/// Returns a serving function to use with
|
||||
/// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute).
|
||||
fn serve(self) -> #server_ident<Self> {
|
||||
#server_ident { service: self }
|
||||
}
|
||||
@@ -483,7 +510,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
/// A serving function to use with [tarpc::server::Channel::respond_with].
|
||||
/// A serving function to use with [tarpc::server::InFlightRequest::execute].
|
||||
#[derive(Clone)]
|
||||
#vis struct #server_ident<S> {
|
||||
service: S,
|
||||
@@ -501,6 +528,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
camel_case_idents,
|
||||
arg_pats,
|
||||
method_idents,
|
||||
request_names,
|
||||
..
|
||||
} = self;
|
||||
|
||||
@@ -511,6 +539,16 @@ impl<'a> ServiceGenerator<'a> {
|
||||
type Resp = #response_ident;
|
||||
type Fut = #response_fut_ident<S>;
|
||||
|
||||
fn method(&self, req: &#request_ident) -> Option<&'static str> {
|
||||
Some(match req {
|
||||
#(
|
||||
#request_ident::#camel_case_idents{..} => {
|
||||
#request_names
|
||||
}
|
||||
)*
|
||||
})
|
||||
}
|
||||
|
||||
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
|
||||
match req {
|
||||
#(
|
||||
@@ -540,6 +578,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
|
||||
quote! {
|
||||
/// The request sent over the wire from the client to the server.
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug)]
|
||||
#derive_serialize
|
||||
#vis enum #request_ident {
|
||||
@@ -560,6 +599,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
|
||||
quote! {
|
||||
/// The response sent over the wire from the server to the client.
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug)]
|
||||
#derive_serialize
|
||||
#vis enum #response_ident {
|
||||
@@ -580,6 +620,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
|
||||
quote! {
|
||||
/// A future resolving to a server response.
|
||||
#[allow(missing_docs)]
|
||||
#vis enum #response_fut_ident<S: #service_ident> {
|
||||
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
|
||||
}
|
||||
@@ -646,27 +687,9 @@ impl<'a> ServiceGenerator<'a> {
|
||||
quote! {
|
||||
#[allow(unused)]
|
||||
#[derive(Clone, Debug)]
|
||||
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
|
||||
#vis struct #client_ident<C = tarpc::client::Channel<#request_ident, #response_ident>>(C);
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_from_for_client(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
client_ident,
|
||||
request_ident,
|
||||
response_ident,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl<C> From<C> for #client_ident<C>
|
||||
where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
|
||||
{
|
||||
fn from(client: C) -> Self {
|
||||
#client_ident(client)
|
||||
}
|
||||
}
|
||||
/// The client stub that makes RPC calls to the server. All request methods return
|
||||
/// [Futures](std::future::Future).
|
||||
#vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -685,7 +708,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
#vis fn new<T>(config: tarpc::client::Config, transport: T)
|
||||
-> tarpc::client::NewClient<
|
||||
Self,
|
||||
tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>
|
||||
tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
|
||||
>
|
||||
where
|
||||
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>
|
||||
@@ -709,6 +732,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
method_attrs,
|
||||
vis,
|
||||
method_idents,
|
||||
request_names,
|
||||
args,
|
||||
return_types,
|
||||
arg_pats,
|
||||
@@ -717,16 +741,14 @@ impl<'a> ServiceGenerator<'a> {
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl<C> #client_ident<C>
|
||||
where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
|
||||
{
|
||||
impl #client_ident {
|
||||
#(
|
||||
#[allow(unused)]
|
||||
#( #method_attrs )*
|
||||
#vis fn #method_idents(&mut self, ctx: tarpc::context::Context, #( #args ),*)
|
||||
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
|
||||
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
|
||||
-> impl std::future::Future<Output = Result<#return_types, tarpc::client::RpcError>> + '_ {
|
||||
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
|
||||
let resp = tarpc::Client::call(&mut self.0, ctx, request);
|
||||
let resp = self.0.call(ctx, #request_names, request);
|
||||
async move {
|
||||
match resp.await? {
|
||||
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
|
||||
@@ -752,7 +774,6 @@ impl<'a> ToTokens for ServiceGenerator<'a> {
|
||||
self.impl_debug_for_response_future(),
|
||||
self.impl_future_for_response_future(),
|
||||
self.struct_client(),
|
||||
self.impl_from_for_client(),
|
||||
self.impl_client_new(),
|
||||
self.impl_client_rpc_methods(),
|
||||
])
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
[package]
|
||||
name = "tarpc"
|
||||
version = "0.22.0"
|
||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||
edition = "2018"
|
||||
version = "0.30.0"
|
||||
rust-version = "1.58.0"
|
||||
authors = [
|
||||
"Adam Wright <adam.austin.wright@gmail.com>",
|
||||
"Tim Kuehn <timothy.j.kuehn@gmail.com>",
|
||||
]
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
documentation = "https://docs.rs/tarpc"
|
||||
homepage = "https://github.com/google/tarpc"
|
||||
@@ -16,11 +20,20 @@ description = "An RPC framework for Rust with a focus on ease of use."
|
||||
default = []
|
||||
|
||||
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
|
||||
tokio1 = []
|
||||
serde-transport = ["tokio-serde", "tokio-util/codec"]
|
||||
tcp = ["tokio/net", "tokio/stream"]
|
||||
tokio1 = ["tokio/rt"]
|
||||
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
||||
serde-transport-json = ["tokio-serde/json"]
|
||||
serde-transport-bincode = ["tokio-serde/bincode"]
|
||||
tcp = ["tokio/net"]
|
||||
|
||||
full = ["serde1", "tokio1", "serde-transport", "tcp"]
|
||||
full = [
|
||||
"serde1",
|
||||
"tokio1",
|
||||
"serde-transport",
|
||||
"serde-transport-json",
|
||||
"serde-transport-bincode",
|
||||
"tcp",
|
||||
]
|
||||
|
||||
[badges]
|
||||
travis-ci = { repository = "google/tarpc" }
|
||||
@@ -29,30 +42,39 @@ travis-ci = { repository = "google/tarpc" }
|
||||
anyhow = "1.0"
|
||||
fnv = "1.0"
|
||||
futures = "0.3"
|
||||
humantime = "1.0"
|
||||
log = "0.4"
|
||||
pin-project = "0.4.17"
|
||||
rand = "0.7"
|
||||
tokio = { version = "0.2", features = ["time"] }
|
||||
humantime = "2.0"
|
||||
pin-project = "1.0"
|
||||
rand = "0.8"
|
||||
serde = { optional = true, version = "1.0", features = ["derive"] }
|
||||
static_assertions = "1.1.0"
|
||||
tarpc-plugins = { path = "../plugins", version = "0.8" }
|
||||
tokio-util = { optional = true, version = "0.3" }
|
||||
tokio-serde = { optional = true, version = "0.6" }
|
||||
tarpc-plugins = { path = "../plugins", version = "0.12" }
|
||||
thiserror = "1.0"
|
||||
tokio = { version = "1", features = ["time"] }
|
||||
tokio-util = { version = "0.7.3", features = ["time"] }
|
||||
tokio-serde = { optional = true, version = "0.8" }
|
||||
tracing = { version = "0.1", default-features = false, features = [
|
||||
"attributes",
|
||||
"log",
|
||||
] }
|
||||
tracing-opentelemetry = { version = "0.17.2", default-features = false }
|
||||
opentelemetry = { version = "0.17.0", default-features = false }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = "1.0"
|
||||
assert_matches = "1.4"
|
||||
bincode = "1.3"
|
||||
bytes = { version = "0.5", features = ["serde"] }
|
||||
env_logger = "0.6"
|
||||
flate2 = "1.0.16"
|
||||
futures = "0.3"
|
||||
humantime = "1.0"
|
||||
log = "0.4"
|
||||
bytes = { version = "1", features = ["serde"] }
|
||||
flate2 = "1.0"
|
||||
futures-test = "0.3"
|
||||
opentelemetry = { version = "0.17.0", default-features = false, features = [
|
||||
"rt-tokio",
|
||||
] }
|
||||
opentelemetry-jaeger = { version = "0.16.0", features = ["rt-tokio"] }
|
||||
pin-utils = "0.1.0-alpha"
|
||||
serde_bytes = "0.11"
|
||||
tokio = { version = "0.2", features = ["full"] }
|
||||
tokio-serde = { version = "0.6", features = ["json", "bincode"] }
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
tokio = { version = "1", features = ["full", "test-util"] }
|
||||
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
|
||||
trybuild = "1.0"
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
@@ -60,7 +82,11 @@ all-features = true
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
|
||||
[[example]]
|
||||
name = "server_calling_server"
|
||||
name = "compression"
|
||||
required-features = ["serde-transport", "tcp"]
|
||||
|
||||
[[example]]
|
||||
name = "tracing"
|
||||
required-features = ["full"]
|
||||
|
||||
[[example]]
|
||||
@@ -70,3 +96,15 @@ required-features = ["full"]
|
||||
[[example]]
|
||||
name = "pubsub"
|
||||
required-features = ["full"]
|
||||
|
||||
[[example]]
|
||||
name = "custom_transport"
|
||||
required-features = ["serde1", "tokio1", "serde-transport"]
|
||||
|
||||
[[test]]
|
||||
name = "service_functional"
|
||||
required-features = ["serde-transport"]
|
||||
|
||||
[[test]]
|
||||
name = "dataservice"
|
||||
required-features = ["serde-transport", "tcp"]
|
||||
|
||||
9
tarpc/LICENSE
Normal file
9
tarpc/LICENSE
Normal file
@@ -0,0 +1,9 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
@@ -7,8 +7,8 @@ use tarpc::{
|
||||
client, context,
|
||||
serde_transport::tcp,
|
||||
server::{BaseChannel, Channel},
|
||||
tokio_serde::formats::Bincode,
|
||||
};
|
||||
use tokio_serde::formats::Bincode;
|
||||
|
||||
/// Type of compression that should be enabled on the request. The transport is free to ignore this.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)]
|
||||
@@ -54,7 +54,7 @@ where
|
||||
if algorithm != CompressionAlgorithm::Deflate {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Compression algorithm {:?} not supported", algorithm),
|
||||
format!("Compression algorithm {algorithm:?} not supported"),
|
||||
));
|
||||
}
|
||||
let mut deflater = DeflateDecoder::new(payload.as_slice());
|
||||
@@ -102,7 +102,7 @@ struct HelloServer;
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
async fn hello(self, _: context::Context, name: String) -> String {
|
||||
format!("Hey, {}!", name)
|
||||
format!("Hey, {name}!")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,14 +113,12 @@ async fn main() -> anyhow::Result<()> {
|
||||
tokio::spawn(async move {
|
||||
let transport = incoming.next().await.unwrap().unwrap();
|
||||
BaseChannel::with_defaults(add_compression(transport))
|
||||
.respond_with(HelloServer.serve())
|
||||
.execute()
|
||||
.execute(HelloServer.serve())
|
||||
.await;
|
||||
});
|
||||
|
||||
let transport = tcp::connect(addr, Bincode::default).await?;
|
||||
let mut client =
|
||||
WorldClient::new(client::Config::default(), add_compression(transport)).spawn()?;
|
||||
let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn();
|
||||
|
||||
println!(
|
||||
"{}",
|
||||
|
||||
48
tarpc/examples/custom_transport.rs
Normal file
48
tarpc/examples/custom_transport.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use tarpc::context::Context;
|
||||
use tarpc::serde_transport as transport;
|
||||
use tarpc::server::{BaseChannel, Channel};
|
||||
use tarpc::tokio_serde::formats::Bincode;
|
||||
use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec;
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait PingService {
|
||||
async fn ping();
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Service;
|
||||
|
||||
#[tarpc::server]
|
||||
impl PingService for Service {
|
||||
async fn ping(self, _: Context) {}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let bind_addr = "/tmp/tarpc_on_unix_example.sock";
|
||||
|
||||
let _ = std::fs::remove_file(bind_addr);
|
||||
|
||||
let listener = UnixListener::bind(bind_addr).unwrap();
|
||||
let codec_builder = LengthDelimitedCodec::builder();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (conn, _addr) = listener.accept().await.unwrap();
|
||||
let framed = codec_builder.new_framed(conn);
|
||||
let transport = transport::new(framed, Bincode::default());
|
||||
|
||||
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
});
|
||||
|
||||
let conn = UnixStream::connect(bind_addr).await?;
|
||||
let transport = transport::new(codec_builder.new_framed(conn), Bincode::default());
|
||||
PingServiceClient::new(Default::default(), transport)
|
||||
.spawn()
|
||||
.ping(tarpc::context::current())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -38,10 +38,11 @@ use futures::{
|
||||
future::{self, AbortHandle},
|
||||
prelude::*,
|
||||
};
|
||||
use log::info;
|
||||
use publisher::Publisher as _;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
env,
|
||||
error::Error,
|
||||
io,
|
||||
net::SocketAddr,
|
||||
sync::{Arc, Mutex, RwLock},
|
||||
@@ -51,9 +52,11 @@ use tarpc::{
|
||||
client, context,
|
||||
serde_transport::tcp,
|
||||
server::{self, Channel},
|
||||
tokio_serde::formats::Json,
|
||||
};
|
||||
use tokio::net::ToSocketAddrs;
|
||||
use tokio_serde::formats::Json;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
pub mod subscriber {
|
||||
#[tarpc::service]
|
||||
@@ -83,10 +86,7 @@ impl subscriber::Subscriber for Subscriber {
|
||||
}
|
||||
|
||||
async fn receive(self, _: context::Context, topic: String, message: String) {
|
||||
info!(
|
||||
"[{}] received message on topic '{}': {}",
|
||||
self.local_addr, topic, message
|
||||
);
|
||||
info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,11 +105,11 @@ impl Subscriber {
|
||||
) -> anyhow::Result<SubscriberHandle> {
|
||||
let publisher = tcp::connect(publisher_addr, Json::default).await?;
|
||||
let local_addr = publisher.local_addr()?;
|
||||
let mut handler = server::BaseChannel::with_defaults(publisher)
|
||||
.respond_with(Subscriber { local_addr, topics }.serve());
|
||||
// The first request is for the topics being subscriibed to.
|
||||
let mut handler = server::BaseChannel::with_defaults(publisher).requests();
|
||||
let subscriber = Subscriber { local_addr, topics };
|
||||
// The first request is for the topics being subscribed to.
|
||||
match handler.next().await {
|
||||
Some(init_topics) => init_topics?.await,
|
||||
Some(init_topics) => init_topics?.execute(subscriber.clone().serve()).await,
|
||||
None => {
|
||||
return Err(anyhow!(
|
||||
"[{}] Server never initialized the subscriber.",
|
||||
@@ -117,10 +117,10 @@ impl Subscriber {
|
||||
))
|
||||
}
|
||||
};
|
||||
let (handler, abort_handle) = future::abortable(handler.execute());
|
||||
let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve()));
|
||||
tokio::spawn(async move {
|
||||
match handler.await {
|
||||
Ok(()) | Err(future::Aborted) => info!("[{}] subscriber shutdown.", local_addr),
|
||||
Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."),
|
||||
}
|
||||
});
|
||||
Ok(SubscriberHandle(abort_handle))
|
||||
@@ -129,7 +129,6 @@ impl Subscriber {
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Subscription {
|
||||
subscriber: subscriber::SubscriberClient,
|
||||
topics: Vec<String>,
|
||||
}
|
||||
|
||||
@@ -153,17 +152,16 @@ impl Publisher {
|
||||
subscriptions: self.clone().start_subscription_manager().await?,
|
||||
};
|
||||
|
||||
info!("[{}] listening for publishers.", publisher_addrs.publisher);
|
||||
info!(publisher_addr = %publisher_addrs.publisher, "listening for publishers.",);
|
||||
tokio::spawn(async move {
|
||||
// Because this is just an example, we know there will only be one publisher. In more
|
||||
// realistic code, this would be a loop to continually accept new publisher
|
||||
// connections.
|
||||
let publisher = connecting_publishers.next().await.unwrap().unwrap();
|
||||
info!("[{}] publisher connected.", publisher.peer_addr().unwrap());
|
||||
info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected.");
|
||||
|
||||
server::BaseChannel::with_defaults(publisher)
|
||||
.respond_with(self.serve())
|
||||
.execute()
|
||||
.execute(self.serve())
|
||||
.await
|
||||
});
|
||||
|
||||
@@ -175,7 +173,7 @@ impl Publisher {
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let new_subscriber_addr = connecting_subscribers.get_ref().local_addr();
|
||||
info!("[{}] listening for subscribers.", new_subscriber_addr);
|
||||
info!(?new_subscriber_addr, "listening for subscribers.");
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(conn) = connecting_subscribers.next().await {
|
||||
@@ -204,19 +202,18 @@ impl Publisher {
|
||||
async fn initialize_subscription(
|
||||
&mut self,
|
||||
subscriber_addr: SocketAddr,
|
||||
mut subscriber: subscriber::SubscriberClient,
|
||||
subscriber: subscriber::SubscriberClient,
|
||||
) {
|
||||
// Populate the topics
|
||||
if let Ok(topics) = subscriber.topics(context::current()).await {
|
||||
self.clients.lock().unwrap().insert(
|
||||
subscriber_addr,
|
||||
Subscription {
|
||||
subscriber: subscriber.clone(),
|
||||
topics: topics.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
info!("[{}] subscribed to topics: {:?}", subscriber_addr, topics);
|
||||
info!(%subscriber_addr, ?topics, "subscribed to new topics");
|
||||
let mut subscriptions = self.subscriptions.write().unwrap();
|
||||
for topic in topics {
|
||||
subscriptions
|
||||
@@ -227,18 +224,18 @@ impl Publisher {
|
||||
}
|
||||
}
|
||||
|
||||
fn start_subscriber_gc(
|
||||
fn start_subscriber_gc<E: Error>(
|
||||
self,
|
||||
subscriber_addr: SocketAddr,
|
||||
client_dispatch: impl Future<Output = anyhow::Result<()>> + Send + 'static,
|
||||
client_dispatch: impl Future<Output = Result<(), E>> + Send + 'static,
|
||||
subscriber_ready: oneshot::Receiver<()>,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = client_dispatch.await {
|
||||
info!(
|
||||
"[{}] subscriber connection broken: {:?}",
|
||||
subscriber_addr, e
|
||||
)
|
||||
%subscriber_addr,
|
||||
error = %e,
|
||||
"subscriber connection broken");
|
||||
}
|
||||
// Don't clean up the subscriber until initialization is done.
|
||||
let _ = subscriber_ready.await;
|
||||
@@ -282,13 +279,29 @@ impl publisher::Publisher for Publisher {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
|
||||
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||
.with_service_name(service_name)
|
||||
.with_max_packet_size(2usize.pow(13))
|
||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::filter::EnvFilter::from_default_env())
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(tracing_opentelemetry::layer().with_tracer(tracer))
|
||||
.try_init()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
env_logger::init();
|
||||
init_tracing("Pub/Sub")?;
|
||||
|
||||
let clients = Arc::new(Mutex::new(HashMap::new()));
|
||||
let addrs = Publisher {
|
||||
clients,
|
||||
clients: Arc::new(Mutex::new(HashMap::new())),
|
||||
subscriptions: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
.start()
|
||||
@@ -306,11 +319,11 @@ async fn main() -> anyhow::Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut publisher = publisher::PublisherClient::new(
|
||||
let publisher = publisher::PublisherClient::new(
|
||||
client::Config::default(),
|
||||
tcp::connect(addrs.publisher, Json::default).await?,
|
||||
)
|
||||
.spawn()?;
|
||||
.spawn();
|
||||
|
||||
publisher
|
||||
.publish(context::current(), "calculus".into(), "sqrt(2)".into())
|
||||
@@ -338,6 +351,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
opentelemetry::global::shutdown_tracer_provider();
|
||||
info!("done.");
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -4,16 +4,11 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use futures::{
|
||||
future::{self, Ready},
|
||||
prelude::*,
|
||||
};
|
||||
use std::io;
|
||||
use futures::future::{self, Ready};
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{BaseChannel, Channel},
|
||||
server::{self, Channel},
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
/// This is the service definition. It looks a lot like a trait definition.
|
||||
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
@@ -34,46 +29,27 @@ impl World for HelloServer {
|
||||
type HelloFut = Ready<String>;
|
||||
|
||||
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
||||
future::ready(format!("Hello, {}!", name))
|
||||
future::ready(format!("Hello, {name}!"))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
// tarpc_json_transport is provided by the associated crate json_transport. It makes it
|
||||
// easy to start up a serde-powered JSON serialization strategy over TCP.
|
||||
let mut transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default).await?;
|
||||
let addr = transport.local_addr();
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
|
||||
let server = async move {
|
||||
// For this example, we're just going to wait for one connection.
|
||||
let client = transport.next().await.unwrap().unwrap();
|
||||
let server = server::BaseChannel::with_defaults(server_transport);
|
||||
tokio::spawn(server.execute(HelloServer.serve()));
|
||||
|
||||
// `Channel` is a trait representing a server-side connection. It is a trait to allow
|
||||
// for some channels to be instrumented: for example, to track the number of open connections.
|
||||
// BaseChannel is the most basic channel, simply wrapping a transport with no added
|
||||
// functionality.
|
||||
BaseChannel::with_defaults(client)
|
||||
// serve_world is generated by the tarpc::service attribute. It takes as input any type
|
||||
// implementing the generated World trait.
|
||||
.respond_with(HelloServer.serve())
|
||||
.execute()
|
||||
.await;
|
||||
};
|
||||
tokio::spawn(server);
|
||||
|
||||
let transport = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
|
||||
// WorldClient is generated by the tarpc::service attribute. It has a constructor `new` that
|
||||
// takes a config and any Transport as input.
|
||||
let mut client = WorldClient::new(client::Config::default(), transport).spawn()?;
|
||||
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
// that takes a config and any Transport as input.
|
||||
let client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
||||
|
||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||
// specifies a deadline and trace information which can be helpful in debugging requests.
|
||||
let hello = client.hello(context::current(), "Stim".to_string()).await?;
|
||||
|
||||
eprintln!("{}", hello);
|
||||
println!("{hello}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -6,12 +6,12 @@
|
||||
|
||||
use crate::{add::Add as AddService, double::Double as DoubleService};
|
||||
use futures::{future, prelude::*};
|
||||
use std::io;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{Handler, Server},
|
||||
server::{incoming::Incoming, BaseChannel},
|
||||
tokio_serde::formats::Json,
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
pub mod add {
|
||||
#[tarpc::service]
|
||||
@@ -46,7 +46,7 @@ struct DoubleServer {
|
||||
|
||||
#[tarpc::server]
|
||||
impl DoubleService for DoubleServer {
|
||||
async fn double(mut self, _: context::Context, x: i32) -> Result<i32, String> {
|
||||
async fn double(self, _: context::Context, x: i32) -> Result<i32, String> {
|
||||
self.add_client
|
||||
.add(context::current(), x, x)
|
||||
.await
|
||||
@@ -54,39 +54,59 @@ impl DoubleService for DoubleServer {
|
||||
}
|
||||
}
|
||||
|
||||
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||
.with_service_name(service_name)
|
||||
.with_auto_split_batch(true)
|
||||
.with_max_packet_size(2usize.pow(13))
|
||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(tracing_opentelemetry::layer().with_tracer(tracer))
|
||||
.try_init()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
env_logger::init();
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
init_tracing("tarpc_tracing_example")?;
|
||||
|
||||
let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let addr = add_listener.get_ref().local_addr();
|
||||
let add_server = Server::default()
|
||||
.incoming(add_listener)
|
||||
let add_server = add_listener
|
||||
.map(BaseChannel::with_defaults)
|
||||
.take(1)
|
||||
.respond_with(AddServer.serve());
|
||||
.execute(AddServer.serve());
|
||||
tokio::spawn(add_server);
|
||||
|
||||
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn()?;
|
||||
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn();
|
||||
|
||||
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let addr = double_listener.get_ref().local_addr();
|
||||
let double_server = tarpc::Server::default()
|
||||
.incoming(double_listener)
|
||||
let double_server = double_listener
|
||||
.map(BaseChannel::with_defaults)
|
||||
.take(1)
|
||||
.respond_with(DoubleServer { add_client }.serve());
|
||||
.execute(DoubleServer { add_client }.serve());
|
||||
tokio::spawn(double_server);
|
||||
|
||||
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let mut double_client =
|
||||
double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?;
|
||||
let double_client =
|
||||
double::DoubleClient::new(client::Config::default(), to_double_server).spawn();
|
||||
|
||||
for i in 1..=5 {
|
||||
eprintln!("{:?}", double_client.double(context::current(), i).await?);
|
||||
let ctx = context::current();
|
||||
for _ in 1..=5 {
|
||||
tracing::info!("{:?}", double_client.double(ctx, 1).await?);
|
||||
}
|
||||
|
||||
opentelemetry::global::shutdown_tracer_provider();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
49
tarpc/src/cancellations.rs
Normal file
49
tarpc/src/cancellations.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use futures::{prelude::*, task::*};
|
||||
use std::pin::Pin;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Sends request cancellation signals.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RequestCancellation(mpsc::UnboundedSender<u64>);
|
||||
|
||||
/// A stream of IDs of requests that have been canceled.
|
||||
#[derive(Debug)]
|
||||
pub struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
|
||||
|
||||
/// Returns a channel to send request cancellation messages.
|
||||
pub fn cancellations() -> (RequestCancellation, CanceledRequests) {
|
||||
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
|
||||
// bounded by the number of in-flight requests.
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
(RequestCancellation(tx), CanceledRequests(rx))
|
||||
}
|
||||
|
||||
impl RequestCancellation {
|
||||
/// Cancels the request with ID `request_id`.
|
||||
///
|
||||
/// No validation is done of `request_id`. There is no way to know if the request id provided
|
||||
/// corresponds to a request actually tracked by the backing channel. `RequestCancellation` is
|
||||
/// a one-way communication channel.
|
||||
///
|
||||
/// Once request data is cleaned up, a response will never be received by the client. This is
|
||||
/// useful primarily when request processing ends prematurely for requests with long deadlines
|
||||
/// which would otherwise continue to be tracked by the backing channel—a kind of leak.
|
||||
pub fn cancel(&self, request_id: u64) {
|
||||
let _ = self.0.send(request_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl CanceledRequests {
|
||||
/// Polls for a cancelled request.
|
||||
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
||||
self.0.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for CanceledRequests {
|
||||
type Item = u64;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
||||
self.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
875
tarpc/src/client.rs
Normal file
875
tarpc/src/client.rs
Normal file
@@ -0,0 +1,875 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Provides a client that connects to a server and sends multiplexed requests.
|
||||
|
||||
mod in_flight_requests;
|
||||
|
||||
use crate::{
|
||||
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||
context, trace, ClientMessage, Request, Response, ServerError, Transport,
|
||||
};
|
||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
error::Error,
|
||||
fmt, mem,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::Span;
|
||||
|
||||
/// Settings that control the behavior of the client.
|
||||
#[derive(Clone, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub struct Config {
|
||||
/// The number of requests that can be in flight at once.
|
||||
/// `max_in_flight_requests` controls the size of the map used by the client
|
||||
/// for storing pending requests.
|
||||
pub max_in_flight_requests: usize,
|
||||
/// The number of requests that can be buffered client-side before being sent.
|
||||
/// `pending_requests_buffer` controls the size of the channel clients use
|
||||
/// to communicate with the request dispatch task.
|
||||
pub pending_request_buffer: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Config {
|
||||
max_in_flight_requests: 1_000,
|
||||
pending_request_buffer: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A channel and dispatch pair. The dispatch drives the sending and receiving of requests
|
||||
/// and must be polled continuously or spawned.
|
||||
pub struct NewClient<C, D> {
|
||||
/// The new client.
|
||||
pub client: C,
|
||||
/// The client's dispatch.
|
||||
pub dispatch: D,
|
||||
}
|
||||
|
||||
impl<C, D, E> NewClient<C, D>
|
||||
where
|
||||
D: Future<Output = Result<(), E>> + Send + 'static,
|
||||
E: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
/// Helper method to spawn the dispatch on the default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub fn spawn(self) -> C {
|
||||
let dispatch = self.dispatch.unwrap_or_else(move |e| {
|
||||
let e = anyhow::Error::new(e);
|
||||
tracing::warn!("Connection broken: {:?}", e);
|
||||
});
|
||||
tokio::spawn(dispatch);
|
||||
self.client
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, D> fmt::Debug for NewClient<C, D> {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "NewClient")
|
||||
}
|
||||
}
|
||||
|
||||
const _CHECK_USIZE: () = assert!(
|
||||
std::mem::size_of::<usize>() <= std::mem::size_of::<u64>(),
|
||||
"usize is too big to fit in u64"
|
||||
);
|
||||
|
||||
/// Handles communication from the client to request dispatch.
|
||||
#[derive(Debug)]
|
||||
pub struct Channel<Req, Resp> {
|
||||
to_dispatch: mpsc::Sender<DispatchRequest<Req, Resp>>,
|
||||
/// Channel to send a cancel message to the dispatcher.
|
||||
cancellation: RequestCancellation,
|
||||
/// The ID to use for the next request to stage.
|
||||
next_request_id: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl<Req, Resp> Clone for Channel<Req, Resp> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
to_dispatch: self.to_dispatch.clone(),
|
||||
cancellation: self.cancellation.clone(),
|
||||
next_request_id: self.next_request_id.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> Channel<Req, Resp> {
|
||||
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
||||
/// resolves to the response.
|
||||
#[tracing::instrument(
|
||||
name = "RPC",
|
||||
skip(self, ctx, request_name, request),
|
||||
fields(
|
||||
rpc.trace_id = tracing::field::Empty,
|
||||
rpc.deadline = %humantime::format_rfc3339(ctx.deadline),
|
||||
otel.kind = "client",
|
||||
otel.name = request_name)
|
||||
)]
|
||||
pub async fn call(
|
||||
&self,
|
||||
mut ctx: context::Context,
|
||||
request_name: &str,
|
||||
request: Req,
|
||||
) -> Result<Resp, RpcError> {
|
||||
let span = Span::current();
|
||||
ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
|
||||
tracing::trace!(
|
||||
"OpenTelemetry subscriber not installed; making unsampled child context."
|
||||
);
|
||||
ctx.trace_context.new_child()
|
||||
});
|
||||
span.record("rpc.trace_id", &tracing::field::display(ctx.trace_id()));
|
||||
let (response_completion, mut response) = oneshot::channel();
|
||||
let request_id =
|
||||
u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||
|
||||
// ResponseGuard impls Drop to cancel in-flight requests. It should be created before
|
||||
// sending out the request; otherwise, the response future could be dropped after the
|
||||
// request is sent out but before ResponseGuard is created, rendering the cancellation
|
||||
// logic inactive.
|
||||
let response_guard = ResponseGuard {
|
||||
response: &mut response,
|
||||
request_id,
|
||||
cancellation: &self.cancellation,
|
||||
};
|
||||
self.to_dispatch
|
||||
.send(DispatchRequest {
|
||||
ctx,
|
||||
span,
|
||||
request_id,
|
||||
request,
|
||||
response_completion,
|
||||
})
|
||||
.await
|
||||
.map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
|
||||
response_guard.response().await
|
||||
}
|
||||
}
|
||||
|
||||
/// A server response that is completed by request dispatch when the corresponding response
|
||||
/// arrives off the wire.
|
||||
struct ResponseGuard<'a, Resp> {
|
||||
response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
cancellation: &'a RequestCancellation,
|
||||
request_id: u64,
|
||||
}
|
||||
|
||||
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
|
||||
/// rather cross-cutting errors that can always occur.
|
||||
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum RpcError {
|
||||
/// The client disconnected from the server.
|
||||
#[error("the client disconnected from the server")]
|
||||
Disconnected,
|
||||
/// The request exceeded its deadline.
|
||||
#[error("the request exceeded its deadline")]
|
||||
DeadlineExceeded,
|
||||
/// The server aborted request processing.
|
||||
#[error("the server aborted request processing")]
|
||||
Server(#[from] ServerError),
|
||||
}
|
||||
|
||||
impl From<DeadlineExceededError> for RpcError {
|
||||
fn from(_: DeadlineExceededError) -> Self {
|
||||
RpcError::DeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
impl<Resp> ResponseGuard<'_, Resp> {
|
||||
async fn response(mut self) -> Result<Resp, RpcError> {
|
||||
let response = (&mut self.response).await;
|
||||
// Cancel drop logic once a response has been received.
|
||||
mem::forget(self);
|
||||
match response {
|
||||
Ok(resp) => Ok(resp?.message?),
|
||||
Err(oneshot::error::RecvError { .. }) => {
|
||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||
// there's nothing listening on the other side, so there's no point in
|
||||
// propagating cancellation.
|
||||
Err(RpcError::Disconnected)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cancels the request when dropped, if not already complete.
|
||||
impl<Resp> Drop for ResponseGuard<'_, Resp> {
|
||||
fn drop(&mut self) {
|
||||
// The receiver needs to be closed to handle the edge case that the request has not
|
||||
// yet been received by the dispatch task. It is possible for the cancel message to
|
||||
// arrive before the request itself, in which case the request could get stuck in the
|
||||
// dispatch map forever if the server never responds (e.g. if the server dies while
|
||||
// responding). Even if the server does respond, it will have unnecessarily done work
|
||||
// for a client no longer waiting for a response. To avoid this, the dispatch task
|
||||
// checks if the receiver is closed before inserting the request in the map. By
|
||||
// closing the receiver before sending the cancel message, it is guaranteed that if the
|
||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||
// receiver as closed.
|
||||
self.response.close();
|
||||
self.cancellation.cancel(self.request_id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the
|
||||
/// channel.
|
||||
pub fn new<Req, Resp, C>(
|
||||
config: Config,
|
||||
transport: C,
|
||||
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C>>
|
||||
where
|
||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||
{
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
|
||||
let (cancellation, canceled_requests) = cancellations();
|
||||
let canceled_requests = canceled_requests;
|
||||
|
||||
NewClient {
|
||||
client: Channel {
|
||||
to_dispatch,
|
||||
cancellation,
|
||||
next_request_id: Arc::new(AtomicUsize::new(0)),
|
||||
},
|
||||
dispatch: RequestDispatch {
|
||||
config,
|
||||
canceled_requests,
|
||||
transport: transport.fuse(),
|
||||
in_flight_requests: InFlightRequests::default(),
|
||||
pending_requests,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
|
||||
/// and dispatching responses to the appropriate channel.
|
||||
#[must_use]
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct RequestDispatch<Req, Resp, C> {
|
||||
/// Writes requests to the wire and reads responses off the wire.
|
||||
#[pin]
|
||||
transport: Fuse<C>,
|
||||
/// Requests waiting to be written to the wire.
|
||||
pending_requests: mpsc::Receiver<DispatchRequest<Req, Resp>>,
|
||||
/// Requests that were dropped.
|
||||
canceled_requests: CanceledRequests,
|
||||
/// Requests already written to the wire that haven't yet received responses.
|
||||
in_flight_requests: InFlightRequests<Resp>,
|
||||
/// Configures limits to prevent unlimited resource usage.
|
||||
config: Config,
|
||||
}
|
||||
|
||||
/// Critical errors that result in a Channel disconnecting.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ChannelError<E>
|
||||
where
|
||||
E: Error + Send + Sync + 'static,
|
||||
{
|
||||
/// Could not read from the transport.
|
||||
#[error("could not read from the transport")]
|
||||
Read(#[source] E),
|
||||
/// Could not ready the transport for writes.
|
||||
#[error("could not ready the transport for writes")]
|
||||
Ready(#[source] E),
|
||||
/// Could not write to the transport.
|
||||
#[error("could not write to the transport")]
|
||||
Write(#[source] E),
|
||||
/// Could not flush the transport.
|
||||
#[error("could not flush the transport")]
|
||||
Flush(#[source] E),
|
||||
/// Could not close the write end of the transport.
|
||||
#[error("could not close the write end of the transport")]
|
||||
Close(#[source] E),
|
||||
/// Could not poll expired requests.
|
||||
#[error("could not poll expired requests")]
|
||||
Timer(#[source] tokio::time::error::Error),
|
||||
}
|
||||
|
||||
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
|
||||
where
|
||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||
{
|
||||
fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<Resp> {
|
||||
self.as_mut().project().in_flight_requests
|
||||
}
|
||||
|
||||
fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<C>> {
|
||||
self.as_mut().project().transport
|
||||
}
|
||||
|
||||
fn poll_ready<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||
self.transport_pin_mut()
|
||||
.poll_ready(cx)
|
||||
.map_err(ChannelError::Ready)
|
||||
}
|
||||
|
||||
fn start_send(
|
||||
self: &mut Pin<&mut Self>,
|
||||
message: ClientMessage<Req>,
|
||||
) -> Result<(), ChannelError<C::Error>> {
|
||||
self.transport_pin_mut()
|
||||
.start_send(message)
|
||||
.map_err(ChannelError::Write)
|
||||
}
|
||||
|
||||
fn poll_flush<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||
self.transport_pin_mut()
|
||||
.poll_flush(cx)
|
||||
.map_err(ChannelError::Flush)
|
||||
}
|
||||
|
||||
fn poll_close<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||
self.transport_pin_mut()
|
||||
.poll_close(cx)
|
||||
.map_err(ChannelError::Close)
|
||||
}
|
||||
|
||||
fn canceled_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut CanceledRequests {
|
||||
self.as_mut().project().canceled_requests
|
||||
}
|
||||
|
||||
fn pending_requests_mut<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
) -> &'a mut mpsc::Receiver<DispatchRequest<Req, Resp>> {
|
||||
self.as_mut().project().pending_requests
|
||||
}
|
||||
|
||||
fn pump_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
self.transport_pin_mut()
|
||||
.poll_next(cx)
|
||||
.map_err(ChannelError::Read)
|
||||
.map_ok(|response| {
|
||||
self.complete(response);
|
||||
})
|
||||
}
|
||||
|
||||
fn pump_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
enum ReceiverStatus {
|
||||
Pending,
|
||||
Closed,
|
||||
}
|
||||
|
||||
let pending_requests_status = match self.as_mut().poll_write_request(cx)? {
|
||||
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
|
||||
Poll::Ready(None) => ReceiverStatus::Closed,
|
||||
Poll::Pending => ReceiverStatus::Pending,
|
||||
};
|
||||
|
||||
let canceled_requests_status = match self.as_mut().poll_write_cancel(cx)? {
|
||||
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
|
||||
Poll::Ready(None) => ReceiverStatus::Closed,
|
||||
Poll::Pending => ReceiverStatus::Pending,
|
||||
};
|
||||
|
||||
// Receiving Poll::Ready(None) when polling expired requests never indicates "Closed",
|
||||
// because there can temporarily be zero in-flight rquests. Therefore, there is no need to
|
||||
// track the status like is done with pending and cancelled requests.
|
||||
if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx) {
|
||||
// Expired requests are considered complete; there is no compelling reason to send a
|
||||
// cancellation message to the server, since it will have already exhausted its
|
||||
// allotted processing time.
|
||||
return Poll::Ready(Some(Ok(())));
|
||||
}
|
||||
|
||||
match (pending_requests_status, canceled_requests_status) {
|
||||
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
|
||||
ready!(self.poll_close(cx)?);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
(ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => {
|
||||
// No more messages to process, so flush any messages buffered in the transport.
|
||||
ready!(self.poll_flush(cx)?);
|
||||
|
||||
// Even if we fully-flush, we return Pending, because we have no more requests
|
||||
// or cancellations right now.
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Yields the next pending request, if one is ready to be sent.
|
||||
///
|
||||
/// Note that a request will only be yielded if the transport is *ready* to be written to (i.e.
|
||||
/// start_send would succeed).
|
||||
fn poll_next_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<DispatchRequest<Req, Resp>, ChannelError<C::Error>>>> {
|
||||
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
|
||||
tracing::info!(
|
||||
"At in-flight request capacity ({}/{}).",
|
||||
self.in_flight_requests().len(),
|
||||
self.config.max_in_flight_requests
|
||||
);
|
||||
|
||||
// No need to schedule a wakeup, because timers and responses are responsible
|
||||
// for clearing out in-flight requests.
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
ready!(self.ensure_writeable(cx)?);
|
||||
|
||||
loop {
|
||||
match ready!(self.pending_requests_mut().poll_recv(cx)) {
|
||||
Some(request) => {
|
||||
if request.response_completion.is_closed() {
|
||||
let _entered = request.span.enter();
|
||||
tracing::info!("AbortRequest");
|
||||
continue;
|
||||
}
|
||||
|
||||
return Poll::Ready(Some(Ok(request)));
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
|
||||
///
|
||||
/// Note that a request to cancel will only be yielded if the transport is *ready* to be
|
||||
/// written to (i.e. start_send would succeed).
|
||||
fn poll_next_cancellation(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(context::Context, Span, u64), ChannelError<C::Error>>>> {
|
||||
ready!(self.ensure_writeable(cx)?);
|
||||
|
||||
loop {
|
||||
match ready!(self.canceled_requests_mut().poll_next_unpin(cx)) {
|
||||
Some(request_id) => {
|
||||
if let Some((ctx, span)) = self.in_flight_requests().cancel_request(request_id)
|
||||
{
|
||||
return Poll::Ready(Some(Ok((ctx, span, request_id))));
|
||||
}
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns Ready if writing a message to the transport (i.e. via write_request or
|
||||
/// write_cancel) would not fail due to a full buffer. If the transport is not ready to be
|
||||
/// written to, flushes it until it is ready.
|
||||
fn ensure_writeable<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
while self.poll_ready(cx)?.is_pending() {
|
||||
ready!(self.poll_flush(cx)?);
|
||||
}
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
fn poll_write_request<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
let DispatchRequest {
|
||||
ctx,
|
||||
span,
|
||||
request_id,
|
||||
request,
|
||||
response_completion,
|
||||
} = match ready!(self.as_mut().poll_next_request(cx)?) {
|
||||
Some(dispatch_request) => dispatch_request,
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
let entered = span.enter();
|
||||
// poll_next_request only returns Ready if there is room to buffer another request.
|
||||
// Therefore, we can call write_request without fear of erroring due to a full
|
||||
// buffer.
|
||||
let request_id = request_id;
|
||||
let request = ClientMessage::Request(Request {
|
||||
id: request_id,
|
||||
message: request,
|
||||
context: context::Context {
|
||||
deadline: ctx.deadline,
|
||||
trace_context: ctx.trace_context,
|
||||
},
|
||||
});
|
||||
self.start_send(request)?;
|
||||
tracing::info!("SendRequest");
|
||||
drop(entered);
|
||||
|
||||
self.in_flight_requests()
|
||||
.insert_request(request_id, ctx, span, response_completion)
|
||||
.expect("Request IDs should be unique");
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
fn poll_write_cancel<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
|
||||
Some(triple) => triple,
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
let _entered = span.enter();
|
||||
|
||||
let cancel = ClientMessage::Cancel {
|
||||
trace_context: context.trace_context,
|
||||
request_id,
|
||||
};
|
||||
self.start_send(cancel)?;
|
||||
tracing::info!("CancelRequest");
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
/// Sends a server response to the client task that initiated the associated request.
|
||||
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
|
||||
self.in_flight_requests().complete_request(response)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
|
||||
where
|
||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||
{
|
||||
type Output = Result<(), ChannelError<C::Error>>;
|
||||
|
||||
fn poll(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||
loop {
|
||||
match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) {
|
||||
(Poll::Ready(None), _) => {
|
||||
tracing::info!("Shutdown: read half closed, so shutting down.");
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
(read, Poll::Ready(None)) => {
|
||||
if self.in_flight_requests.is_empty() {
|
||||
tracing::info!("Shutdown: write half closed, and no requests in flight.");
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
tracing::info!(
|
||||
"Shutdown: write half closed, and {} requests in flight.",
|
||||
self.in_flight_requests().len()
|
||||
);
|
||||
match read {
|
||||
Poll::Ready(Some(())) => continue,
|
||||
_ => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
(Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
|
||||
_ => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage
|
||||
/// the lifecycle of the request.
|
||||
#[derive(Debug)]
|
||||
struct DispatchRequest<Req, Resp> {
|
||||
pub ctx: context::Context,
|
||||
pub span: Span,
|
||||
pub request_id: u64,
|
||||
pub request: Req,
|
||||
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard};
|
||||
use crate::{
|
||||
client::{
|
||||
in_flight_requests::{DeadlineExceededError, InFlightRequests},
|
||||
Config,
|
||||
},
|
||||
context,
|
||||
transport::{self, channel::UnboundedChannel},
|
||||
ClientMessage, Response,
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{prelude::*, task::*};
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
pin::Pin,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::Span;
|
||||
|
||||
#[tokio::test]
|
||||
async fn response_completes_request_future() {
|
||||
let (mut dispatch, mut _channel, mut server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
dispatch
|
||||
.in_flight_requests
|
||||
.insert_request(0, context::current(), Span::current(), tx)
|
||||
.unwrap();
|
||||
server_channel
|
||||
.send(Response {
|
||||
request_id: 0,
|
||||
message: Ok("Resp".into()),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
|
||||
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_response_cancels_on_drop() {
|
||||
let (cancellation, mut canceled_requests) = cancellations();
|
||||
let (_, mut response) = oneshot::channel();
|
||||
drop(ResponseGuard::<u32> {
|
||||
response: &mut response,
|
||||
cancellation: &cancellation,
|
||||
request_id: 3,
|
||||
});
|
||||
// resp's drop() is run, which should send a cancel message.
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
assert_eq!(canceled_requests.poll_recv(cx), Poll::Ready(Some(3)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_response_doesnt_cancel_after_complete() {
|
||||
let (cancellation, mut canceled_requests) = cancellations();
|
||||
let (tx, mut response) = oneshot::channel();
|
||||
tx.send(Ok(Response {
|
||||
request_id: 0,
|
||||
message: Ok("well done"),
|
||||
}))
|
||||
.unwrap();
|
||||
// resp's drop() is run, but should not send a cancel message.
|
||||
ResponseGuard {
|
||||
response: &mut response,
|
||||
cancellation: &cancellation,
|
||||
request_id: 3,
|
||||
}
|
||||
.response()
|
||||
.await
|
||||
.unwrap();
|
||||
drop(cancellation);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
assert_eq!(canceled_requests.poll_recv(cx), Poll::Ready(None));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stage_request() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let _resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
|
||||
#[allow(unstable_name_collisions)]
|
||||
let req = dispatch.as_mut().poll_next_request(cx).ready();
|
||||
assert!(req.is_some());
|
||||
|
||||
let req = req.unwrap();
|
||||
assert_eq!(req.request_id, 0);
|
||||
assert_eq!(req.request, "hi".to_string());
|
||||
}
|
||||
|
||||
// Regression test for https://github.com/google/tarpc/issues/220
|
||||
#[tokio::test]
|
||||
async fn stage_request_channel_dropped_doesnt_panic() {
|
||||
let (mut dispatch, mut channel, mut server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
drop(channel);
|
||||
|
||||
assert!(dispatch.as_mut().poll(cx).is_ready());
|
||||
send_response(
|
||||
&mut server_channel,
|
||||
Response {
|
||||
request_id: 0,
|
||||
message: Ok("hello".into()),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
dispatch.await.unwrap();
|
||||
}
|
||||
|
||||
#[allow(unstable_name_collisions)]
|
||||
#[tokio::test]
|
||||
async fn stage_request_response_future_dropped_is_canceled_before_sending() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
|
||||
// Drop the channel so polling returns none if no requests are currently ready.
|
||||
drop(channel);
|
||||
// Test that a request future dropped before it's processed by dispatch will cause the request
|
||||
// to not be added to the in-flight request map.
|
||||
assert!(dispatch.as_mut().poll_next_request(cx).ready().is_none());
|
||||
}
|
||||
|
||||
#[allow(unstable_name_collisions)]
|
||||
#[tokio::test]
|
||||
async fn stage_request_response_future_dropped_is_canceled_after_sending() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let req = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
|
||||
assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
|
||||
assert!(!dispatch.in_flight_requests.is_empty());
|
||||
|
||||
// Test that a request future dropped after it's processed by dispatch will cause the request
|
||||
// to be removed from the in-flight request map.
|
||||
drop(req);
|
||||
assert_matches!(
|
||||
dispatch.as_mut().poll_next_cancellation(cx),
|
||||
Poll::Ready(Some(Ok(_)))
|
||||
);
|
||||
assert!(dispatch.in_flight_requests.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stage_request_response_closed_skipped() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
// Test that a request future that's closed its receiver but not yet canceled its request --
|
||||
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
|
||||
// map.
|
||||
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
resp.response.close();
|
||||
|
||||
assert!(dispatch.as_mut().poll_next_request(cx).is_pending());
|
||||
}
|
||||
|
||||
fn set_up() -> (
|
||||
Pin<
|
||||
Box<
|
||||
RequestDispatch<
|
||||
String,
|
||||
String,
|
||||
UnboundedChannel<Response<String>, ClientMessage<String>>,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
Channel<String, String>,
|
||||
UnboundedChannel<ClientMessage<String>, Response<String>>,
|
||||
) {
|
||||
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
|
||||
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
||||
let (cancellation, canceled_requests) = cancellations();
|
||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||
|
||||
let dispatch = RequestDispatch::<String, String, _> {
|
||||
transport: client_channel.fuse(),
|
||||
pending_requests: pending_requests,
|
||||
canceled_requests,
|
||||
in_flight_requests: InFlightRequests::default(),
|
||||
config: Config::default(),
|
||||
};
|
||||
|
||||
let channel = Channel {
|
||||
to_dispatch,
|
||||
cancellation,
|
||||
next_request_id: Arc::new(AtomicUsize::new(0)),
|
||||
};
|
||||
|
||||
(Box::pin(dispatch), channel, server_channel)
|
||||
}
|
||||
|
||||
async fn send_request<'a>(
|
||||
channel: &'a mut Channel<String, String>,
|
||||
request: &str,
|
||||
response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
|
||||
response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
|
||||
) -> ResponseGuard<'a, String> {
|
||||
let request_id =
|
||||
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||
let request = DispatchRequest {
|
||||
ctx: context::current(),
|
||||
span: Span::current(),
|
||||
request_id,
|
||||
request: request.to_string(),
|
||||
response_completion,
|
||||
};
|
||||
let response_guard = ResponseGuard {
|
||||
response,
|
||||
cancellation: &channel.cancellation,
|
||||
request_id,
|
||||
};
|
||||
channel.to_dispatch.send(request).await.unwrap();
|
||||
response_guard
|
||||
}
|
||||
|
||||
async fn send_response(
|
||||
channel: &mut UnboundedChannel<ClientMessage<String>, Response<String>>,
|
||||
response: Response<String>,
|
||||
) {
|
||||
channel.send(response).await.unwrap();
|
||||
}
|
||||
|
||||
trait PollTest {
|
||||
type T;
|
||||
fn unwrap(self) -> Poll<Self::T>;
|
||||
fn ready(self) -> Self::T;
|
||||
}
|
||||
|
||||
impl<T, E> PollTest for Poll<Option<Result<T, E>>>
|
||||
where
|
||||
E: ::std::fmt::Display,
|
||||
{
|
||||
type T = Option<T>;
|
||||
|
||||
fn unwrap(self) -> Poll<Option<T>> {
|
||||
match self {
|
||||
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn ready(self) -> Option<T> {
|
||||
match self {
|
||||
Poll::Ready(Some(Ok(t))) => Some(t),
|
||||
Poll::Ready(None) => None,
|
||||
Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()),
|
||||
Poll::Pending => panic!("Pending"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
134
tarpc/src/client/in_flight_requests.rs
Normal file
134
tarpc/src/client/in_flight_requests.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use crate::{
|
||||
context,
|
||||
util::{Compact, TimeUntil},
|
||||
Response,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_util::time::delay_queue::{self, DelayQueue};
|
||||
use tracing::Span;
|
||||
|
||||
/// Requests already written to the wire that haven't yet received responses.
|
||||
#[derive(Debug)]
|
||||
pub struct InFlightRequests<Resp> {
|
||||
request_data: FnvHashMap<u64, RequestData<Resp>>,
|
||||
deadlines: DelayQueue<u64>,
|
||||
}
|
||||
|
||||
impl<Resp> Default for InFlightRequests<Resp> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
request_data: Default::default(),
|
||||
deadlines: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The request exceeded its deadline.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[error("the request exceeded its deadline")]
|
||||
pub struct DeadlineExceededError;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RequestData<Resp> {
|
||||
ctx: context::Context,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
/// The key to remove the timer for the request's deadline.
|
||||
deadline_key: delay_queue::Key,
|
||||
}
|
||||
|
||||
/// An error returned when an attempt is made to insert a request with an ID that is already in
|
||||
/// use.
|
||||
#[derive(Debug)]
|
||||
pub struct AlreadyExistsError;
|
||||
|
||||
impl<Resp> InFlightRequests<Resp> {
|
||||
/// Returns the number of in-flight requests.
|
||||
pub fn len(&self) -> usize {
|
||||
self.request_data.len()
|
||||
}
|
||||
|
||||
/// Returns true iff there are no requests in flight.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.request_data.is_empty()
|
||||
}
|
||||
|
||||
/// Starts a request, unless a request with the same ID is already in flight.
|
||||
pub fn insert_request(
|
||||
&mut self,
|
||||
request_id: u64,
|
||||
ctx: context::Context,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
) -> Result<(), AlreadyExistsError> {
|
||||
match self.request_data.entry(request_id) {
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
let timeout = ctx.deadline.time_until();
|
||||
let deadline_key = self.deadlines.insert(request_id, timeout);
|
||||
vacant.insert(RequestData {
|
||||
ctx,
|
||||
span,
|
||||
response_completion,
|
||||
deadline_key,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
hash_map::Entry::Occupied(_) => Err(AlreadyExistsError),
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a request without aborting. Returns true iff the request was found.
|
||||
pub fn complete_request(&mut self, response: Response<Resp>) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&response.request_id) {
|
||||
let _entered = request_data.span.enter();
|
||||
tracing::info!("ReceiveResponse");
|
||||
self.request_data.compact(0.1);
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
let _ = request_data.response_completion.send(Ok(response));
|
||||
return true;
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
"No in-flight request found for request_id = {}.",
|
||||
response.request_id
|
||||
);
|
||||
|
||||
// If the response completion was absent, then the request was already canceled.
|
||||
false
|
||||
}
|
||||
|
||||
/// Cancels a request without completing (typically used when a request handle was dropped
|
||||
/// before the request completed).
|
||||
pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
Some((request_data.ctx, request_data.span))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Yields a request that has expired, completing it with a TimedOut error.
|
||||
/// The caller should send cancellation messages for any yielded request ID.
|
||||
pub fn poll_expired(&mut self, cx: &mut Context) -> Poll<Option<u64>> {
|
||||
self.deadlines.poll_expired(cx).map(|expired| {
|
||||
let request_id = expired?.into_inner();
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
let _entered = request_data.span.enter();
|
||||
tracing::error!("DeadlineExceeded");
|
||||
self.request_data.compact(0.1);
|
||||
let _ = request_data
|
||||
.response_completion
|
||||
.send(Err(DeadlineExceededError));
|
||||
}
|
||||
Some(request_id)
|
||||
})
|
||||
}
|
||||
}
|
||||
152
tarpc/src/context.rs
Normal file
152
tarpc/src/context.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Provides a request context that carries a deadline and trace context. This context is sent from
|
||||
//! client to server and is used by the server to enforce response deadlines.
|
||||
|
||||
use crate::trace::{self, TraceId};
|
||||
use opentelemetry::trace::TraceContextExt;
|
||||
use static_assertions::assert_impl_all;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
/// A request context that carries request-scoped information like deadlines and trace information.
|
||||
/// It is sent from client to server and is used by the server to enforce response deadlines.
|
||||
///
|
||||
/// The context should not be stored directly in a server implementation, because the context will
|
||||
/// be different for each request in scope.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Context {
|
||||
/// When the client expects the request to be complete by. The server should cancel the request
|
||||
/// if it is not complete by this time.
|
||||
#[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))]
|
||||
// Serialized as a Duration to prevent clock skew issues.
|
||||
#[cfg_attr(feature = "serde1", serde(with = "absolute_to_relative_time"))]
|
||||
pub deadline: SystemTime,
|
||||
/// Uniquely identifies requests originating from the same source.
|
||||
/// When a service handles a request by making requests itself, those requests should
|
||||
/// include the same `trace_id` as that included on the original request. This way,
|
||||
/// users can trace related actions across a distributed system.
|
||||
pub trace_context: trace::Context,
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
mod absolute_to_relative_time {
|
||||
pub use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
pub use std::time::{Duration, SystemTime};
|
||||
|
||||
pub fn serialize<S>(deadline: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let deadline = deadline
|
||||
.duration_since(SystemTime::now())
|
||||
.unwrap_or(Duration::ZERO);
|
||||
deadline.serialize(serializer)
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let deadline = Duration::deserialize(deserializer)?;
|
||||
Ok(SystemTime::now() + deadline)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct AbsoluteToRelative(#[serde(with = "self")] SystemTime);
|
||||
|
||||
#[test]
|
||||
fn test_serialize() {
|
||||
let now = SystemTime::now();
|
||||
let deadline = now + Duration::from_secs(10);
|
||||
let serialized_deadline = bincode::serialize(&AbsoluteToRelative(deadline)).unwrap();
|
||||
let deserialized_deadline: Duration = bincode::deserialize(&serialized_deadline).unwrap();
|
||||
// TODO: how to avoid flakiness?
|
||||
assert!(deserialized_deadline > Duration::from_secs(9));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize() {
|
||||
let deadline = Duration::from_secs(10);
|
||||
let serialized_deadline = bincode::serialize(&deadline).unwrap();
|
||||
let AbsoluteToRelative(deserialized_deadline) =
|
||||
bincode::deserialize(&serialized_deadline).unwrap();
|
||||
// TODO: how to avoid flakiness?
|
||||
assert!(deserialized_deadline > SystemTime::now() + Duration::from_secs(9));
|
||||
}
|
||||
}
|
||||
|
||||
assert_impl_all!(Context: Send, Sync);
|
||||
|
||||
fn ten_seconds_from_now() -> SystemTime {
|
||||
SystemTime::now() + Duration::from_secs(10)
|
||||
}
|
||||
|
||||
/// Returns the context for the current request, or a default Context if no request is active.
|
||||
pub fn current() -> Context {
|
||||
Context::current()
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Deadline(SystemTime);
|
||||
|
||||
impl Default for Deadline {
|
||||
fn default() -> Self {
|
||||
Self(ten_seconds_from_now())
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
/// Returns the context for the current request, or a default Context if no request is active.
|
||||
pub fn current() -> Self {
|
||||
let span = tracing::Span::current();
|
||||
Self {
|
||||
trace_context: trace::Context::try_from(&span)
|
||||
.unwrap_or_else(|_| trace::Context::default()),
|
||||
deadline: span
|
||||
.context()
|
||||
.get::<Deadline>()
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the ID of the request-scoped trace.
|
||||
pub fn trace_id(&self) -> &TraceId {
|
||||
&self.trace_context.trace_id
|
||||
}
|
||||
}
|
||||
|
||||
/// An extension trait for [`tracing::Span`] for propagating tarpc Contexts.
|
||||
pub(crate) trait SpanExt {
|
||||
/// Sets the given context on this span. Newly-created spans will be children of the given
|
||||
/// context's trace context.
|
||||
fn set_context(&self, context: &Context);
|
||||
}
|
||||
|
||||
impl SpanExt for tracing::Span {
|
||||
fn set_context(&self, context: &Context) {
|
||||
self.set_parent(
|
||||
opentelemetry::Context::new()
|
||||
.with_remote_span_context(opentelemetry::trace::SpanContext::new(
|
||||
opentelemetry::trace::TraceId::from(context.trace_context.trace_id),
|
||||
opentelemetry::trace::SpanId::from(context.trace_context.span_id),
|
||||
opentelemetry::trace::TraceFlags::from(context.trace_context.sampling_decision),
|
||||
true,
|
||||
opentelemetry::trace::TraceState::default(),
|
||||
))
|
||||
.with_value(Deadline(context.deadline)),
|
||||
);
|
||||
}
|
||||
}
|
||||
192
tarpc/src/lib.rs
192
tarpc/src/lib.rs
@@ -3,7 +3,6 @@
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! *Disclaimer*: This is not an official Google product.
|
||||
//!
|
||||
//! tarpc is an RPC framework for rust with a focus on ease of use. Defining a
|
||||
@@ -28,7 +27,7 @@
|
||||
//! process, and no context switching between different languages.
|
||||
//!
|
||||
//! Some other features of tarpc:
|
||||
//! - Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
|
||||
//! - Pluggable transport: any type implementing `Stream<Item = Request> + Sink<Response>` can be
|
||||
//! used as a transport to connect the client and server.
|
||||
//! - `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
|
||||
//! - Cascading cancellation: dropping a request will send a cancellation message to the server.
|
||||
@@ -39,6 +38,14 @@
|
||||
//! requests sent by the server that use the request context will propagate the request deadline.
|
||||
//! For example, if a server is handling a request with a 10s deadline, does 2s of work, then
|
||||
//! sends a request to another server, that server will see an 8s deadline.
|
||||
//! - Distributed tracing: tarpc is instrumented with
|
||||
//! [tracing](https://github.com/tokio-rs/tracing) primitives extended with
|
||||
//! [OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
|
||||
//! [Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
|
||||
//! each RPC can be traced through the client, server, and other dependencies downstream of the
|
||||
//! server. Even for applications not connected to a distributed tracing collector, the
|
||||
//! instrumentation can also be ingested by regular loggers like
|
||||
//! [env_logger](https://github.com/env-logger-rs/env_logger/).
|
||||
//! - Serde serialization: enabling the `serde1` Cargo feature will make service requests and
|
||||
//! responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
|
||||
//! be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
|
||||
@@ -47,7 +54,7 @@
|
||||
//! Add to your `Cargo.toml` dependencies:
|
||||
//!
|
||||
//! ```toml
|
||||
//! tarpc = "0.22.0"
|
||||
//! tarpc = "0.29"
|
||||
//! ```
|
||||
//!
|
||||
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
@@ -56,12 +63,14 @@
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! For this example, in addition to tarpc, also add two other dependencies to
|
||||
//! This example uses [tokio](https://tokio.rs), so add the following dependencies to
|
||||
//! your `Cargo.toml`:
|
||||
//!
|
||||
//! ```toml
|
||||
//! anyhow = "1.0"
|
||||
//! futures = "0.3"
|
||||
//! tokio = "0.2"
|
||||
//! tarpc = { version = "0.29", features = ["tokio1"] }
|
||||
//! tokio = { version = "1.0", features = ["macros"] }
|
||||
//! ```
|
||||
//!
|
||||
//! In the following example, we use an in-process channel for communication between
|
||||
@@ -79,9 +88,8 @@
|
||||
//! };
|
||||
//! use tarpc::{
|
||||
//! client, context,
|
||||
//! server::{self, Handler},
|
||||
//! server::{self, incoming::Incoming, Channel},
|
||||
//! };
|
||||
//! use std::io;
|
||||
//!
|
||||
//! // This is the service definition. It looks a lot like a trait definition.
|
||||
//! // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
@@ -103,9 +111,8 @@
|
||||
//! # };
|
||||
//! # use tarpc::{
|
||||
//! # client, context,
|
||||
//! # server::{self, Handler},
|
||||
//! # server::{self, incoming::Incoming},
|
||||
//! # };
|
||||
//! # use std::io;
|
||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
//! # #[tarpc::service]
|
||||
@@ -125,13 +132,13 @@
|
||||
//! type HelloFut = Ready<String>;
|
||||
//!
|
||||
//! fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
||||
//! future::ready(format!("Hello, {}!", name))
|
||||
//! future::ready(format!("Hello, {name}!"))
|
||||
//! }
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! Lastly let's write our `main` that will start the server. While this example uses an
|
||||
//! [in-process channel](rpc::transport::channel), tarpc also ships a generic [`serde_transport`]
|
||||
//! [in-process channel](transport::channel), tarpc also ships a generic [`serde_transport`]
|
||||
//! behind the `serde-transport` feature, with additional [TCP](serde_transport::tcp) functionality
|
||||
//! available behind the `tcp` feature.
|
||||
//!
|
||||
@@ -143,9 +150,8 @@
|
||||
//! # };
|
||||
//! # use tarpc::{
|
||||
//! # client, context,
|
||||
//! # server::{self, Handler},
|
||||
//! # server::{self, Channel},
|
||||
//! # };
|
||||
//! # use std::io;
|
||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
//! # #[tarpc::service]
|
||||
@@ -162,31 +168,29 @@
|
||||
//! # // an associated type representing the future output by the fn.
|
||||
//! # type HelloFut = Ready<String>;
|
||||
//! # fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
||||
//! # future::ready(format!("Hello, {}!", name))
|
||||
//! # future::ready(format!("Hello, {name}!"))
|
||||
//! # }
|
||||
//! # }
|
||||
//! # #[cfg(not(feature = "tokio1"))]
|
||||
//! # fn main() {}
|
||||
//! # #[cfg(feature = "tokio1")]
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> io::Result<()> {
|
||||
//! async fn main() -> anyhow::Result<()> {
|
||||
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
//!
|
||||
//! let server = server::new(server::Config::default())
|
||||
//! // incoming() takes a stream of transports such as would be returned by
|
||||
//! // TcpListener::incoming (but a stream instead of an iterator).
|
||||
//! .incoming(stream::once(future::ready(server_transport)))
|
||||
//! .respond_with(HelloServer.serve());
|
||||
//! let server = server::BaseChannel::with_defaults(server_transport);
|
||||
//! tokio::spawn(server.execute(HelloServer.serve()));
|
||||
//!
|
||||
//! tokio::spawn(server);
|
||||
//!
|
||||
//! // WorldClient is generated by the macro. It has a constructor `new` that takes a config and
|
||||
//! // any Transport as input
|
||||
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
|
||||
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
//! // that takes a config and any Transport as input.
|
||||
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
||||
//!
|
||||
//! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
//! // args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||
//! // specifies a deadline and trace information which can be helpful in debugging requests.
|
||||
//! let hello = client.hello(context::current(), "Stim".to_string()).await?;
|
||||
//!
|
||||
//! println!("{}", hello);
|
||||
//! println!("{hello}");
|
||||
//!
|
||||
//! Ok(())
|
||||
//! }
|
||||
@@ -200,8 +204,12 @@
|
||||
#![allow(clippy::type_complexity)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
pub mod rpc;
|
||||
pub use rpc::*;
|
||||
#[cfg(feature = "serde1")]
|
||||
#[doc(hidden)]
|
||||
pub use serde;
|
||||
|
||||
#[cfg(feature = "serde-transport")]
|
||||
pub use {tokio_serde, tokio_util};
|
||||
|
||||
#[cfg(feature = "serde-transport")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "serde-transport")))]
|
||||
@@ -209,6 +217,9 @@ pub mod serde_transport;
|
||||
|
||||
pub mod trace;
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
pub use tarpc_plugins::derive_serde;
|
||||
|
||||
/// The main macro that creates RPC services.
|
||||
///
|
||||
/// Rpc methods are specified, mirroring trait syntax:
|
||||
@@ -253,7 +264,7 @@ pub use tarpc_plugins::service;
|
||||
/// #[tarpc::server]
|
||||
/// impl World for HelloServer {
|
||||
/// async fn hello(self, _: context::Context, name: String) -> String {
|
||||
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
||||
/// format!("Hello, {name}! You are connected from {:?}.", self.0)
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
@@ -279,7 +290,7 @@ pub use tarpc_plugins::service;
|
||||
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
|
||||
/// + Send>> {
|
||||
/// Box::pin(async move {
|
||||
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
||||
/// format!("Hello, {name}! You are connected from {:?}.", self.0)
|
||||
/// })
|
||||
/// }
|
||||
/// }
|
||||
@@ -288,3 +299,124 @@ pub use tarpc_plugins::service;
|
||||
/// Note that this won't touch functions unless they have been annotated with
|
||||
/// `async`, meaning that this should not break existing code.
|
||||
pub use tarpc_plugins::server;
|
||||
|
||||
pub(crate) mod cancellations;
|
||||
pub mod client;
|
||||
pub mod context;
|
||||
pub mod server;
|
||||
pub mod transport;
|
||||
pub(crate) mod util;
|
||||
|
||||
pub use crate::transport::sealed::Transport;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use futures::task::*;
|
||||
use std::{error::Error, fmt::Display, io, time::SystemTime};
|
||||
|
||||
/// A message from a client to a server.
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[non_exhaustive]
|
||||
pub enum ClientMessage<T> {
|
||||
/// A request initiated by a user. The server responds to a request by invoking a
|
||||
/// service-provided request handler. The handler completes with a [`response`](Response), which
|
||||
/// the server sends back to the client.
|
||||
Request(Request<T>),
|
||||
/// A command to cancel an in-flight request, automatically sent by the client when a response
|
||||
/// future is dropped.
|
||||
///
|
||||
/// When received, the server will immediately cancel the main task (top-level future) of the
|
||||
/// request handler for the associated request. Any tasks spawned by the request handler will
|
||||
/// not be canceled, because the framework layer does not
|
||||
/// know about them.
|
||||
Cancel {
|
||||
/// The trace context associates the message with a specific chain of causally-related actions,
|
||||
/// possibly orchestrated across many distributed systems.
|
||||
#[cfg_attr(feature = "serde1", serde(default))]
|
||||
trace_context: trace::Context,
|
||||
/// The ID of the request to cancel.
|
||||
request_id: u64,
|
||||
},
|
||||
}
|
||||
|
||||
/// A request from a client to a server.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Request<T> {
|
||||
/// Trace context, deadline, and other cross-cutting concerns.
|
||||
pub context: context::Context,
|
||||
/// Uniquely identifies the request across all requests sent over a single channel.
|
||||
pub id: u64,
|
||||
/// The request body.
|
||||
pub message: T,
|
||||
}
|
||||
|
||||
/// A response from a server to a client.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Response<T> {
|
||||
/// The ID of the request being responded to.
|
||||
pub request_id: u64,
|
||||
/// The response body, or an error if the request failed.
|
||||
pub message: Result<T, ServerError>,
|
||||
}
|
||||
|
||||
/// An error indicating the server aborted the request early, e.g., due to request throttling.
|
||||
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[error("{kind:?}: {detail}")]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ServerError {
|
||||
#[cfg_attr(
|
||||
feature = "serde1",
|
||||
serde(serialize_with = "util::serde::serialize_io_error_kind_as_u32")
|
||||
)]
|
||||
#[cfg_attr(
|
||||
feature = "serde1",
|
||||
serde(deserialize_with = "util::serde::deserialize_io_error_kind_from_u32")
|
||||
)]
|
||||
/// The type of error that occurred to fail the request.
|
||||
pub kind: io::ErrorKind,
|
||||
/// A message describing more detail about the error that occurred.
|
||||
pub detail: String,
|
||||
}
|
||||
|
||||
impl<T> Request<T> {
|
||||
/// Returns the deadline for this request.
|
||||
pub fn deadline(&self) -> &SystemTime {
|
||||
&self.context.deadline
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait PollContext<T> {
|
||||
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static;
|
||||
|
||||
fn with_context<C, F>(self, f: F) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
F: FnOnce() -> C;
|
||||
}
|
||||
|
||||
impl<T, E> PollContext<T> for Poll<Option<Result<T, E>>>
|
||||
where
|
||||
E: Error + Send + Sync + 'static,
|
||||
{
|
||||
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
{
|
||||
self.map(|o| o.map(|r| r.context(context)))
|
||||
}
|
||||
|
||||
fn with_context<C, F>(self, f: F) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
F: FnOnce() -> C,
|
||||
{
|
||||
self.map(|o| o.map(|r| r.with_context(f)))
|
||||
}
|
||||
}
|
||||
|
||||
148
tarpc/src/rpc.rs
148
tarpc/src/rpc.rs
@@ -1,148 +0,0 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
#![deny(missing_docs, missing_debug_implementations)]
|
||||
|
||||
//! An RPC framework providing client and server.
|
||||
//!
|
||||
//! Features:
|
||||
//! * RPC deadlines, both client- and server-side.
|
||||
//! * Cascading cancellation (works with multiple hops).
|
||||
//! * Configurable limits
|
||||
//! * In-flight requests, both client and server-side.
|
||||
//! * Server-side limit is per-connection.
|
||||
//! * When the server reaches the in-flight request maximum, it returns a throttled error
|
||||
//! to the client.
|
||||
//! * When the client reaches the in-flight request max, messages are buffered up to a
|
||||
//! configurable maximum, beyond which the requests are back-pressured.
|
||||
//! * Server connections.
|
||||
//! * Total and per-IP limits.
|
||||
//! * When an incoming connection is accepted, if already at maximum, the connection is
|
||||
//! dropped.
|
||||
//! * Transport agnostic.
|
||||
|
||||
pub mod client;
|
||||
pub mod context;
|
||||
pub mod server;
|
||||
pub mod transport;
|
||||
pub(crate) mod util;
|
||||
|
||||
pub use crate::{client::Client, server::Server, trace, transport::sealed::Transport};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use futures::task::*;
|
||||
use std::{fmt::Display, io, time::SystemTime};
|
||||
|
||||
/// A message from a client to a server.
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[non_exhaustive]
|
||||
pub enum ClientMessage<T> {
|
||||
/// A request initiated by a user. The server responds to a request by invoking a
|
||||
/// service-provided request handler. The handler completes with a [`response`](Response), which
|
||||
/// the server sends back to the client.
|
||||
Request(Request<T>),
|
||||
/// A command to cancel an in-flight request, automatically sent by the client when a response
|
||||
/// future is dropped.
|
||||
///
|
||||
/// When received, the server will immediately cancel the main task (top-level future) of the
|
||||
/// request handler for the associated request. Any tasks spawned by the request handler will
|
||||
/// not be canceled, because the framework layer does not
|
||||
/// know about them.
|
||||
Cancel {
|
||||
/// The trace context associates the message with a specific chain of causally-related actions,
|
||||
/// possibly orchestrated across many distributed systems.
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
trace_context: trace::Context,
|
||||
/// The ID of the request to cancel.
|
||||
request_id: u64,
|
||||
},
|
||||
}
|
||||
|
||||
/// A request from a client to a server.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Request<T> {
|
||||
/// Trace context, deadline, and other cross-cutting concerns.
|
||||
pub context: context::Context,
|
||||
/// Uniquely identifies the request across all requests sent over a single channel.
|
||||
pub id: u64,
|
||||
/// The request body.
|
||||
pub message: T,
|
||||
}
|
||||
|
||||
/// A response from a server to a client.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Response<T> {
|
||||
/// The ID of the request being responded to.
|
||||
pub request_id: u64,
|
||||
/// The response body, or an error if the request failed.
|
||||
pub message: Result<T, ServerError>,
|
||||
}
|
||||
|
||||
/// An error response from a server to a client.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ServerError {
|
||||
#[cfg_attr(
|
||||
feature = "serde1",
|
||||
serde(serialize_with = "util::serde::serialize_io_error_kind_as_u32")
|
||||
)]
|
||||
#[cfg_attr(
|
||||
feature = "serde1",
|
||||
serde(deserialize_with = "util::serde::deserialize_io_error_kind_from_u32")
|
||||
)]
|
||||
/// The type of error that occurred to fail the request.
|
||||
pub kind: io::ErrorKind,
|
||||
/// A message describing more detail about the error that occurred.
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
impl From<ServerError> for io::Error {
|
||||
fn from(e: ServerError) -> io::Error {
|
||||
io::Error::new(e.kind, e.detail.unwrap_or_default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Request<T> {
|
||||
/// Returns the deadline for this request.
|
||||
pub fn deadline(&self) -> &SystemTime {
|
||||
&self.context.deadline
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) type PollIo<T> = Poll<Option<io::Result<T>>>;
|
||||
pub(crate) trait PollContext<T> {
|
||||
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static;
|
||||
|
||||
fn with_context<C, F>(self, f: F) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
F: FnOnce() -> C;
|
||||
}
|
||||
|
||||
impl<T> PollContext<T> for PollIo<T> {
|
||||
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
{
|
||||
self.map(|o| o.map(|r| r.context(context)))
|
||||
}
|
||||
|
||||
fn with_context<C, F>(self, f: F) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
F: FnOnce() -> C,
|
||||
{
|
||||
self.map(|o| o.map(|r| r.with_context(f)))
|
||||
}
|
||||
}
|
||||
@@ -1,155 +0,0 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Provides a client that connects to a server and sends multiplexed requests.
|
||||
|
||||
use crate::context;
|
||||
use futures::prelude::*;
|
||||
use std::io;
|
||||
|
||||
/// Provides a [`Client`] backed by a transport.
|
||||
pub mod channel;
|
||||
pub use channel::{new, Channel};
|
||||
|
||||
/// Sends multiplexed requests to, and receives responses from, a server.
|
||||
pub trait Client<'a, Req> {
|
||||
/// The response type.
|
||||
type Response;
|
||||
|
||||
/// The future response.
|
||||
type Future: Future<Output = io::Result<Self::Response>> + 'a;
|
||||
|
||||
/// Initiates a request, sending it to the dispatch task.
|
||||
///
|
||||
/// Returns a [`Future`] that resolves to this client and the future response
|
||||
/// once the request is successfully enqueued.
|
||||
///
|
||||
/// [`Future`]: futures::Future
|
||||
fn call(&'a mut self, ctx: context::Context, request: Req) -> Self::Future;
|
||||
|
||||
/// Returns a Client that applies a post-processing function to the returned response.
|
||||
fn map_response<F, R>(self, f: F) -> MapResponse<Self, F>
|
||||
where
|
||||
F: FnMut(Self::Response) -> R,
|
||||
Self: Sized,
|
||||
{
|
||||
MapResponse { inner: self, f }
|
||||
}
|
||||
|
||||
/// Returns a Client that applies a pre-processing function to the request.
|
||||
fn with_request<F, Req2>(self, f: F) -> WithRequest<Self, F>
|
||||
where
|
||||
F: FnMut(Req2) -> Req,
|
||||
Self: Sized,
|
||||
{
|
||||
WithRequest { inner: self, f }
|
||||
}
|
||||
}
|
||||
|
||||
/// A Client that applies a function to the returned response.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MapResponse<C, F> {
|
||||
inner: C,
|
||||
f: F,
|
||||
}
|
||||
|
||||
impl<'a, C, F, Req, Resp, Resp2> Client<'a, Req> for MapResponse<C, F>
|
||||
where
|
||||
C: Client<'a, Req, Response = Resp>,
|
||||
F: FnMut(Resp) -> Resp2 + 'a,
|
||||
{
|
||||
type Response = Resp2;
|
||||
type Future = futures::future::MapOk<<C as Client<'a, Req>>::Future, &'a mut F>;
|
||||
|
||||
fn call(&'a mut self, ctx: context::Context, request: Req) -> Self::Future {
|
||||
self.inner.call(ctx, request).map_ok(&mut self.f)
|
||||
}
|
||||
}
|
||||
|
||||
/// A Client that applies a pre-processing function to the request.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct WithRequest<C, F> {
|
||||
inner: C,
|
||||
f: F,
|
||||
}
|
||||
|
||||
impl<'a, C, F, Req, Req2, Resp> Client<'a, Req2> for WithRequest<C, F>
|
||||
where
|
||||
C: Client<'a, Req, Response = Resp>,
|
||||
F: FnMut(Req2) -> Req,
|
||||
{
|
||||
type Response = Resp;
|
||||
type Future = <C as Client<'a, Req>>::Future;
|
||||
|
||||
fn call(&'a mut self, ctx: context::Context, request: Req2) -> Self::Future {
|
||||
self.inner.call(ctx, (self.f)(request))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, Req, Resp> Client<'a, Req> for Channel<Req, Resp>
|
||||
where
|
||||
Req: 'a,
|
||||
Resp: 'a,
|
||||
{
|
||||
type Response = Resp;
|
||||
type Future = channel::Call<'a, Req, Resp>;
|
||||
|
||||
fn call(&'a mut self, ctx: context::Context, request: Req) -> channel::Call<'a, Req, Resp> {
|
||||
self.call(ctx, request)
|
||||
}
|
||||
}
|
||||
|
||||
/// Settings that control the behavior of the client.
|
||||
#[derive(Clone, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub struct Config {
|
||||
/// The number of requests that can be in flight at once.
|
||||
/// `max_in_flight_requests` controls the size of the map used by the client
|
||||
/// for storing pending requests.
|
||||
pub max_in_flight_requests: usize,
|
||||
/// The number of requests that can be buffered client-side before being sent.
|
||||
/// `pending_requests_buffer` controls the size of the channel clients use
|
||||
/// to communicate with the request dispatch task.
|
||||
pub pending_request_buffer: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Config {
|
||||
max_in_flight_requests: 1_000,
|
||||
pending_request_buffer: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A channel and dispatch pair. The dispatch drives the sending and receiving of requests
|
||||
/// and must be polled continuously or spawned.
|
||||
#[derive(Debug)]
|
||||
pub struct NewClient<C, D> {
|
||||
/// The new client.
|
||||
pub client: C,
|
||||
/// The client's dispatch.
|
||||
pub dispatch: D,
|
||||
}
|
||||
|
||||
impl<C, D, E> NewClient<C, D>
|
||||
where
|
||||
D: Future<Output = Result<(), E>> + Send + 'static,
|
||||
E: std::fmt::Display,
|
||||
{
|
||||
/// Helper method to spawn the dispatch on the default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub fn spawn(self) -> io::Result<C> {
|
||||
use log::error;
|
||||
|
||||
let dispatch = self
|
||||
.dispatch
|
||||
.unwrap_or_else(move |e| error!("Connection broken: {}", e));
|
||||
tokio::spawn(dispatch);
|
||||
Ok(self.client)
|
||||
}
|
||||
}
|
||||
@@ -1,907 +0,0 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use crate::{
|
||||
context,
|
||||
trace::SpanId,
|
||||
util::{Compact, TimeUntil},
|
||||
ClientMessage, PollContext, PollIo, Request, Response, Transport,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
prelude::*,
|
||||
ready,
|
||||
stream::Fuse,
|
||||
task::*,
|
||||
};
|
||||
use log::{debug, info, trace};
|
||||
use pin_project::{pin_project, pinned_drop};
|
||||
use std::{
|
||||
io,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use super::{Config, NewClient};
|
||||
|
||||
/// Handles communication from the client to request dispatch.
|
||||
#[derive(Debug)]
|
||||
pub struct Channel<Req, Resp> {
|
||||
to_dispatch: mpsc::Sender<DispatchRequest<Req, Resp>>,
|
||||
/// Channel to send a cancel message to the dispatcher.
|
||||
cancellation: RequestCancellation,
|
||||
/// The ID to use for the next request to stage.
|
||||
next_request_id: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl<Req, Resp> Clone for Channel<Req, Resp> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
to_dispatch: self.to_dispatch.clone(),
|
||||
cancellation: self.cancellation.clone(),
|
||||
next_request_id: self.next_request_id.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A future returned by [`Channel::send`] that resolves to a server response.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
struct Send<'a, Req, Resp> {
|
||||
#[pin]
|
||||
fut: MapOkDispatchResponse<SendMapErrConnectionReset<'a, Req, Resp>, Resp>,
|
||||
}
|
||||
|
||||
type SendMapErrConnectionReset<'a, Req, Resp> = MapErrConnectionReset<
|
||||
futures::sink::Send<'a, mpsc::Sender<DispatchRequest<Req, Resp>>, DispatchRequest<Req, Resp>>,
|
||||
>;
|
||||
|
||||
impl<'a, Req, Resp> Future for Send<'a, Req, Resp> {
|
||||
type Output = io::Result<DispatchResponse<Resp>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.as_mut().project().fut.poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// A future returned by [`Channel::call`] that resolves to a server response.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct Call<'a, Req, Resp> {
|
||||
#[pin]
|
||||
fut: tokio::time::Timeout<AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>>,
|
||||
}
|
||||
|
||||
impl<'a, Req, Resp> Future for Call<'a, Req, Resp> {
|
||||
type Output = io::Result<Resp>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let resp = ready!(self.as_mut().project().fut.poll(cx));
|
||||
Poll::Ready(match resp {
|
||||
Ok(resp) => resp,
|
||||
Err(tokio::time::Elapsed { .. }) => Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"Client dropped expired request.".to_string(),
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> Channel<Req, Resp> {
|
||||
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
||||
/// resolves when the request is sent (not when the response is received).
|
||||
fn send(&mut self, mut ctx: context::Context, request: Req) -> Send<Req, Resp> {
|
||||
// Convert the context to the call context.
|
||||
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
|
||||
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
|
||||
|
||||
let (response_completion, response) = oneshot::channel();
|
||||
let cancellation = self.cancellation.clone();
|
||||
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
||||
Send {
|
||||
fut: MapOkDispatchResponse::new(
|
||||
MapErrConnectionReset::new(self.to_dispatch.send(DispatchRequest {
|
||||
ctx,
|
||||
request_id,
|
||||
request,
|
||||
response_completion,
|
||||
})),
|
||||
DispatchResponse {
|
||||
response,
|
||||
complete: false,
|
||||
request_id,
|
||||
cancellation,
|
||||
ctx,
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
||||
/// resolves to the response.
|
||||
pub fn call(&mut self, ctx: context::Context, request: Req) -> Call<Req, Resp> {
|
||||
let timeout = ctx.deadline.time_until();
|
||||
trace!(
|
||||
"[{}] Queuing request with timeout {:?}.",
|
||||
ctx.trace_id(),
|
||||
timeout,
|
||||
);
|
||||
|
||||
Call {
|
||||
fut: tokio::time::timeout(timeout, AndThenIdent::new(self.send(ctx, request))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A server response that is completed by request dispatch when the corresponding response
|
||||
/// arrives off the wire.
|
||||
#[pin_project(PinnedDrop)]
|
||||
#[derive(Debug)]
|
||||
struct DispatchResponse<Resp> {
|
||||
response: oneshot::Receiver<Response<Resp>>,
|
||||
ctx: context::Context,
|
||||
complete: bool,
|
||||
cancellation: RequestCancellation,
|
||||
request_id: u64,
|
||||
}
|
||||
|
||||
impl<Resp> Future for DispatchResponse<Resp> {
|
||||
type Output = io::Result<Resp>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
|
||||
let resp = ready!(self.response.poll_unpin(cx));
|
||||
self.complete = true;
|
||||
Poll::Ready(match resp {
|
||||
Ok(resp) => Ok(resp.message?),
|
||||
Err(oneshot::Canceled) => {
|
||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||
// there's nothing listening on the other side, so there's no point in
|
||||
// propagating cancellation.
|
||||
Err(io::Error::from(io::ErrorKind::ConnectionReset))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Cancels the request when dropped, if not already complete.
|
||||
#[pinned_drop]
|
||||
impl<Resp> PinnedDrop for DispatchResponse<Resp> {
|
||||
fn drop(mut self: Pin<&mut Self>) {
|
||||
if !self.complete {
|
||||
// The receiver needs to be closed to handle the edge case that the request has not
|
||||
// yet been received by the dispatch task. It is possible for the cancel message to
|
||||
// arrive before the request itself, in which case the request could get stuck in the
|
||||
// dispatch map forever if the server never responds (e.g. if the server dies while
|
||||
// responding). Even if the server does respond, it will have unnecessarily done work
|
||||
// for a client no longer waiting for a response. To avoid this, the dispatch task
|
||||
// checks if the receiver is closed before inserting the request in the map. By
|
||||
// closing the receiver before sending the cancel message, it is guaranteed that if the
|
||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||
// receiver as closed.
|
||||
self.response.close();
|
||||
let request_id = self.request_id;
|
||||
self.cancellation.cancel(request_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the
|
||||
/// channel.
|
||||
pub fn new<Req, Resp, C>(
|
||||
config: Config,
|
||||
transport: C,
|
||||
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C>>
|
||||
where
|
||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||
{
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
|
||||
let (cancellation, canceled_requests) = cancellations();
|
||||
let canceled_requests = canceled_requests.fuse();
|
||||
|
||||
NewClient {
|
||||
client: Channel {
|
||||
to_dispatch,
|
||||
cancellation,
|
||||
next_request_id: Arc::new(AtomicU64::new(0)),
|
||||
},
|
||||
dispatch: RequestDispatch {
|
||||
config,
|
||||
canceled_requests,
|
||||
transport: transport.fuse(),
|
||||
in_flight_requests: FnvHashMap::default(),
|
||||
pending_requests: pending_requests.fuse(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
|
||||
/// and dispatching responses to the appropriate channel.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct RequestDispatch<Req, Resp, C> {
|
||||
/// Writes requests to the wire and reads responses off the wire.
|
||||
#[pin]
|
||||
transport: Fuse<C>,
|
||||
/// Requests waiting to be written to the wire.
|
||||
#[pin]
|
||||
pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>,
|
||||
/// Requests that were dropped.
|
||||
#[pin]
|
||||
canceled_requests: Fuse<CanceledRequests>,
|
||||
/// Requests already written to the wire that haven't yet received responses.
|
||||
in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>,
|
||||
/// Configures limits to prevent unlimited resource usage.
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
|
||||
where
|
||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||
{
|
||||
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
Poll::Ready(
|
||||
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
|
||||
Some(response) => {
|
||||
self.complete(response);
|
||||
Some(Ok(()))
|
||||
}
|
||||
None => None,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
enum ReceiverStatus {
|
||||
NotReady,
|
||||
Closed,
|
||||
}
|
||||
|
||||
let pending_requests_status = match self.as_mut().poll_next_request(cx)? {
|
||||
Poll::Ready(Some(dispatch_request)) => {
|
||||
self.as_mut().write_request(dispatch_request)?;
|
||||
return Poll::Ready(Some(Ok(())));
|
||||
}
|
||||
Poll::Ready(None) => ReceiverStatus::Closed,
|
||||
Poll::Pending => ReceiverStatus::NotReady,
|
||||
};
|
||||
|
||||
let canceled_requests_status = match self.as_mut().poll_next_cancellation(cx)? {
|
||||
Poll::Ready(Some((context, request_id))) => {
|
||||
self.as_mut().write_cancel(context, request_id)?;
|
||||
return Poll::Ready(Some(Ok(())));
|
||||
}
|
||||
Poll::Ready(None) => ReceiverStatus::Closed,
|
||||
Poll::Pending => ReceiverStatus::NotReady,
|
||||
};
|
||||
|
||||
match (pending_requests_status, canceled_requests_status) {
|
||||
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
|
||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
(ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
|
||||
// No more messages to process, so flush any messages buffered in the transport.
|
||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
||||
|
||||
// Even if we fully-flush, we return Pending, because we have no more requests
|
||||
// or cancellations right now.
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Yields the next pending request, if one is ready to be sent.
|
||||
fn poll_next_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<DispatchRequest<Req, Resp>> {
|
||||
if self.as_mut().project().in_flight_requests.len() >= self.config.max_in_flight_requests {
|
||||
info!(
|
||||
"At in-flight request capacity ({}/{}).",
|
||||
self.as_mut().project().in_flight_requests.len(),
|
||||
self.config.max_in_flight_requests
|
||||
);
|
||||
|
||||
// No need to schedule a wakeup, because timers and responses are responsible
|
||||
// for clearing out in-flight requests.
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? {
|
||||
// We can't yield a request-to-be-sent before the transport is capable of buffering it.
|
||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
||||
}
|
||||
|
||||
loop {
|
||||
match ready!(self.as_mut().project().pending_requests.poll_next_unpin(cx)) {
|
||||
Some(request) => {
|
||||
if request.response_completion.is_canceled() {
|
||||
trace!(
|
||||
"[{}] Request canceled before being sent.",
|
||||
request.ctx.trace_id()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
return Poll::Ready(Some(Ok(request)));
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
|
||||
fn poll_next_cancellation(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<(context::Context, u64)> {
|
||||
while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? {
|
||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
||||
}
|
||||
|
||||
loop {
|
||||
let cancellation = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.canceled_requests
|
||||
.poll_next_unpin(cx);
|
||||
match ready!(cancellation) {
|
||||
Some(request_id) => {
|
||||
if let Some(in_flight_data) = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&request_id)
|
||||
{
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
debug!("[{}] Removed request.", in_flight_data.ctx.trace_id());
|
||||
return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id))));
|
||||
}
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn write_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
dispatch_request: DispatchRequest<Req, Resp>,
|
||||
) -> io::Result<()> {
|
||||
let request_id = dispatch_request.request_id;
|
||||
let request = ClientMessage::Request(Request {
|
||||
id: request_id,
|
||||
message: dispatch_request.request,
|
||||
context: context::Context {
|
||||
deadline: dispatch_request.ctx.deadline,
|
||||
trace_context: dispatch_request.ctx.trace_context,
|
||||
},
|
||||
});
|
||||
self.as_mut().project().transport.start_send(request)?;
|
||||
self.as_mut().project().in_flight_requests.insert(
|
||||
request_id,
|
||||
InFlightData {
|
||||
ctx: dispatch_request.ctx,
|
||||
response_completion: dispatch_request.response_completion,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_cancel(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: context::Context,
|
||||
request_id: u64,
|
||||
) -> io::Result<()> {
|
||||
let trace_id = *context.trace_id();
|
||||
let cancel = ClientMessage::Cancel {
|
||||
trace_context: context.trace_context,
|
||||
request_id,
|
||||
};
|
||||
self.as_mut().project().transport.start_send(cancel)?;
|
||||
trace!("[{}] Cancel message sent.", trace_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sends a server response to the client task that initiated the associated request.
|
||||
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
|
||||
if let Some(in_flight_data) = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&response.request_id)
|
||||
{
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
|
||||
trace!("[{}] Received response.", in_flight_data.ctx.trace_id());
|
||||
let _ = in_flight_data.response_completion.send(response);
|
||||
return true;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"No in-flight request found for request_id = {}.",
|
||||
response.request_id
|
||||
);
|
||||
|
||||
// If the response completion was absent, then the request was already canceled.
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
|
||||
where
|
||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||
{
|
||||
type Output = anyhow::Result<()>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<anyhow::Result<()>> {
|
||||
loop {
|
||||
match (
|
||||
self.as_mut()
|
||||
.pump_read(cx)
|
||||
.context("failed to read from transport")?,
|
||||
self.as_mut()
|
||||
.pump_write(cx)
|
||||
.context("failed to write to transport")?,
|
||||
) {
|
||||
(Poll::Ready(None), _) => {
|
||||
info!("Shutdown: read half closed, so shutting down.");
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
(read, Poll::Ready(None)) => {
|
||||
if self.as_mut().project().in_flight_requests.is_empty() {
|
||||
info!("Shutdown: write half closed, and no requests in flight.");
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
info!(
|
||||
"Shutdown: write half closed, and {} requests in flight.",
|
||||
self.as_mut().project().in_flight_requests.len()
|
||||
);
|
||||
match read {
|
||||
Poll::Ready(Some(())) => continue,
|
||||
_ => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
(Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
|
||||
_ => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage
|
||||
/// the lifecycle of the request.
|
||||
#[derive(Debug)]
|
||||
struct DispatchRequest<Req, Resp> {
|
||||
ctx: context::Context,
|
||||
request_id: u64,
|
||||
request: Req,
|
||||
response_completion: oneshot::Sender<Response<Resp>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InFlightData<Resp> {
|
||||
ctx: context::Context,
|
||||
response_completion: oneshot::Sender<Response<Resp>>,
|
||||
}
|
||||
|
||||
/// Sends request cancellation signals.
|
||||
#[derive(Debug, Clone)]
|
||||
struct RequestCancellation(mpsc::UnboundedSender<u64>);
|
||||
|
||||
/// A stream of IDs of requests that have been canceled.
|
||||
#[derive(Debug)]
|
||||
struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
|
||||
|
||||
/// Returns a channel to send request cancellation messages.
|
||||
fn cancellations() -> (RequestCancellation, CanceledRequests) {
|
||||
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
|
||||
// bounded by the number of in-flight requests. Additionally, each request has a clone
|
||||
// of the sender, so the bounded channel would have the same behavior,
|
||||
// since it guarantees a slot.
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
(RequestCancellation(tx), CanceledRequests(rx))
|
||||
}
|
||||
|
||||
impl RequestCancellation {
|
||||
/// Cancels the request with ID `request_id`.
|
||||
fn cancel(&mut self, request_id: u64) {
|
||||
let _ = self.0.unbounded_send(request_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for CanceledRequests {
|
||||
type Item = u64;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
||||
self.0.poll_next_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
struct MapErrConnectionReset<Fut> {
|
||||
#[pin]
|
||||
future: Fut,
|
||||
finished: Option<()>,
|
||||
}
|
||||
|
||||
impl<Fut> MapErrConnectionReset<Fut> {
|
||||
fn new(future: Fut) -> MapErrConnectionReset<Fut> {
|
||||
MapErrConnectionReset {
|
||||
future,
|
||||
finished: Some(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fut> Future for MapErrConnectionReset<Fut>
|
||||
where
|
||||
Fut: TryFuture,
|
||||
{
|
||||
type Output = io::Result<Fut::Ok>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.as_mut().project().future.try_poll(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(result) => {
|
||||
self.project().finished.take().expect(
|
||||
"MapErrConnectionReset must not be polled after it returned `Poll::Ready`",
|
||||
);
|
||||
Poll::Ready(result.map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
struct MapOkDispatchResponse<Fut, Resp> {
|
||||
#[pin]
|
||||
future: Fut,
|
||||
response: Option<DispatchResponse<Resp>>,
|
||||
}
|
||||
|
||||
impl<Fut, Resp> MapOkDispatchResponse<Fut, Resp> {
|
||||
fn new(future: Fut, response: DispatchResponse<Resp>) -> MapOkDispatchResponse<Fut, Resp> {
|
||||
MapOkDispatchResponse {
|
||||
future,
|
||||
response: Some(response),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fut, Resp> Future for MapOkDispatchResponse<Fut, Resp>
|
||||
where
|
||||
Fut: TryFuture,
|
||||
{
|
||||
type Output = Result<DispatchResponse<Resp>, Fut::Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.as_mut().project().future.try_poll(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(result) => {
|
||||
let response = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.response
|
||||
.take()
|
||||
.expect("MapOk must not be polled after it returned `Poll::Ready`");
|
||||
Poll::Ready(result.map(|_| response))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
struct AndThenIdent<Fut1, Fut2> {
|
||||
#[pin]
|
||||
try_chain: TryChain<Fut1, Fut2>,
|
||||
}
|
||||
|
||||
impl<Fut1, Fut2> AndThenIdent<Fut1, Fut2>
|
||||
where
|
||||
Fut1: TryFuture<Ok = Fut2>,
|
||||
Fut2: TryFuture,
|
||||
{
|
||||
/// Creates a new `Then`.
|
||||
fn new(future: Fut1) -> AndThenIdent<Fut1, Fut2> {
|
||||
AndThenIdent {
|
||||
try_chain: TryChain::new(future),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fut1, Fut2> Future for AndThenIdent<Fut1, Fut2>
|
||||
where
|
||||
Fut1: TryFuture<Ok = Fut2>,
|
||||
Fut2: TryFuture<Error = Fut1::Error>,
|
||||
{
|
||||
type Output = Result<Fut2::Ok, Fut2::Error>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.project().try_chain.poll(cx, |result| match result {
|
||||
Ok(ok) => TryChainAction::Future(ok),
|
||||
Err(err) => TryChainAction::Output(Err(err)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project(project = TryChainProj)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
#[derive(Debug)]
|
||||
enum TryChain<Fut1, Fut2> {
|
||||
First(#[pin] Fut1),
|
||||
Second(#[pin] Fut2),
|
||||
Empty,
|
||||
}
|
||||
|
||||
enum TryChainAction<Fut2>
|
||||
where
|
||||
Fut2: TryFuture,
|
||||
{
|
||||
Future(Fut2),
|
||||
Output(Result<Fut2::Ok, Fut2::Error>),
|
||||
}
|
||||
|
||||
impl<Fut1, Fut2> TryChain<Fut1, Fut2>
|
||||
where
|
||||
Fut1: TryFuture<Ok = Fut2>,
|
||||
Fut2: TryFuture,
|
||||
{
|
||||
fn new(fut1: Fut1) -> TryChain<Fut1, Fut2> {
|
||||
TryChain::First(fut1)
|
||||
}
|
||||
|
||||
fn poll<F>(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
f: F,
|
||||
) -> Poll<Result<Fut2::Ok, Fut2::Error>>
|
||||
where
|
||||
F: FnOnce(Result<Fut1::Ok, Fut1::Error>) -> TryChainAction<Fut2>,
|
||||
{
|
||||
let mut f = Some(f);
|
||||
|
||||
loop {
|
||||
let output = match self.as_mut().project() {
|
||||
TryChainProj::First(fut1) => {
|
||||
// Poll the first future
|
||||
match fut1.try_poll(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(output) => output,
|
||||
}
|
||||
}
|
||||
TryChainProj::Second(fut2) => {
|
||||
// Poll the second future
|
||||
return fut2.try_poll(cx);
|
||||
}
|
||||
TryChainProj::Empty => {
|
||||
panic!("future must not be polled after it returned `Poll::Ready`");
|
||||
}
|
||||
};
|
||||
|
||||
self.set(TryChain::Empty); // Drop fut1
|
||||
let f = f.take().unwrap();
|
||||
match f(output) {
|
||||
TryChainAction::Future(fut2) => self.set(TryChain::Second(fut2)),
|
||||
TryChainAction::Output(output) => return Poll::Ready(output),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
cancellations, CanceledRequests, Channel, DispatchResponse, RequestCancellation,
|
||||
RequestDispatch,
|
||||
};
|
||||
use crate::{
|
||||
client::Config,
|
||||
context,
|
||||
transport::{self, channel::UnboundedChannel},
|
||||
ClientMessage, Response,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
prelude::*,
|
||||
task::*,
|
||||
};
|
||||
use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc};
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn dispatch_response_cancels_on_drop() {
|
||||
let (cancellation, mut canceled_requests) = cancellations();
|
||||
let (_, response) = oneshot::channel();
|
||||
drop(DispatchResponse::<u32> {
|
||||
response,
|
||||
cancellation,
|
||||
complete: false,
|
||||
request_id: 3,
|
||||
ctx: context::current(),
|
||||
});
|
||||
// resp's drop() is run, which should send a cancel message.
|
||||
assert_eq!(canceled_requests.0.try_next().unwrap(), Some(3));
|
||||
}
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn stage_request() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let dispatch = Pin::new(&mut dispatch);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
|
||||
let _resp = send_request(&mut channel, "hi").await;
|
||||
|
||||
let req = dispatch.poll_next_request(cx).ready();
|
||||
assert!(req.is_some());
|
||||
|
||||
let req = req.unwrap();
|
||||
assert_eq!(req.request_id, 0);
|
||||
assert_eq!(req.request, "hi".to_string());
|
||||
}
|
||||
|
||||
// Regression test for https://github.com/google/tarpc/issues/220
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn stage_request_channel_dropped_doesnt_panic() {
|
||||
let (mut dispatch, mut channel, mut server_channel) = set_up();
|
||||
let mut dispatch = Pin::new(&mut dispatch);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
|
||||
let _ = send_request(&mut channel, "hi").await;
|
||||
drop(channel);
|
||||
|
||||
assert!(dispatch.as_mut().poll(cx).is_ready());
|
||||
send_response(
|
||||
&mut server_channel,
|
||||
Response {
|
||||
request_id: 0,
|
||||
message: Ok("hello".into()),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
dispatch.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn stage_request_response_future_dropped_is_canceled_before_sending() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let dispatch = Pin::new(&mut dispatch);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
|
||||
let _ = send_request(&mut channel, "hi").await;
|
||||
|
||||
// Drop the channel so polling returns none if no requests are currently ready.
|
||||
drop(channel);
|
||||
// Test that a request future dropped before it's processed by dispatch will cause the request
|
||||
// to not be added to the in-flight request map.
|
||||
assert!(dispatch.poll_next_request(cx).ready().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn stage_request_response_future_dropped_is_canceled_after_sending() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let mut dispatch = Pin::new(&mut dispatch);
|
||||
|
||||
let req = send_request(&mut channel, "hi").await;
|
||||
|
||||
assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
|
||||
assert!(!dispatch.as_mut().project().in_flight_requests.is_empty());
|
||||
|
||||
// Test that a request future dropped after it's processed by dispatch will cause the request
|
||||
// to be removed from the in-flight request map.
|
||||
drop(req);
|
||||
if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(cx).unwrap() {
|
||||
// ok
|
||||
} else {
|
||||
panic!("Expected request to be cancelled")
|
||||
};
|
||||
assert!(dispatch.project().in_flight_requests.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn stage_request_response_closed_skipped() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let dispatch = Pin::new(&mut dispatch);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
|
||||
// Test that a request future that's closed its receiver but not yet canceled its request --
|
||||
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
|
||||
// map.
|
||||
let mut resp = send_request(&mut channel, "hi").await;
|
||||
resp.response.close();
|
||||
|
||||
assert!(dispatch.poll_next_request(cx).is_pending());
|
||||
}
|
||||
|
||||
fn set_up() -> (
|
||||
RequestDispatch<String, String, UnboundedChannel<Response<String>, ClientMessage<String>>>,
|
||||
Channel<String, String>,
|
||||
UnboundedChannel<ClientMessage<String>, Response<String>>,
|
||||
) {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
||||
let (cancel_tx, canceled_requests) = mpsc::unbounded();
|
||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||
|
||||
let dispatch = RequestDispatch::<String, String, _> {
|
||||
transport: client_channel.fuse(),
|
||||
pending_requests: pending_requests.fuse(),
|
||||
canceled_requests: CanceledRequests(canceled_requests).fuse(),
|
||||
in_flight_requests: FnvHashMap::default(),
|
||||
config: Config::default(),
|
||||
};
|
||||
|
||||
let cancellation = RequestCancellation(cancel_tx);
|
||||
let channel = Channel {
|
||||
to_dispatch,
|
||||
cancellation,
|
||||
next_request_id: Arc::new(AtomicU64::new(0)),
|
||||
};
|
||||
|
||||
(dispatch, channel, server_channel)
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
channel: &mut Channel<String, String>,
|
||||
request: &str,
|
||||
) -> DispatchResponse<String> {
|
||||
channel
|
||||
.send(context::current(), request.to_string())
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn send_response(
|
||||
channel: &mut UnboundedChannel<ClientMessage<String>, Response<String>>,
|
||||
response: Response<String>,
|
||||
) {
|
||||
channel.send(response).await.unwrap();
|
||||
}
|
||||
|
||||
trait PollTest {
|
||||
type T;
|
||||
fn unwrap(self) -> Poll<Self::T>;
|
||||
fn ready(self) -> Self::T;
|
||||
}
|
||||
|
||||
impl<T, E> PollTest for Poll<Option<Result<T, E>>>
|
||||
where
|
||||
E: ::std::fmt::Display,
|
||||
{
|
||||
type T = Option<T>;
|
||||
|
||||
fn unwrap(self) -> Poll<Option<T>> {
|
||||
match self {
|
||||
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn ready(self) -> Option<T> {
|
||||
match self {
|
||||
Poll::Ready(Some(Ok(t))) => Some(t),
|
||||
Poll::Ready(None) => None,
|
||||
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
|
||||
Poll::Pending => panic!("Pending"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Provides a request context that carries a deadline and trace context. This context is sent from
|
||||
//! client to server and is used by the server to enforce response deadlines.
|
||||
|
||||
use crate::trace::{self, TraceId};
|
||||
use static_assertions::assert_impl_all;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
/// A request context that carries request-scoped information like deadlines and trace information.
|
||||
/// It is sent from client to server and is used by the server to enforce response deadlines.
|
||||
///
|
||||
/// The context should not be stored directly in a server implementation, because the context will
|
||||
/// be different for each request in scope.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Context {
|
||||
/// When the client expects the request to be complete by. The server should cancel the request
|
||||
/// if it is not complete by this time.
|
||||
#[cfg_attr(
|
||||
feature = "serde1",
|
||||
serde(serialize_with = "crate::util::serde::serialize_epoch_secs")
|
||||
)]
|
||||
#[cfg_attr(
|
||||
feature = "serde1",
|
||||
serde(deserialize_with = "crate::util::serde::deserialize_epoch_secs")
|
||||
)]
|
||||
#[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))]
|
||||
pub deadline: SystemTime,
|
||||
/// Uniquely identifies requests originating from the same source.
|
||||
/// When a service handles a request by making requests itself, those requests should
|
||||
/// include the same `trace_id` as that included on the original request. This way,
|
||||
/// users can trace related actions across a distributed system.
|
||||
pub trace_context: trace::Context,
|
||||
}
|
||||
|
||||
assert_impl_all!(Context: Send, Sync);
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
fn ten_seconds_from_now() -> SystemTime {
|
||||
SystemTime::now() + Duration::from_secs(10)
|
||||
}
|
||||
|
||||
/// Returns the context for the current request, or a default Context if no request is active.
|
||||
// TODO: populate Context with request-scoped data, with default fallbacks.
|
||||
pub fn current() -> Context {
|
||||
Context {
|
||||
deadline: SystemTime::now() + Duration::from_secs(10),
|
||||
trace_context: trace::Context::new_root(),
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
/// Returns the ID of the request-scoped trace.
|
||||
pub fn trace_id(&self) -> &TraceId {
|
||||
&self.trace_context.trace_id
|
||||
}
|
||||
}
|
||||
@@ -1,707 +0,0 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Provides a server that concurrently handles many connections sending multiplexed requests.
|
||||
|
||||
use crate::{
|
||||
context, trace, util::Compact, util::TimeUntil, ClientMessage, PollIo, Request, Response,
|
||||
ServerError, Transport,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{
|
||||
channel::mpsc,
|
||||
future::{AbortHandle, AbortRegistration, Abortable},
|
||||
prelude::*,
|
||||
ready,
|
||||
stream::Fuse,
|
||||
task::*,
|
||||
};
|
||||
use humantime::format_rfc3339;
|
||||
use log::{debug, trace};
|
||||
use pin_project::pin_project;
|
||||
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
|
||||
use tokio::time::Timeout;
|
||||
|
||||
mod filter;
|
||||
#[cfg(test)]
|
||||
mod testing;
|
||||
mod throttle;
|
||||
|
||||
pub use self::{
|
||||
filter::ChannelFilter,
|
||||
throttle::{Throttler, ThrottlerStream},
|
||||
};
|
||||
|
||||
/// Manages clients, serving multiplexed requests over each connection.
|
||||
#[derive(Debug)]
|
||||
pub struct Server<Req, Resp> {
|
||||
config: Config,
|
||||
ghost: PhantomData<(Req, Resp)>,
|
||||
}
|
||||
|
||||
impl<Req, Resp> Default for Server<Req, Resp> {
|
||||
fn default() -> Self {
|
||||
new(Config::default())
|
||||
}
|
||||
}
|
||||
|
||||
/// Settings that control the behavior of the server.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Config {
|
||||
/// The number of responses per client that can be buffered server-side before being sent.
|
||||
/// `pending_response_buffer` controls the buffer size of the channel that a server's
|
||||
/// response tasks use to send responses to the client handler task.
|
||||
pub pending_response_buffer: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Config {
|
||||
pending_response_buffer: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Returns a channel backed by `transport` and configured with `self`.
|
||||
pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
{
|
||||
BaseChannel::new(self, transport)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new server with configuration specified `config`.
|
||||
pub fn new<Req, Resp>(config: Config) -> Server<Req, Resp> {
|
||||
Server {
|
||||
config,
|
||||
ghost: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> Server<Req, Resp> {
|
||||
/// Returns the config for this server.
|
||||
pub fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Returns a stream of server channels.
|
||||
pub fn incoming<S, T>(self, listener: S) -> impl Stream<Item = BaseChannel<Req, Resp, T>>
|
||||
where
|
||||
S: Stream<Item = T>,
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
{
|
||||
listener.map(move |t| BaseChannel::new(self.config.clone(), t))
|
||||
}
|
||||
}
|
||||
|
||||
/// Basically a Fn(Req) -> impl Future<Output = Resp>;
|
||||
pub trait Serve<Req>: Sized + Clone {
|
||||
/// Type of response.
|
||||
type Resp;
|
||||
|
||||
/// Type of response future.
|
||||
type Fut: Future<Output = Self::Resp>;
|
||||
|
||||
/// Responds to a single request.
|
||||
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut;
|
||||
}
|
||||
|
||||
impl<Req, Resp, Fut, F> Serve<Req> for F
|
||||
where
|
||||
F: FnOnce(context::Context, Req) -> Fut + Clone,
|
||||
Fut: Future<Output = Resp>,
|
||||
{
|
||||
type Resp = Resp;
|
||||
type Fut = Fut;
|
||||
|
||||
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut {
|
||||
self(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
/// A utility trait enabling a stream to fluently chain a request handler.
|
||||
pub trait Handler<C>
|
||||
where
|
||||
Self: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
/// Enforces channel per-key limits.
|
||||
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> filter::ChannelFilter<Self, K, KF>
|
||||
where
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
KF: Fn(&C) -> K,
|
||||
{
|
||||
ChannelFilter::new(self, n, keymaker)
|
||||
}
|
||||
|
||||
/// Caps the number of concurrent requests per channel.
|
||||
fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream<Self> {
|
||||
ThrottlerStream::new(self, n)
|
||||
}
|
||||
|
||||
/// Responds to all requests with [`server::serve`](Serve).
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
fn respond_with<S>(self, server: S) -> Running<Self, S>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
{
|
||||
Running {
|
||||
incoming: self,
|
||||
server,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C> Handler<C> for S
|
||||
where
|
||||
S: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
}
|
||||
|
||||
/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct BaseChannel<Req, Resp, T> {
|
||||
config: Config,
|
||||
/// Writes responses to the wire and reads requests off the wire.
|
||||
#[pin]
|
||||
transport: Fuse<T>,
|
||||
/// Number of requests currently being responded to.
|
||||
in_flight_requests: FnvHashMap<u64, AbortHandle>,
|
||||
/// Types the request and response.
|
||||
ghost: PhantomData<(Req, Resp)>,
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
{
|
||||
/// Creates a new channel backed by `transport` and configured with `config`.
|
||||
pub fn new(config: Config, transport: T) -> Self {
|
||||
BaseChannel {
|
||||
config,
|
||||
transport: transport.fuse(),
|
||||
in_flight_requests: FnvHashMap::default(),
|
||||
ghost: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new channel backed by `transport` and configured with the defaults.
|
||||
pub fn with_defaults(transport: T) -> Self {
|
||||
Self::new(Config::default(), transport)
|
||||
}
|
||||
|
||||
/// Returns the inner transport over which messages are sent and received.
|
||||
pub fn get_ref(&self) -> &T {
|
||||
self.transport.get_ref()
|
||||
}
|
||||
|
||||
/// Returns the inner transport over which messages are sent and received.
|
||||
pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> {
|
||||
self.project().transport.get_pin_mut()
|
||||
}
|
||||
|
||||
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
|
||||
// It's possible the request was already completed, so it's fine
|
||||
// if this is None.
|
||||
if let Some(cancel_handle) = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&request_id)
|
||||
{
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
|
||||
cancel_handle.abort();
|
||||
let remaining = self.as_mut().project().in_flight_requests.len();
|
||||
trace!(
|
||||
"[{}] Request canceled. In-flight requests = {}",
|
||||
trace_context.trace_id,
|
||||
remaining,
|
||||
);
|
||||
} else {
|
||||
trace!(
|
||||
"[{}] Received cancellation, but response handler \
|
||||
is already complete.",
|
||||
trace_context.trace_id,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The server end of an open connection with a client, streaming in requests from, and sinking
|
||||
/// responses to, the client.
|
||||
///
|
||||
/// Channels are free to somewhat rely on the assumption that all in-flight requests are eventually
|
||||
/// either [cancelled](BaseChannel::cancel_request) or [responded to](Sink::start_send). Safety cannot
|
||||
/// rely on this assumption, but it is best for `Channel` users to always account for all outstanding
|
||||
/// requests.
|
||||
pub trait Channel
|
||||
where
|
||||
Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
|
||||
{
|
||||
/// Type of request item.
|
||||
type Req;
|
||||
|
||||
/// Type of response sink item.
|
||||
type Resp;
|
||||
|
||||
/// Configuration of the channel.
|
||||
fn config(&self) -> &Config;
|
||||
|
||||
/// Returns the number of in-flight requests over this channel.
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize;
|
||||
|
||||
/// Caps the number of concurrent requests.
|
||||
fn max_concurrent_requests(self, n: usize) -> Throttler<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Throttler::new(self, n)
|
||||
}
|
||||
|
||||
/// Tells the Channel that request with ID `request_id` is being handled.
|
||||
/// The request will be tracked until a response with the same ID is sent
|
||||
/// to the Channel.
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration;
|
||||
|
||||
/// Respond to requests coming over the channel with `f`. Returns a future that drives the
|
||||
/// responses and resolves when the connection is closed.
|
||||
fn respond_with<S>(self, server: S) -> ClientHandler<Self, S>
|
||||
where
|
||||
S: Serve<Self::Req, Resp = Self::Resp>,
|
||||
Self: Sized,
|
||||
{
|
||||
let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
|
||||
let responses = responses.fuse();
|
||||
|
||||
ClientHandler {
|
||||
channel: self,
|
||||
server,
|
||||
pending_responses: responses,
|
||||
responses_tx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
{
|
||||
type Item = io::Result<Request<Req>>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
|
||||
Some(message) => match message {
|
||||
ClientMessage::Request(request) => {
|
||||
return Poll::Ready(Some(Ok(request)));
|
||||
}
|
||||
ClientMessage::Cancel {
|
||||
trace_context,
|
||||
request_id,
|
||||
} => {
|
||||
self.as_mut().cancel_request(&trace_context, request_id);
|
||||
}
|
||||
},
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
{
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().transport.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
|
||||
if self
|
||||
.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&response.request_id)
|
||||
.is_some()
|
||||
{
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
}
|
||||
|
||||
self.project().transport.start_send(response)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().transport.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().transport.poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
|
||||
fn as_ref(&self) -> &T {
|
||||
self.transport.get_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
fn in_flight_requests(mut self: Pin<&mut Self>) -> usize {
|
||||
self.as_mut().project().in_flight_requests.len()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
let (abort_handle, abort_registration) = AbortHandle::new_pair();
|
||||
assert!(self
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.insert(request_id, abort_handle)
|
||||
.is_none());
|
||||
abort_registration
|
||||
}
|
||||
}
|
||||
|
||||
/// A running handler serving all requests coming over a channel.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ClientHandler<C, S>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
#[pin]
|
||||
channel: C,
|
||||
/// Responses waiting to be written to the wire.
|
||||
#[pin]
|
||||
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>,
|
||||
/// Handed out to request handlers to fan in responses.
|
||||
#[pin]
|
||||
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
|
||||
/// Server
|
||||
server: S,
|
||||
}
|
||||
|
||||
impl<C, S> ClientHandler<C, S>
|
||||
where
|
||||
C: Channel,
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
{
|
||||
/// Returns the inner channel over which messages are sent and received.
|
||||
pub fn get_pin_channel(self: Pin<&mut Self>) -> Pin<&mut C> {
|
||||
self.project().channel
|
||||
}
|
||||
|
||||
fn pump_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<RequestHandler<S::Fut, C::Resp>> {
|
||||
match ready!(self.as_mut().project().channel.poll_next(cx)?) {
|
||||
Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn pump_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
read_half_closed: bool,
|
||||
) -> PollIo<()> {
|
||||
match self.as_mut().poll_next_response(cx)? {
|
||||
Poll::Ready(Some((ctx, response))) => {
|
||||
trace!(
|
||||
"[{}] Staging response. In-flight requests = {}.",
|
||||
ctx.trace_id(),
|
||||
self.as_mut().project().channel.in_flight_requests(),
|
||||
);
|
||||
self.as_mut().project().channel.start_send(response)?;
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
// Shutdown can't be done before we finish pumping out remaining responses.
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Pending => {
|
||||
// No more requests to process, so flush any requests buffered in the transport.
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
|
||||
// Being here means there are no staged requests and all written responses are
|
||||
// fully flushed. So, if the read half is closed and there are no in-flight
|
||||
// requests, then we can close the write half.
|
||||
if read_half_closed && self.as_mut().project().channel.in_flight_requests() == 0 {
|
||||
Poll::Ready(None)
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_next_response(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<(context::Context, Response<C::Resp>)> {
|
||||
// Ensure there's room to write a response.
|
||||
while let Poll::Pending = self.as_mut().project().channel.poll_ready(cx)? {
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
}
|
||||
|
||||
match ready!(self.as_mut().project().pending_responses.poll_next(cx)) {
|
||||
Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
|
||||
None => {
|
||||
// This branch likely won't happen, since the ClientHandler is holding a Sender.
|
||||
Poll::Ready(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
request: Request<C::Req>,
|
||||
) -> RequestHandler<S::Fut, C::Resp> {
|
||||
let request_id = request.id;
|
||||
let deadline = request.context.deadline;
|
||||
let timeout = deadline.time_until();
|
||||
trace!(
|
||||
"[{}] Received request with deadline {} (timeout {:?}).",
|
||||
request.context.trace_id(),
|
||||
format_rfc3339(deadline),
|
||||
timeout,
|
||||
);
|
||||
let ctx = request.context;
|
||||
let request = request.message;
|
||||
|
||||
let response = self.as_mut().project().server.clone().serve(ctx, request);
|
||||
let response = Resp {
|
||||
state: RespState::PollResp,
|
||||
request_id,
|
||||
ctx,
|
||||
deadline,
|
||||
f: tokio::time::timeout(timeout, response),
|
||||
response: None,
|
||||
response_tx: self.as_mut().project().responses_tx.clone(),
|
||||
};
|
||||
let abort_registration = self.as_mut().project().channel.start_request(request_id);
|
||||
RequestHandler {
|
||||
resp: Abortable::new(response, abort_registration),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A future fulfilling a single client request.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct RequestHandler<F, R> {
|
||||
#[pin]
|
||||
resp: Abortable<Resp<F, R>>,
|
||||
}
|
||||
|
||||
impl<F, R> Future for RequestHandler<F, R>
|
||||
where
|
||||
F: Future<Output = R>,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
let _ = ready!(self.project().resp.poll(cx));
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
struct Resp<F, R> {
|
||||
state: RespState,
|
||||
request_id: u64,
|
||||
ctx: context::Context,
|
||||
deadline: SystemTime,
|
||||
#[pin]
|
||||
f: Timeout<F>,
|
||||
response: Option<Response<R>>,
|
||||
#[pin]
|
||||
response_tx: mpsc::Sender<(context::Context, Response<R>)>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
enum RespState {
|
||||
PollResp,
|
||||
PollReady,
|
||||
PollFlush,
|
||||
}
|
||||
|
||||
impl<F, R> Future for Resp<F, R>
|
||||
where
|
||||
F: Future<Output = R>,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
loop {
|
||||
match self.as_mut().project().state {
|
||||
RespState::PollResp => {
|
||||
let result = ready!(self.as_mut().project().f.poll(cx));
|
||||
*self.as_mut().project().response = Some(Response {
|
||||
request_id: self.request_id,
|
||||
message: match result {
|
||||
Ok(message) => Ok(message),
|
||||
Err(tokio::time::Elapsed { .. }) => {
|
||||
debug!(
|
||||
"[{}] Response did not complete before deadline of {}s.",
|
||||
self.ctx.trace_id(),
|
||||
format_rfc3339(self.deadline)
|
||||
);
|
||||
// No point in responding, since the client will have dropped the
|
||||
// request.
|
||||
Err(ServerError {
|
||||
kind: io::ErrorKind::TimedOut,
|
||||
detail: Some(format!(
|
||||
"Response did not complete before deadline of {}s.",
|
||||
format_rfc3339(self.deadline)
|
||||
)),
|
||||
})
|
||||
}
|
||||
},
|
||||
});
|
||||
*self.as_mut().project().state = RespState::PollReady;
|
||||
}
|
||||
RespState::PollReady => {
|
||||
let ready = ready!(self.as_mut().project().response_tx.poll_ready(cx));
|
||||
if ready.is_err() {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
let resp = (self.ctx, self.as_mut().project().response.take().unwrap());
|
||||
if self
|
||||
.as_mut()
|
||||
.project()
|
||||
.response_tx
|
||||
.start_send(resp)
|
||||
.is_err()
|
||||
{
|
||||
return Poll::Ready(());
|
||||
}
|
||||
*self.as_mut().project().state = RespState::PollFlush;
|
||||
}
|
||||
RespState::PollFlush => {
|
||||
let ready = ready!(self.as_mut().project().response_tx.poll_flush(cx));
|
||||
if ready.is_err() {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
return Poll::Ready(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, S> Stream for ClientHandler<C, S>
|
||||
where
|
||||
C: Channel,
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
{
|
||||
type Item = io::Result<RequestHandler<S::Fut, C::Resp>>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
let read = self.as_mut().pump_read(cx)?;
|
||||
let read_closed = if let Poll::Ready(None) = read {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
};
|
||||
match (read, self.as_mut().pump_write(cx, read_closed)?) {
|
||||
(Poll::Ready(None), Poll::Ready(None)) => {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
(Poll::Ready(Some(request_handler)), _) => {
|
||||
return Poll::Ready(Some(Ok(request_handler)));
|
||||
}
|
||||
(_, Poll::Ready(Some(()))) => {}
|
||||
_ => {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send + 'static execution helper methods.
|
||||
|
||||
impl<C, S> ClientHandler<C, S>
|
||||
where
|
||||
C: Channel + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
||||
S::Fut: Send + 'static,
|
||||
{
|
||||
/// Runs the client handler until completion by [spawning](tokio::spawn) each
|
||||
/// request handler onto the default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub fn execute(self) -> impl Future<Output = ()> {
|
||||
self.try_for_each(|request_handler| async {
|
||||
tokio::spawn(request_handler);
|
||||
Ok(())
|
||||
})
|
||||
.map_ok(|()| log::info!("ClientHandler finished."))
|
||||
.unwrap_or_else(|e| log::info!("ClientHandler errored out: {}", e))
|
||||
}
|
||||
}
|
||||
|
||||
/// A future that drives the server by [spawning](tokio::spawn) channels and request handlers on the default
|
||||
/// executor.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub struct Running<St, Se> {
|
||||
#[pin]
|
||||
incoming: St,
|
||||
server: Se,
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
impl<St, C, Se> Future for Running<St, Se>
|
||||
where
|
||||
St: Sized + Stream<Item = C>,
|
||||
C: Channel + Send + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||
Se::Fut: Send + 'static,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
while let Some(channel) = ready!(self.as_mut().project().incoming.poll_next(cx)) {
|
||||
tokio::spawn(
|
||||
channel
|
||||
.respond_with(self.as_mut().project().server.clone())
|
||||
.execute(),
|
||||
);
|
||||
}
|
||||
log::info!("Server shutting down.");
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
@@ -1,328 +0,0 @@
|
||||
// Copyright 2020 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use super::{Channel, Config};
|
||||
use crate::{Response, ServerError};
|
||||
use futures::{future::AbortRegistration, prelude::*, ready, task::*};
|
||||
use log::debug;
|
||||
use pin_project::pin_project;
|
||||
use std::{io, pin::Pin};
|
||||
|
||||
/// A [`Channel`] that limits the number of concurrent
|
||||
/// requests by throttling.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct Throttler<C> {
|
||||
max_in_flight_requests: usize,
|
||||
#[pin]
|
||||
inner: C,
|
||||
}
|
||||
|
||||
impl<C> Throttler<C> {
|
||||
/// Returns the inner channel.
|
||||
pub fn get_ref(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Throttler<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
/// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to
|
||||
/// `max_in_flight_requests`.
|
||||
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
|
||||
Throttler {
|
||||
inner,
|
||||
max_in_flight_requests,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Stream for Throttler<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Item = <C as Stream>::Item;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
|
||||
{
|
||||
ready!(self.as_mut().project().inner.poll_ready(cx)?);
|
||||
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
|
||||
Some(request) => {
|
||||
debug!(
|
||||
"[{}] Client has reached in-flight request limit ({}/{}).",
|
||||
request.context.trace_id(),
|
||||
self.as_mut().in_flight_requests(),
|
||||
self.as_mut().project().max_in_flight_requests,
|
||||
);
|
||||
|
||||
self.as_mut().start_send(Response {
|
||||
request_id: request.id,
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::WouldBlock,
|
||||
detail: Some("Server throttled the request.".into()),
|
||||
}),
|
||||
})?;
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
self.project().inner.poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
|
||||
self.project().inner.start_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
self.project().inner.poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> AsRef<C> for Throttler<C> {
|
||||
fn as_ref(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Channel for Throttler<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Req = <C as Channel>::Req;
|
||||
type Resp = <C as Channel>::Resp;
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
self.project().inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
self.inner.config()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
self.project().inner.start_request(request_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream of throttling channels.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ThrottlerStream<S> {
|
||||
#[pin]
|
||||
inner: S,
|
||||
max_in_flight_requests: usize,
|
||||
}
|
||||
|
||||
impl<S> ThrottlerStream<S>
|
||||
where
|
||||
S: Stream,
|
||||
<S as Stream>::Item: Channel,
|
||||
{
|
||||
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
max_in_flight_requests,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream for ThrottlerStream<S>
|
||||
where
|
||||
S: Stream,
|
||||
<S as Stream>::Item: Channel,
|
||||
{
|
||||
type Item = Throttler<<S as Stream>::Item>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)) {
|
||||
Some(channel) => Poll::Ready(Some(Throttler::new(
|
||||
channel,
|
||||
*self.project().max_in_flight_requests,
|
||||
))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use super::testing::{self, FakeChannel, PollExt};
|
||||
#[cfg(test)]
|
||||
use crate::Request;
|
||||
#[cfg(test)]
|
||||
use pin_utils::pin_mut;
|
||||
#[cfg(test)]
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[test]
|
||||
fn throttler_in_flight_requests() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
for i in 0..5 {
|
||||
throttler.inner.in_flight_requests.insert(i);
|
||||
}
|
||||
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_start_request() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.as_mut().start_request(1);
|
||||
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_done() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_some() -> io::Result<()> {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 1,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.inner.push_req(0, 1);
|
||||
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
|
||||
assert_eq!(
|
||||
throttler
|
||||
.as_mut()
|
||||
.poll_next(&mut testing::cx())?
|
||||
.map(|r| r.map(|r| (r.id, r.message))),
|
||||
Poll::Ready(Some((0, 1)))
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_throttled() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.inner.push_req(1, 1);
|
||||
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
||||
assert_eq!(throttler.inner.sink.len(), 1);
|
||||
let resp = throttler.inner.sink.get(0).unwrap();
|
||||
assert_eq!(resp.request_id, 1);
|
||||
assert!(resp.message.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_throttled_sink_not_ready() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: PendingSink::default::<isize, isize>(),
|
||||
};
|
||||
pin_mut!(throttler);
|
||||
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
|
||||
|
||||
struct PendingSink<In, Out> {
|
||||
ghost: PhantomData<fn(Out) -> In>,
|
||||
}
|
||||
impl PendingSink<(), ()> {
|
||||
pub fn default<Req, Resp>() -> PendingSink<io::Result<Request<Req>>, Response<Resp>> {
|
||||
PendingSink { ghost: PhantomData }
|
||||
}
|
||||
}
|
||||
impl<In, Out> Stream for PendingSink<In, Out> {
|
||||
type Item = In;
|
||||
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
|
||||
type Error = io::Error;
|
||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
|
||||
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>> {
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
fn config(&self) -> &Config {
|
||||
unimplemented!()
|
||||
}
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
0
|
||||
}
|
||||
fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_start_send() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.inner.in_flight_requests.insert(0);
|
||||
throttler
|
||||
.as_mut()
|
||||
.start_send(Response {
|
||||
request_id: 0,
|
||||
message: Ok(1),
|
||||
})
|
||||
.unwrap();
|
||||
assert!(throttler.inner.in_flight_requests.is_empty());
|
||||
assert_eq!(
|
||||
throttler.inner.sink.get(0),
|
||||
Some(&Response {
|
||||
request_id: 0,
|
||||
message: Ok(1),
|
||||
})
|
||||
);
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Provides a [`Transport`](sealed::Transport) trait as well as implementations.
|
||||
//!
|
||||
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`](sealed::Transport)
|
||||
//! can be plugged in, using whatever protocol it wants.
|
||||
|
||||
use futures::prelude::*;
|
||||
use std::io;
|
||||
|
||||
pub mod channel;
|
||||
|
||||
pub(crate) mod sealed {
|
||||
use super::*;
|
||||
|
||||
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
|
||||
pub trait Transport<SinkItem, Item>:
|
||||
Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T, SinkItem, Item> Transport<SinkItem, Item> for T where
|
||||
T: Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error> + ?Sized
|
||||
{
|
||||
}
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Transports backed by in-memory channels.
|
||||
|
||||
use crate::PollIo;
|
||||
use futures::{channel::mpsc, task::*, Sink, Stream};
|
||||
use pin_project::pin_project;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
|
||||
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
|
||||
/// [`Sink`].
|
||||
pub fn unbounded<SinkItem, Item>() -> (
|
||||
UnboundedChannel<SinkItem, Item>,
|
||||
UnboundedChannel<Item, SinkItem>,
|
||||
) {
|
||||
let (tx1, rx2) = mpsc::unbounded();
|
||||
let (tx2, rx1) = mpsc::unbounded();
|
||||
(
|
||||
UnboundedChannel { tx: tx1, rx: rx1 },
|
||||
UnboundedChannel { tx: tx2, rx: rx2 },
|
||||
)
|
||||
}
|
||||
|
||||
/// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender)
|
||||
/// and [`UnboundedReceiver`](mpsc::UnboundedReceiver).
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct UnboundedChannel<Item, SinkItem> {
|
||||
#[pin]
|
||||
rx: mpsc::UnboundedReceiver<Item>,
|
||||
#[pin]
|
||||
tx: mpsc::UnboundedSender<SinkItem>,
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
|
||||
type Item = Result<Item, io::Error>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> {
|
||||
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_ready(cx)
|
||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
||||
self.project()
|
||||
.tx
|
||||
.start_send(item)
|
||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_flush(cx)
|
||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_close(cx)
|
||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
client, context,
|
||||
server::{Handler, Server},
|
||||
transport,
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{prelude::*, stream};
|
||||
use log::trace;
|
||||
use std::io;
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn integration() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||
tokio::spawn(
|
||||
Server::default()
|
||||
.incoming(stream::once(future::ready(server_channel)))
|
||||
.respond_with(|_ctx, request: String| {
|
||||
future::ready(request.parse::<u64>().map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("{:?} is not an int", request),
|
||||
)
|
||||
}))
|
||||
}),
|
||||
);
|
||||
|
||||
let mut client = client::new(client::Config::default(), client_channel).spawn()?;
|
||||
|
||||
let response1 = client.call(context::current(), "123".into()).await?;
|
||||
let response2 = client.call(context::current(), "abc".into()).await?;
|
||||
|
||||
trace!("response1: {:?}, response2: {:?}", response1, response2);
|
||||
|
||||
assert_matches!(response1, Ok(123));
|
||||
assert_matches!(response2, Err(ref e) if e.kind() == io::ErrorKind::InvalidInput);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -14,10 +14,7 @@ use serde::{Deserialize, Serialize};
|
||||
use std::{error::Error, io, pin::Pin};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_serde::{Framed as SerdeFramed, *};
|
||||
use tokio_util::codec::{
|
||||
length_delimited::{self, LengthDelimitedCodec},
|
||||
Framed,
|
||||
};
|
||||
use tokio_util::codec::{length_delimited::LengthDelimitedCodec, Framed};
|
||||
|
||||
/// A transport that serializes to, and deserializes from, a byte stream.
|
||||
#[pin_project]
|
||||
@@ -45,14 +42,10 @@ where
|
||||
type Item = io::Result<Item>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
|
||||
match self.project().inner.poll_next(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Ready(Some(Ok::<_, CodecError>(next))) => Poll::Ready(Some(Ok(next))),
|
||||
Poll::Ready(Some(Err::<_, CodecError>(e))) => {
|
||||
Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e))))
|
||||
}
|
||||
}
|
||||
self.project()
|
||||
.inner
|
||||
.poll_next(cx)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,7 +61,10 @@ where
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
convert(self.project().inner.poll_ready(cx))
|
||||
self.project()
|
||||
.inner
|
||||
.poll_ready(cx)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
||||
@@ -79,20 +75,20 @@ where
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
convert(self.project().inner.poll_flush(cx))
|
||||
self.project()
|
||||
.inner
|
||||
.poll_flush(cx)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
convert(self.project().inner.poll_close(cx))
|
||||
self.project()
|
||||
.inner
|
||||
.poll_close(cx)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
}
|
||||
|
||||
fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
|
||||
poll: Poll<Result<(), E>>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
poll.map(|ready| ready.map_err(|e| io::Error::new(io::ErrorKind::Other, e)))
|
||||
}
|
||||
|
||||
/// Constructs a new transport from a framed transport and a serialization codec.
|
||||
pub fn new<S, Item, SinkItem, Codec>(
|
||||
framed_io: Framed<S, LengthDelimitedCodec>,
|
||||
@@ -130,6 +126,7 @@ pub mod tcp {
|
||||
futures::ready,
|
||||
std::{marker::PhantomData, net::SocketAddr},
|
||||
tokio::net::{TcpListener, TcpStream, ToSocketAddrs},
|
||||
tokio_util::codec::length_delimited,
|
||||
};
|
||||
|
||||
mod private {
|
||||
@@ -152,6 +149,7 @@ pub mod tcp {
|
||||
}
|
||||
|
||||
/// A connection Future that also exposes the length-delimited framing config.
|
||||
#[must_use]
|
||||
#[pin_project]
|
||||
pub struct Connect<T, Item, SinkItem, CodecFn> {
|
||||
#[pin]
|
||||
@@ -269,9 +267,12 @@ pub mod tcp {
|
||||
type Item = io::Result<Transport<TcpStream, Item, SinkItem, Codec>>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let next =
|
||||
ready!(Pin::new(&mut self.as_mut().project().listener.incoming()).poll_next(cx)?);
|
||||
Poll::Ready(next.map(|conn| Ok(new(self.config.new_framed(conn), (self.codec_fn)()))))
|
||||
let conn: TcpStream =
|
||||
ready!(Pin::new(&mut self.as_mut().project().listener).poll_accept(cx)?).0;
|
||||
Poll::Ready(Some(Ok(new(
|
||||
self.config.new_framed(conn),
|
||||
(self.codec_fn)(),
|
||||
))))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -286,96 +287,73 @@ mod tests {
|
||||
io::{self, Cursor},
|
||||
pin::Pin,
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_serde::formats::SymmetricalJson;
|
||||
|
||||
fn ctx() -> Context<'static> {
|
||||
Context::from_waker(&noop_waker_ref())
|
||||
}
|
||||
|
||||
struct TestIo(Cursor<Vec<u8>>);
|
||||
|
||||
impl AsyncRead for TestIo {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TestIo {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn close() {
|
||||
let (tx, _rx) = crate::transport::channel::bounded::<(), ()>(0);
|
||||
pin_mut!(tx);
|
||||
assert_matches!(tx.as_mut().poll_close(&mut ctx()), Poll::Ready(Ok(())));
|
||||
assert_matches!(tx.as_mut().start_send(()), Err(_));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream() {
|
||||
struct TestIo(Cursor<&'static [u8]>);
|
||||
|
||||
impl AsyncRead for TestIo {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
AsyncRead::poll_read(Pin::new(self.0.get_mut()), cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TestIo {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
let data = b"\x00\x00\x00\x18\"Test one, check check.\"";
|
||||
let data: &[u8] = b"\x00\x00\x00\x18\"Test one, check check.\"";
|
||||
let transport = Transport::from((
|
||||
TestIo(Cursor::new(data)),
|
||||
TestIo(Cursor::new(Vec::from(data))),
|
||||
SymmetricalJson::<String>::default(),
|
||||
));
|
||||
pin_mut!(transport);
|
||||
|
||||
assert_matches!(
|
||||
transport.poll_next(&mut ctx()),
|
||||
transport.as_mut().poll_next(&mut ctx()),
|
||||
Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check.");
|
||||
assert_matches!(transport.as_mut().poll_next(&mut ctx()), Poll::Ready(None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sink() {
|
||||
struct TestIo<'a>(&'a mut Vec<u8>);
|
||||
|
||||
impl<'a> AsyncRead for TestIo<'a> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> AsyncWrite for TestIo<'a> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
AsyncWrite::poll_write(Pin::new(&mut *self.0), cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
AsyncWrite::poll_flush(Pin::new(&mut *self.0), cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
AsyncWrite::poll_shutdown(Pin::new(&mut *self.0), cx)
|
||||
}
|
||||
}
|
||||
|
||||
let mut writer = vec![];
|
||||
let transport =
|
||||
Transport::from((TestIo(&mut writer), SymmetricalJson::<String>::default()));
|
||||
pin_mut!(transport);
|
||||
let writer = Cursor::new(vec![]);
|
||||
let mut transport = Box::pin(Transport::from((
|
||||
TestIo(writer),
|
||||
SymmetricalJson::<String>::default(),
|
||||
)));
|
||||
|
||||
assert_matches!(
|
||||
transport.as_mut().poll_ready(&mut ctx()),
|
||||
@@ -387,7 +365,32 @@ mod tests {
|
||||
.start_send("Test one, check check.".into()),
|
||||
Ok(())
|
||||
);
|
||||
assert_matches!(transport.poll_flush(&mut ctx()), Poll::Ready(Ok(())));
|
||||
assert_eq!(writer, b"\x00\x00\x00\x18\"Test one, check check.\"");
|
||||
assert_matches!(
|
||||
transport.as_mut().poll_flush(&mut ctx()),
|
||||
Poll::Ready(Ok(()))
|
||||
);
|
||||
assert_eq!(
|
||||
transport.get_ref().0.get_ref(),
|
||||
b"\x00\x00\x00\x18\"Test one, check check.\""
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(tcp)]
|
||||
#[tokio::test]
|
||||
async fn tcp() -> io::Result<()> {
|
||||
use super::tcp;
|
||||
|
||||
let mut listener = tcp::listen("0.0.0.0:0", SymmetricalJson::<String>::default).await?;
|
||||
let addr = listener.local_addr();
|
||||
tokio::spawn(async move {
|
||||
let mut transport = listener.next().await.unwrap().unwrap();
|
||||
let message = transport.next().await.unwrap().unwrap();
|
||||
transport.send(message).await.unwrap();
|
||||
});
|
||||
let mut transport = tcp::connect(addr, SymmetricalJson::<String>::default).await?;
|
||||
transport.send(String::from("test")).await?;
|
||||
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
|
||||
assert_matches!(transport.next().await, None);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
1188
tarpc/src/server.rs
Normal file
1188
tarpc/src/server.rs
Normal file
File diff suppressed because it is too large
Load Diff
221
tarpc/src/server/in_flight_requests.rs
Normal file
221
tarpc/src/server/in_flight_requests.rs
Normal file
@@ -0,0 +1,221 @@
|
||||
use crate::util::{Compact, TimeUntil};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::future::{AbortHandle, AbortRegistration};
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
task::{Context, Poll},
|
||||
time::SystemTime,
|
||||
};
|
||||
use tokio_util::time::delay_queue::{self, DelayQueue};
|
||||
use tracing::Span;
|
||||
|
||||
/// A data structure that tracks in-flight requests. It aborts requests,
|
||||
/// either on demand or when a request deadline expires.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct InFlightRequests {
|
||||
request_data: FnvHashMap<u64, RequestData>,
|
||||
deadlines: DelayQueue<u64>,
|
||||
}
|
||||
|
||||
/// Data needed to clean up a single in-flight request.
|
||||
#[derive(Debug)]
|
||||
struct RequestData {
|
||||
/// Aborts the response handler for the associated request.
|
||||
abort_handle: AbortHandle,
|
||||
/// The key to remove the timer for the request's deadline.
|
||||
deadline_key: delay_queue::Key,
|
||||
/// The client span.
|
||||
span: Span,
|
||||
}
|
||||
|
||||
/// An error returned when a request attempted to start with the same ID as a request already
|
||||
/// in flight.
|
||||
#[derive(Debug)]
|
||||
pub struct AlreadyExistsError;
|
||||
|
||||
impl InFlightRequests {
|
||||
/// Returns the number of in-flight requests.
|
||||
pub fn len(&self) -> usize {
|
||||
self.request_data.len()
|
||||
}
|
||||
|
||||
/// Starts a request, unless a request with the same ID is already in flight.
|
||||
pub fn start_request(
|
||||
&mut self,
|
||||
request_id: u64,
|
||||
deadline: SystemTime,
|
||||
span: Span,
|
||||
) -> Result<AbortRegistration, AlreadyExistsError> {
|
||||
match self.request_data.entry(request_id) {
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
let timeout = deadline.time_until();
|
||||
let (abort_handle, abort_registration) = AbortHandle::new_pair();
|
||||
let deadline_key = self.deadlines.insert(request_id, timeout);
|
||||
vacant.insert(RequestData {
|
||||
abort_handle,
|
||||
deadline_key,
|
||||
span,
|
||||
});
|
||||
Ok(abort_registration)
|
||||
}
|
||||
hash_map::Entry::Occupied(_) => Err(AlreadyExistsError),
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancels an in-flight request. Returns true iff the request was found.
|
||||
pub fn cancel_request(&mut self, request_id: u64) -> bool {
|
||||
if let Some(RequestData {
|
||||
span,
|
||||
abort_handle,
|
||||
deadline_key,
|
||||
}) = self.request_data.remove(&request_id)
|
||||
{
|
||||
let _entered = span.enter();
|
||||
self.request_data.compact(0.1);
|
||||
abort_handle.abort();
|
||||
self.deadlines.remove(&deadline_key);
|
||||
tracing::info!("ReceiveCancel");
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a request without aborting. Returns true iff the request was found.
|
||||
/// This method should be used when a response is being sent.
|
||||
pub fn remove_request(&mut self, request_id: u64) -> Option<Span> {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
Some(request_data.span)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Yields a request that has expired, aborting any ongoing processing of that request.
|
||||
pub fn poll_expired(&mut self, cx: &mut Context) -> Poll<Option<u64>> {
|
||||
if self.deadlines.is_empty() {
|
||||
// TODO(https://github.com/tokio-rs/tokio/issues/4161)
|
||||
// This is a workaround for DelayQueue not always treating this case correctly.
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
self.deadlines.poll_expired(cx).map(|expired| {
|
||||
let expired = expired?;
|
||||
if let Some(RequestData {
|
||||
abort_handle, span, ..
|
||||
}) = self.request_data.remove(expired.get_ref())
|
||||
{
|
||||
let _entered = span.enter();
|
||||
self.request_data.compact(0.1);
|
||||
abort_handle.abort();
|
||||
tracing::error!("DeadlineExceeded");
|
||||
}
|
||||
Some(expired.into_inner())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// When InFlightRequests is dropped, any outstanding requests are aborted.
|
||||
impl Drop for InFlightRequests {
|
||||
fn drop(&mut self) {
|
||||
self.request_data
|
||||
.values()
|
||||
.for_each(|request_data| request_data.abort_handle.abort())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{
|
||||
future::{pending, Abortable},
|
||||
FutureExt,
|
||||
};
|
||||
use futures_test::task::noop_context;
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_request_increases_len() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
in_flight_requests
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
assert_eq!(in_flight_requests.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn polling_expired_aborts() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
tokio::time::pause();
|
||||
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
|
||||
|
||||
assert_matches!(
|
||||
in_flight_requests.poll_expired(&mut noop_context()),
|
||||
Poll::Ready(Some(_))
|
||||
);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Ready(Err(_))
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancel_request_aborts() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
assert_eq!(in_flight_requests.cancel_request(0), true);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Ready(Err(_))
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_request_doesnt_abort() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
assert!(in_flight_requests.deadlines.is_empty());
|
||||
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(
|
||||
0,
|
||||
SystemTime::now() + std::time::Duration::from_secs(10),
|
||||
Span::current(),
|
||||
)
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
// Precondition: Pending expiration
|
||||
assert_matches!(
|
||||
in_flight_requests.poll_expired(&mut noop_context()),
|
||||
Poll::Pending
|
||||
);
|
||||
assert!(!in_flight_requests.deadlines.is_empty());
|
||||
|
||||
assert_matches!(in_flight_requests.remove_request(0), Some(_));
|
||||
// Postcondition: No pending expirations
|
||||
assert!(in_flight_requests.deadlines.is_empty());
|
||||
assert_matches!(
|
||||
in_flight_requests.poll_expired(&mut noop_context()),
|
||||
Poll::Ready(None)
|
||||
);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Pending
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
}
|
||||
}
|
||||
49
tarpc/src/server/incoming.rs
Normal file
49
tarpc/src/server/incoming.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use super::{
|
||||
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
|
||||
Channel,
|
||||
};
|
||||
use futures::prelude::*;
|
||||
use std::{fmt, hash::Hash};
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
use super::{tokio::TokioServerExecutor, Serve};
|
||||
|
||||
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
|
||||
pub trait Incoming<C>
|
||||
where
|
||||
Self: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
/// Enforces channel per-key limits.
|
||||
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> MaxChannelsPerKey<Self, K, KF>
|
||||
where
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
KF: Fn(&C) -> K,
|
||||
{
|
||||
MaxChannelsPerKey::new(self, n, keymaker)
|
||||
}
|
||||
|
||||
/// Caps the number of concurrent requests per channel.
|
||||
fn max_concurrent_requests_per_channel(self, n: usize) -> MaxRequestsPerChannel<Self> {
|
||||
MaxRequestsPerChannel::new(self, n)
|
||||
}
|
||||
|
||||
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
|
||||
/// concurrently by spawning on tokio's default executor, and each request will be also
|
||||
/// be spawned on tokio's default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
{
|
||||
TokioServerExecutor::new(self, serve)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C> Incoming<C> for S
|
||||
where
|
||||
S: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
}
|
||||
5
tarpc/src/server/limits.rs
Normal file
5
tarpc/src/server/limits.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
/// Provides functionality to limit the number of active channels.
|
||||
pub mod channels_per_key;
|
||||
|
||||
/// Provides a [channel](crate::server::Channel) that limits the number of in-flight requests.
|
||||
pub mod requests_per_channel;
|
||||
@@ -9,33 +9,36 @@ use crate::{
|
||||
util::Compact,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{channel::mpsc, future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
|
||||
use log::{debug, info, trace};
|
||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||
use pin_project::pin_project;
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::{
|
||||
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
|
||||
collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, info, trace};
|
||||
|
||||
/// A single-threaded filter that drops channels based on per-key limits.
|
||||
/// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on
|
||||
/// per-key limits.
|
||||
///
|
||||
/// The decision to drop a Channel is made once at the time the Channel materializes. Once a
|
||||
/// Channel is yielded, it will not be prematurely dropped.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ChannelFilter<S, K, F>
|
||||
pub struct MaxChannelsPerKey<S, K, F>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
{
|
||||
#[pin]
|
||||
listener: Fuse<S>,
|
||||
channels_per_key: u32,
|
||||
#[pin]
|
||||
dropped_keys: mpsc::UnboundedReceiver<K>,
|
||||
#[pin]
|
||||
dropped_keys_tx: mpsc::UnboundedSender<K>,
|
||||
key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
|
||||
keymaker: F,
|
||||
}
|
||||
|
||||
/// A channel that is tracked by a ChannelFilter.
|
||||
/// A channel that is tracked by [`MaxChannelsPerKey`].
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct TrackedChannel<C, K> {
|
||||
@@ -53,7 +56,7 @@ struct Tracker<K> {
|
||||
impl<K> Drop for Tracker<K> {
|
||||
fn drop(&mut self) {
|
||||
// Don't care if the listener is dropped.
|
||||
let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap());
|
||||
let _ = self.dropped_keys.send(self.key.take().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,8 +66,8 @@ where
|
||||
{
|
||||
type Item = <C as Stream>::Item;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
self.channel().poll_next(cx)
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
self.inner_pin_mut().poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,20 +77,20 @@ where
|
||||
{
|
||||
type Error = C::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.channel().poll_ready(cx)
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner_pin_mut().poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
|
||||
self.channel().start_send(item)
|
||||
fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
|
||||
self.inner_pin_mut().start_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.channel().poll_flush(cx)
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner_pin_mut().poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.channel().poll_close(cx)
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner_pin_mut().poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,17 +106,18 @@ where
|
||||
{
|
||||
type Req = C::Req;
|
||||
type Resp = C::Resp;
|
||||
type Transport = C::Transport;
|
||||
|
||||
fn config(&self) -> &server::Config {
|
||||
self.inner.config()
|
||||
}
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
self.project().inner.in_flight_requests()
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
self.inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
self.project().inner.start_request(request_id)
|
||||
fn transport(&self) -> &Self::Transport {
|
||||
self.inner.transport()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,12 +128,12 @@ impl<C, K> TrackedChannel<C, K> {
|
||||
}
|
||||
|
||||
/// Returns the pinned inner channel.
|
||||
fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> {
|
||||
self.project().inner
|
||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
|
||||
self.as_mut().project().inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, K, F> ChannelFilter<S, K, F>
|
||||
impl<S, K, F> MaxChannelsPerKey<S, K, F>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
S: Stream,
|
||||
@@ -137,8 +141,8 @@ where
|
||||
{
|
||||
/// Sheds new channels to stay under configured limits.
|
||||
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
|
||||
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded();
|
||||
ChannelFilter {
|
||||
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel();
|
||||
MaxChannelsPerKey {
|
||||
listener: listener.fuse(),
|
||||
channels_per_key,
|
||||
dropped_keys,
|
||||
@@ -149,12 +153,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, K, F> ChannelFilter<S, K, F>
|
||||
impl<S, K, F> MaxChannelsPerKey<S, K, F>
|
||||
where
|
||||
S: Stream,
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
F: Fn(&S::Item) -> K,
|
||||
{
|
||||
fn listener_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<S>> {
|
||||
self.as_mut().project().listener
|
||||
}
|
||||
|
||||
fn handle_new_channel(
|
||||
mut self: Pin<&mut Self>,
|
||||
stream: S::Item,
|
||||
@@ -163,11 +171,10 @@ where
|
||||
let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
|
||||
|
||||
trace!(
|
||||
"[{}] Opening channel ({}/{}) channels for key.",
|
||||
key,
|
||||
Arc::strong_count(&tracker),
|
||||
self.as_mut().project().channels_per_key
|
||||
);
|
||||
channel_filter_key = %key,
|
||||
open_channels = Arc::strong_count(&tracker),
|
||||
max_open_channels = self.channels_per_key,
|
||||
"Opening channel");
|
||||
|
||||
Ok(TrackedChannel {
|
||||
tracker,
|
||||
@@ -175,15 +182,14 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
|
||||
let channels_per_key = self.channels_per_key;
|
||||
let dropped_keys = self.dropped_keys_tx.clone();
|
||||
let key_counts = &mut self.as_mut().project().key_counts;
|
||||
match key_counts.entry(key.clone()) {
|
||||
fn increment_channels_for_key(self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
|
||||
let self_ = self.project();
|
||||
let dropped_keys = self_.dropped_keys_tx;
|
||||
match self_.key_counts.entry(key.clone()) {
|
||||
Entry::Vacant(vacant) => {
|
||||
let tracker = Arc::new(Tracker {
|
||||
key: Some(key),
|
||||
dropped_keys,
|
||||
dropped_keys: dropped_keys.clone(),
|
||||
});
|
||||
|
||||
vacant.insert(Arc::downgrade(&tracker));
|
||||
@@ -191,17 +197,18 @@ where
|
||||
}
|
||||
Entry::Occupied(mut o) => {
|
||||
let count = o.get().strong_count();
|
||||
if count >= channels_per_key.try_into().unwrap() {
|
||||
if count >= TryFrom::try_from(*self_.channels_per_key).unwrap() {
|
||||
info!(
|
||||
"[{}] Opened max channels from key ({}/{}).",
|
||||
key, count, channels_per_key
|
||||
);
|
||||
channel_filter_key = %key,
|
||||
open_channels = count,
|
||||
max_open_channels = *self_.channels_per_key,
|
||||
"At open channel limit");
|
||||
Err(key)
|
||||
} else {
|
||||
Ok(o.get().upgrade().unwrap_or_else(|| {
|
||||
let tracker = Arc::new(Tracker {
|
||||
key: Some(key),
|
||||
dropped_keys,
|
||||
dropped_keys: dropped_keys.clone(),
|
||||
});
|
||||
|
||||
*o.get_mut() = Arc::downgrade(&tracker);
|
||||
@@ -216,18 +223,21 @@ where
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
|
||||
match ready!(self.as_mut().project().listener.poll_next_unpin(cx)) {
|
||||
match ready!(self.listener_pin_mut().poll_next_unpin(cx)) {
|
||||
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
match ready!(self.as_mut().project().dropped_keys.poll_next_unpin(cx)) {
|
||||
fn poll_closed_channels(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
let self_ = self.project();
|
||||
match ready!(self_.dropped_keys.poll_recv(cx)) {
|
||||
Some(key) => {
|
||||
debug!("All channels dropped for key [{}]", key);
|
||||
self.as_mut().project().key_counts.remove(&key);
|
||||
self.as_mut().project().key_counts.compact(0.1);
|
||||
debug!(
|
||||
channel_filter_key = %key,
|
||||
"All channels dropped");
|
||||
self_.key_counts.remove(&key);
|
||||
self_.key_counts.compact(0.1);
|
||||
Poll::Ready(())
|
||||
}
|
||||
None => unreachable!("Holding a copy of closed_channels and didn't close it."),
|
||||
@@ -235,7 +245,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, K, F> Stream for ChannelFilter<S, K, F>
|
||||
impl<S, K, F> Stream for MaxChannelsPerKey<S, K, F>
|
||||
where
|
||||
S: Stream,
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
@@ -268,7 +278,6 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn ctx() -> Context<'static> {
|
||||
use futures::task::*;
|
||||
@@ -280,12 +289,12 @@ fn ctx() -> Context<'static> {
|
||||
fn tracker_drop() {
|
||||
use assert_matches::assert_matches;
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded();
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
Tracker {
|
||||
key: Some(1),
|
||||
dropped_keys: tx,
|
||||
};
|
||||
assert_matches!(rx.try_next(), Ok(Some(1)));
|
||||
assert_matches!(rx.poll_recv(&mut ctx()), Poll::Ready(Some(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -293,8 +302,8 @@ fn tracked_channel_stream() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
|
||||
let (chan_tx, chan) = mpsc::unbounded();
|
||||
let (dropped_keys, _) = mpsc::unbounded();
|
||||
let (chan_tx, chan) = futures::channel::mpsc::unbounded();
|
||||
let (dropped_keys, _) = mpsc::unbounded_channel();
|
||||
let channel = TrackedChannel {
|
||||
inner: chan,
|
||||
tracker: Arc::new(Tracker {
|
||||
@@ -313,8 +322,8 @@ fn tracked_channel_sink() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
|
||||
let (chan, mut chan_rx) = mpsc::unbounded();
|
||||
let (dropped_keys, _) = mpsc::unbounded();
|
||||
let (chan, mut chan_rx) = futures::channel::mpsc::unbounded();
|
||||
let (dropped_keys, _) = mpsc::unbounded_channel();
|
||||
let channel = TrackedChannel {
|
||||
inner: chan,
|
||||
tracker: Arc::new(Tracker {
|
||||
@@ -338,8 +347,8 @@ fn channel_filter_increment_channels_for_key() {
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (_, listener) = mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let (_, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
|
||||
assert_eq!(Arc::strong_count(&tracker1), 1);
|
||||
@@ -359,8 +368,8 @@ fn channel_filter_handle_new_channel() {
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (_, listener) = mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let (_, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
let channel1 = filter
|
||||
.as_mut()
|
||||
@@ -391,8 +400,8 @@ fn channel_filter_poll_listener() {
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
new_channels
|
||||
@@ -427,8 +436,8 @@ fn channel_filter_poll_closed_channels() {
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
new_channels
|
||||
@@ -455,8 +464,8 @@ fn channel_filter_stream() {
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
new_channels
|
||||
349
tarpc/src/server/limits/requests_per_channel.rs
Normal file
349
tarpc/src/server/limits/requests_per_channel.rs
Normal file
@@ -0,0 +1,349 @@
|
||||
// Copyright 2020 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use crate::{
|
||||
server::{Channel, Config},
|
||||
Response, ServerError,
|
||||
};
|
||||
use futures::{prelude::*, ready, task::*};
|
||||
use pin_project::pin_project;
|
||||
use std::{io, pin::Pin};
|
||||
|
||||
/// A [`Channel`] that limits the number of concurrent requests by throttling.
|
||||
///
|
||||
/// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low
|
||||
/// for the resources available to the server. For production use cases, a more advanced throttler
|
||||
/// is likely needed.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct MaxRequests<C> {
|
||||
max_in_flight_requests: usize,
|
||||
#[pin]
|
||||
inner: C,
|
||||
}
|
||||
|
||||
impl<C> MaxRequests<C> {
|
||||
/// Returns the inner channel.
|
||||
pub fn get_ref(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> MaxRequests<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
/// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to
|
||||
/// `max_in_flight_requests`.
|
||||
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
|
||||
MaxRequests {
|
||||
max_in_flight_requests,
|
||||
inner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Stream for MaxRequests<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Item = <C as Stream>::Item;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
|
||||
{
|
||||
ready!(self.as_mut().project().inner.poll_ready(cx)?);
|
||||
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
|
||||
Some(r) => {
|
||||
let _entered = r.span.enter();
|
||||
tracing::info!(
|
||||
in_flight_requests = self.as_mut().in_flight_requests(),
|
||||
"ThrottleRequest",
|
||||
);
|
||||
|
||||
self.as_mut().start_send(Response {
|
||||
request_id: r.request.id,
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::WouldBlock,
|
||||
detail: "server throttled the request.".into(),
|
||||
}),
|
||||
})?;
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
self.project().inner.poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Sink<Response<<C as Channel>::Resp>> for MaxRequests<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Error = C::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(
|
||||
self: Pin<&mut Self>,
|
||||
item: Response<<C as Channel>::Resp>,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.project().inner.start_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().inner.poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> AsRef<C> for MaxRequests<C> {
|
||||
fn as_ref(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Channel for MaxRequests<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Req = <C as Channel>::Req;
|
||||
type Resp = <C as Channel>::Resp;
|
||||
type Transport = <C as Channel>::Transport;
|
||||
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
self.inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
self.inner.config()
|
||||
}
|
||||
|
||||
fn transport(&self) -> &Self::Transport {
|
||||
self.inner.transport()
|
||||
}
|
||||
}
|
||||
|
||||
/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on
|
||||
/// the number of in-flight requests.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct MaxRequestsPerChannel<S> {
|
||||
#[pin]
|
||||
inner: S,
|
||||
max_in_flight_requests: usize,
|
||||
}
|
||||
|
||||
impl<S> MaxRequestsPerChannel<S>
|
||||
where
|
||||
S: Stream,
|
||||
<S as Stream>::Item: Channel,
|
||||
{
|
||||
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
max_in_flight_requests,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream for MaxRequestsPerChannel<S>
|
||||
where
|
||||
S: Stream,
|
||||
<S as Stream>::Item: Channel,
|
||||
{
|
||||
type Item = MaxRequests<<S as Stream>::Item>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)) {
|
||||
Some(channel) => Poll::Ready(Some(MaxRequests::new(
|
||||
channel,
|
||||
*self.project().max_in_flight_requests,
|
||||
))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::server::{
|
||||
testing::{self, FakeChannel, PollExt},
|
||||
TrackedRequest,
|
||||
};
|
||||
use pin_utils::pin_mut;
|
||||
use std::{
|
||||
marker::PhantomData,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use tracing::Span;
|
||||
|
||||
#[tokio::test]
|
||||
async fn throttler_in_flight_requests() {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
for i in 0..5 {
|
||||
throttler
|
||||
.inner
|
||||
.in_flight_requests
|
||||
.start_request(
|
||||
i,
|
||||
SystemTime::now() + Duration::from_secs(1),
|
||||
Span::current(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_done() {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_some() -> io::Result<()> {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 1,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.inner.push_req(0, 1);
|
||||
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
|
||||
assert_eq!(
|
||||
throttler
|
||||
.as_mut()
|
||||
.poll_next(&mut testing::cx())?
|
||||
.map(|r| r.map(|r| (r.request.id, r.request.message))),
|
||||
Poll::Ready(Some((0, 1)))
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_throttled() {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.inner.push_req(1, 1);
|
||||
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
||||
assert_eq!(throttler.inner.sink.len(), 1);
|
||||
let resp = throttler.inner.sink.get(0).unwrap();
|
||||
assert_eq!(resp.request_id, 1);
|
||||
assert!(resp.message.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_throttled_sink_not_ready() {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: PendingSink::default::<isize, isize>(),
|
||||
};
|
||||
pin_mut!(throttler);
|
||||
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
|
||||
|
||||
struct PendingSink<In, Out> {
|
||||
ghost: PhantomData<fn(Out) -> In>,
|
||||
}
|
||||
impl PendingSink<(), ()> {
|
||||
pub fn default<Req, Resp>(
|
||||
) -> PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
PendingSink { ghost: PhantomData }
|
||||
}
|
||||
}
|
||||
impl<In, Out> Stream for PendingSink<In, Out> {
|
||||
type Item = In;
|
||||
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
|
||||
type Error = io::Error;
|
||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
|
||||
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
impl<Req, Resp> Channel for PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
type Transport = ();
|
||||
fn config(&self) -> &Config {
|
||||
unimplemented!()
|
||||
}
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
0
|
||||
}
|
||||
fn transport(&self) -> &() {
|
||||
&()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn throttler_start_send() {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler
|
||||
.inner
|
||||
.in_flight_requests
|
||||
.start_request(
|
||||
0,
|
||||
SystemTime::now() + Duration::from_secs(1),
|
||||
Span::current(),
|
||||
)
|
||||
.unwrap();
|
||||
throttler
|
||||
.as_mut()
|
||||
.start_send(Response {
|
||||
request_id: 0,
|
||||
message: Ok(1),
|
||||
})
|
||||
.unwrap();
|
||||
assert_eq!(throttler.inner.in_flight_requests.len(), 0);
|
||||
assert_eq!(
|
||||
throttler.inner.sink.get(0),
|
||||
Some(&Response {
|
||||
request_id: 0,
|
||||
message: Ok(1),
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -4,19 +4,16 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use crate::server::{Channel, Config};
|
||||
use crate::{context, Request, Response};
|
||||
use fnv::FnvHashSet;
|
||||
use futures::{
|
||||
future::{AbortHandle, AbortRegistration},
|
||||
task::*,
|
||||
Sink, Stream,
|
||||
use crate::{
|
||||
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||
context,
|
||||
server::{Channel, Config, ResponseGuard, TrackedRequest},
|
||||
Request, Response,
|
||||
};
|
||||
use futures::{task::*, Sink, Stream};
|
||||
use pin_project::pin_project;
|
||||
use std::collections::VecDeque;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::time::SystemTime;
|
||||
use std::{collections::VecDeque, io, mem::ManuallyDrop, pin::Pin, time::SystemTime};
|
||||
use tracing::Span;
|
||||
|
||||
#[pin_project]
|
||||
pub(crate) struct FakeChannel<In, Out> {
|
||||
@@ -25,7 +22,9 @@ pub(crate) struct FakeChannel<In, Out> {
|
||||
#[pin]
|
||||
pub sink: VecDeque<Out>,
|
||||
pub config: Config,
|
||||
pub in_flight_requests: FnvHashSet<u64>,
|
||||
pub in_flight_requests: super::in_flight_requests::InFlightRequests,
|
||||
pub request_cancellation: RequestCancellation,
|
||||
pub canceled_requests: CanceledRequests,
|
||||
}
|
||||
|
||||
impl<In, Out> Stream for FakeChannel<In, Out>
|
||||
@@ -50,7 +49,7 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
|
||||
self.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&response.request_id);
|
||||
.remove_request(response.request_id);
|
||||
self.project()
|
||||
.sink
|
||||
.start_send(response)
|
||||
@@ -66,47 +65,60 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> Channel for FakeChannel<io::Result<Request<Req>>, Response<Resp>>
|
||||
impl<Req, Resp> Channel for FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>>
|
||||
where
|
||||
Req: Unpin,
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
type Transport = ();
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
self.in_flight_requests.len()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration {
|
||||
self.project().in_flight_requests.insert(id);
|
||||
AbortHandle::new_pair().1
|
||||
fn transport(&self) -> &() {
|
||||
&()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
|
||||
impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
pub fn push_req(&mut self, id: u64, message: Req) {
|
||||
self.stream.push_back(Ok(Request {
|
||||
context: context::Context {
|
||||
deadline: SystemTime::UNIX_EPOCH,
|
||||
trace_context: Default::default(),
|
||||
let (_, abort_registration) = futures::future::AbortHandle::new_pair();
|
||||
let (request_cancellation, _) = cancellations();
|
||||
self.stream.push_back(Ok(TrackedRequest {
|
||||
request: Request {
|
||||
context: context::Context {
|
||||
deadline: SystemTime::UNIX_EPOCH,
|
||||
trace_context: Default::default(),
|
||||
},
|
||||
id,
|
||||
message,
|
||||
},
|
||||
id,
|
||||
message,
|
||||
abort_registration,
|
||||
span: Span::none(),
|
||||
response_guard: ManuallyDrop::new(ResponseGuard {
|
||||
request_cancellation,
|
||||
request_id: id,
|
||||
}),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeChannel<(), ()> {
|
||||
pub fn default<Req, Resp>() -> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
|
||||
pub fn default<Req, Resp>() -> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
let (request_cancellation, canceled_requests) = cancellations();
|
||||
FakeChannel {
|
||||
stream: Default::default(),
|
||||
sink: Default::default(),
|
||||
config: Default::default(),
|
||||
in_flight_requests: Default::default(),
|
||||
request_cancellation,
|
||||
canceled_requests,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -117,10 +129,7 @@ pub trait PollExt {
|
||||
|
||||
impl<T> PollExt for Poll<Option<T>> {
|
||||
fn is_done(&self) -> bool {
|
||||
match self {
|
||||
Poll::Ready(None) => true,
|
||||
_ => false,
|
||||
}
|
||||
matches!(self, Poll::Ready(None))
|
||||
}
|
||||
}
|
||||
|
||||
113
tarpc/src/server/tokio.rs
Normal file
113
tarpc/src/server/tokio.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
use super::{Channel, Requests, Serve};
|
||||
use futures::{prelude::*, ready, task::*};
|
||||
use pin_project::pin_project;
|
||||
use std::pin::Pin;
|
||||
|
||||
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
|
||||
/// for each new channel. Returned by
|
||||
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
|
||||
#[must_use]
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct TokioServerExecutor<T, S> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
serve: S,
|
||||
}
|
||||
|
||||
impl<T, S> TokioServerExecutor<T, S> {
|
||||
pub(crate) fn new(inner: T, serve: S) -> Self {
|
||||
Self { inner, serve }
|
||||
}
|
||||
}
|
||||
|
||||
/// A future that drives the server by [spawning](tokio::spawn) each [response
|
||||
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
|
||||
/// [`Channel::execute`](crate::server::Channel::execute).
|
||||
#[must_use]
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct TokioChannelExecutor<T, S> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
serve: S,
|
||||
}
|
||||
|
||||
impl<T, S> TokioServerExecutor<T, S> {
|
||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||
self.as_mut().project().inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> TokioChannelExecutor<T, S> {
|
||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||
self.as_mut().project().inner
|
||||
}
|
||||
}
|
||||
|
||||
// Send + 'static execution helper methods.
|
||||
|
||||
impl<C> Requests<C>
|
||||
where
|
||||
C: Channel,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
{
|
||||
/// Executes all requests using the given service function. Requests are handled concurrently
|
||||
/// by [spawning](::tokio::spawn) each handler on tokio's default executor.
|
||||
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
||||
{
|
||||
TokioChannelExecutor { inner: self, serve }
|
||||
}
|
||||
}
|
||||
|
||||
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
|
||||
where
|
||||
St: Sized + Stream<Item = C>,
|
||||
C: Channel + Send + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||
Se::Fut: Send,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||
tokio::spawn(channel.execute(self.serve.clone()));
|
||||
}
|
||||
tracing::info!("Server shutting down.");
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
|
||||
where
|
||||
C: Channel + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||
S::Fut: Send,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||
match response_handler {
|
||||
Ok(resp) => {
|
||||
let server = self.serve.clone();
|
||||
tokio::spawn(async move {
|
||||
resp.execute(server).await;
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Requests stream errored out: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
@@ -16,18 +16,21 @@
|
||||
//! This crate's design is based on [opencensus
|
||||
//! tracing](https://opencensus.io/core-concepts/tracing/).
|
||||
|
||||
use opentelemetry::trace::TraceContextExt;
|
||||
use rand::Rng;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
fmt::{self, Formatter},
|
||||
mem,
|
||||
num::{NonZeroU128, NonZeroU64},
|
||||
};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
/// A context for tracing the execution of processes, distributed or otherwise.
|
||||
///
|
||||
/// Consists of a span identifying an event, an optional parent span identifying a causal event
|
||||
/// that triggered the current span, and a trace with which all related spans are associated.
|
||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Context {
|
||||
/// An identifier of the trace associated with the current context. A trace ID is typically
|
||||
/// created at a root span and passed along through all causal events.
|
||||
@@ -36,33 +39,50 @@ pub struct Context {
|
||||
/// before making an RPC, and the span ID is sent to the server. The server is free to create
|
||||
/// its own spans, for which it sets the client's span as the parent span.
|
||||
pub span_id: SpanId,
|
||||
/// An identifier of the span that originated the current span. For example, if a server sends
|
||||
/// an RPC in response to a client request that included a span, the server would create a span
|
||||
/// for the RPC and set its parent to the span_id in the incoming request's context.
|
||||
///
|
||||
/// If `parent_id` is `None`, then this is a root context.
|
||||
pub parent_id: Option<SpanId>,
|
||||
/// Indicates whether a sampler has already decided whether or not to sample the trace
|
||||
/// associated with the Context. If `sampling_decision` is None, then a decision has not yet
|
||||
/// been made. Downstream samplers do not need to abide by "no sample" decisions--for example,
|
||||
/// an upstream client may choose to never sample, which may not make sense for the client's
|
||||
/// dependencies. On the other hand, if an upstream process has chosen to sample this trace,
|
||||
/// then the downstream samplers are expected to respect that decision and also sample the
|
||||
/// trace. Otherwise, the full trace would not be able to be reconstructed.
|
||||
pub sampling_decision: SamplingDecision,
|
||||
}
|
||||
|
||||
/// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the
|
||||
/// same trace ID.
|
||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct TraceId(u128);
|
||||
#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct TraceId(#[cfg_attr(feature = "serde1", serde(with = "u128_serde"))] u128);
|
||||
|
||||
/// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace.
|
||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct SpanId(u64);
|
||||
|
||||
/// Indicates whether a sampler has decided whether or not to sample the trace associated with the
|
||||
/// Context. Downstream samplers do not need to abide by "no sample" decisions--for example, an
|
||||
/// upstream client may choose to never sample, which may not make sense for the client's
|
||||
/// dependencies. On the other hand, if an upstream process has chosen to sample this trace, then
|
||||
/// the downstream samplers are expected to respect that decision and also sample the trace.
|
||||
/// Otherwise, the full trace would not be able to be reconstructed reliably.
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[repr(u8)]
|
||||
pub enum SamplingDecision {
|
||||
/// The associated span was sampled by its creating process. Child spans must also be sampled.
|
||||
Sampled,
|
||||
/// The associated span was not sampled by its creating process.
|
||||
Unsampled,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
/// Constructs a new root context. A root context is one with no parent span.
|
||||
pub fn new_root() -> Self {
|
||||
let rng = &mut rand::thread_rng();
|
||||
Context {
|
||||
trace_id: TraceId::random(rng),
|
||||
span_id: SpanId::random(rng),
|
||||
parent_id: None,
|
||||
/// Constructs a new context with the trace ID and sampling decision inherited from the parent.
|
||||
pub(crate) fn new_child(&self) -> Self {
|
||||
Self {
|
||||
trace_id: self.trace_id,
|
||||
span_id: SpanId::random(&mut rand::thread_rng()),
|
||||
sampling_decision: self.sampling_decision,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -71,17 +91,128 @@ impl TraceId {
|
||||
/// Returns a random trace ID that can be assumed to be globally unique if `rng` generates
|
||||
/// actually-random numbers.
|
||||
pub fn random<R: Rng>(rng: &mut R) -> Self {
|
||||
TraceId(u128::from(rng.next_u64()) << mem::size_of::<u64>() | u128::from(rng.next_u64()))
|
||||
TraceId(rng.gen::<NonZeroU128>().get())
|
||||
}
|
||||
|
||||
/// Returns true iff the trace ID is 0.
|
||||
pub fn is_none(&self) -> bool {
|
||||
self.0 == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl SpanId {
|
||||
/// Returns a random span ID that can be assumed to be unique within a single trace.
|
||||
pub fn random<R: Rng>(rng: &mut R) -> Self {
|
||||
SpanId(rng.next_u64())
|
||||
SpanId(rng.gen::<NonZeroU64>().get())
|
||||
}
|
||||
|
||||
/// Returns true iff the span ID is 0.
|
||||
pub fn is_none(&self) -> bool {
|
||||
self.0 == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TraceId> for u128 {
|
||||
fn from(trace_id: TraceId) -> Self {
|
||||
trace_id.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u128> for TraceId {
|
||||
fn from(trace_id: u128) -> Self {
|
||||
Self(trace_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SpanId> for u64 {
|
||||
fn from(span_id: SpanId) -> Self {
|
||||
span_id.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u64> for SpanId {
|
||||
fn from(span_id: u64) -> Self {
|
||||
Self(span_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<opentelemetry::trace::TraceId> for TraceId {
|
||||
fn from(trace_id: opentelemetry::trace::TraceId) -> Self {
|
||||
Self::from(u128::from_be_bytes(trace_id.to_bytes()))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TraceId> for opentelemetry::trace::TraceId {
|
||||
fn from(trace_id: TraceId) -> Self {
|
||||
Self::from_bytes(u128::from(trace_id).to_be_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<opentelemetry::trace::SpanId> for SpanId {
|
||||
fn from(span_id: opentelemetry::trace::SpanId) -> Self {
|
||||
Self::from(u64::from_be_bytes(span_id.to_bytes()))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SpanId> for opentelemetry::trace::SpanId {
|
||||
fn from(span_id: SpanId) -> Self {
|
||||
Self::from_bytes(u64::from(span_id).to_be_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&tracing::Span> for Context {
|
||||
type Error = NoActiveSpan;
|
||||
|
||||
fn try_from(span: &tracing::Span) -> Result<Self, NoActiveSpan> {
|
||||
let context = span.context();
|
||||
if context.has_active_span() {
|
||||
Ok(Self::from(context.span()))
|
||||
} else {
|
||||
Err(NoActiveSpan)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<opentelemetry::trace::SpanRef<'_>> for Context {
|
||||
fn from(span: opentelemetry::trace::SpanRef<'_>) -> Self {
|
||||
let otel_ctx = span.span_context();
|
||||
Self {
|
||||
trace_id: TraceId::from(otel_ctx.trace_id()),
|
||||
span_id: SpanId::from(otel_ctx.span_id()),
|
||||
sampling_decision: SamplingDecision::from(otel_ctx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SamplingDecision> for opentelemetry::trace::TraceFlags {
|
||||
fn from(decision: SamplingDecision) -> Self {
|
||||
match decision {
|
||||
SamplingDecision::Sampled => opentelemetry::trace::TraceFlags::SAMPLED,
|
||||
SamplingDecision::Unsampled => opentelemetry::trace::TraceFlags::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&opentelemetry::trace::SpanContext> for SamplingDecision {
|
||||
fn from(context: &opentelemetry::trace::SpanContext) -> Self {
|
||||
if context.is_sampled() {
|
||||
SamplingDecision::Sampled
|
||||
} else {
|
||||
SamplingDecision::Unsampled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SamplingDecision {
|
||||
fn default() -> Self {
|
||||
Self::Unsampled
|
||||
}
|
||||
}
|
||||
|
||||
/// Returned when a [`Context`] cannot be constructed from a [`Span`](tracing::Span).
|
||||
#[derive(Debug)]
|
||||
pub struct NoActiveSpan;
|
||||
|
||||
impl fmt::Display for TraceId {
|
||||
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "{:02x}", self.0)?;
|
||||
@@ -89,9 +220,42 @@ impl fmt::Display for TraceId {
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for TraceId {
|
||||
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "{:02x}", self.0)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SpanId {
|
||||
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "{:02x}", self.0)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for SpanId {
|
||||
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "{:02x}", self.0)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
mod u128_serde {
|
||||
pub fn serialize<S>(u: &u128, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serde::Serialize::serialize(&u.to_le_bytes(), serializer)
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<u128, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
Ok(u128::from_le_bytes(serde::Deserialize::deserialize(
|
||||
deserializer,
|
||||
)?))
|
||||
}
|
||||
}
|
||||
|
||||
40
tarpc/src/transport.rs
Normal file
40
tarpc/src/transport.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Provides a [`Transport`](sealed::Transport) trait as well as implementations.
|
||||
//!
|
||||
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`](sealed::Transport)
|
||||
//! can be plugged in, using whatever protocol it wants.
|
||||
|
||||
pub mod channel;
|
||||
|
||||
pub(crate) mod sealed {
|
||||
use futures::prelude::*;
|
||||
use std::error::Error;
|
||||
|
||||
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
|
||||
pub trait Transport<SinkItem, Item>
|
||||
where
|
||||
Self: Stream<Item = Result<Item, <Self as Sink<SinkItem>>::Error>>,
|
||||
Self: Sink<SinkItem, Error = <Self as Transport<SinkItem, Item>>::TransportError>,
|
||||
<Self as Sink<SinkItem>>::Error: Error,
|
||||
{
|
||||
/// Associated type where clauses are not elaborated; this associated type allows users
|
||||
/// bounding types by Transport to avoid having to explicitly add `T::Error: Error` to their
|
||||
/// bounds.
|
||||
type TransportError: Error + Send + Sync + 'static;
|
||||
}
|
||||
|
||||
impl<T, SinkItem, Item, E> Transport<SinkItem, Item> for T
|
||||
where
|
||||
T: ?Sized,
|
||||
T: Stream<Item = Result<Item, E>>,
|
||||
T: Sink<SinkItem, Error = E>,
|
||||
T::Error: Error + Send + Sync + 'static,
|
||||
{
|
||||
type TransportError = E;
|
||||
}
|
||||
}
|
||||
202
tarpc/src/transport/channel.rs
Normal file
202
tarpc/src/transport/channel.rs
Normal file
@@ -0,0 +1,202 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Transports backed by in-memory channels.
|
||||
|
||||
use futures::{task::*, Sink, Stream};
|
||||
use pin_project::pin_project;
|
||||
use std::{error::Error, pin::Pin};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Errors that occur in the sending or receiving of messages over a channel.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ChannelError {
|
||||
/// An error occurred sending over the channel.
|
||||
#[error("an error occurred sending over the channel")]
|
||||
Send(#[source] Box<dyn Error + Send + Sync + 'static>),
|
||||
}
|
||||
|
||||
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
|
||||
/// [`Sink`].
|
||||
pub fn unbounded<SinkItem, Item>() -> (
|
||||
UnboundedChannel<SinkItem, Item>,
|
||||
UnboundedChannel<Item, SinkItem>,
|
||||
) {
|
||||
let (tx1, rx2) = mpsc::unbounded_channel();
|
||||
let (tx2, rx1) = mpsc::unbounded_channel();
|
||||
(
|
||||
UnboundedChannel { tx: tx1, rx: rx1 },
|
||||
UnboundedChannel { tx: tx2, rx: rx2 },
|
||||
)
|
||||
}
|
||||
|
||||
/// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender)
|
||||
/// and [`UnboundedReceiver`](mpsc::UnboundedReceiver).
|
||||
#[derive(Debug)]
|
||||
pub struct UnboundedChannel<Item, SinkItem> {
|
||||
rx: mpsc::UnboundedReceiver<Item>,
|
||||
tx: mpsc::UnboundedSender<SinkItem>,
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
|
||||
type Item = Result<Item, ChannelError>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Item, ChannelError>>> {
|
||||
self.rx.poll_recv(cx).map(|option| option.map(Ok))
|
||||
}
|
||||
}
|
||||
|
||||
const CLOSED_MESSAGE: &str = "the channel is closed and cannot accept new items for sending";
|
||||
|
||||
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
|
||||
type Error = ChannelError;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(if self.tx.is_closed() {
|
||||
Err(ChannelError::Send(CLOSED_MESSAGE.into()))
|
||||
} else {
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
self.tx
|
||||
.send(item)
|
||||
.map_err(|_| ChannelError::Send(CLOSED_MESSAGE.into()))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
// UnboundedSender requires no flushing.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
// UnboundedSender can't initiate closure.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns two channel peers with buffer equal to `capacity`. Each [`Stream`] yields items sent
|
||||
/// through the other's [`Sink`].
|
||||
pub fn bounded<SinkItem, Item>(
|
||||
capacity: usize,
|
||||
) -> (Channel<SinkItem, Item>, Channel<Item, SinkItem>) {
|
||||
let (tx1, rx2) = futures::channel::mpsc::channel(capacity);
|
||||
let (tx2, rx1) = futures::channel::mpsc::channel(capacity);
|
||||
(Channel { tx: tx1, rx: rx1 }, Channel { tx: tx2, rx: rx2 })
|
||||
}
|
||||
|
||||
/// A bi-directional channel backed by a [`Sender`](futures::channel::mpsc::Sender)
|
||||
/// and [`Receiver`](futures::channel::mpsc::Receiver).
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct Channel<Item, SinkItem> {
|
||||
#[pin]
|
||||
rx: futures::channel::mpsc::Receiver<Item>,
|
||||
#[pin]
|
||||
tx: futures::channel::mpsc::Sender<SinkItem>,
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Stream for Channel<Item, SinkItem> {
|
||||
type Item = Result<Item, ChannelError>;
|
||||
|
||||
fn poll_next(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Item, ChannelError>>> {
|
||||
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
|
||||
type Error = ChannelError;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_ready(cx)
|
||||
.map_err(|e| ChannelError::Send(Box::new(e)))
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
self.project()
|
||||
.tx
|
||||
.start_send(item)
|
||||
.map_err(|e| ChannelError::Send(Box::new(e)))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_flush(cx)
|
||||
.map_err(|e| ChannelError::Send(Box::new(e)))
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_close(cx)
|
||||
.map_err(|e| ChannelError::Send(Box::new(e)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(feature = "tokio1")]
|
||||
mod tests {
|
||||
use crate::{
|
||||
client, context,
|
||||
server::{incoming::Incoming, BaseChannel},
|
||||
transport::{
|
||||
self,
|
||||
channel::{Channel, UnboundedChannel},
|
||||
},
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{prelude::*, stream};
|
||||
use std::io;
|
||||
use tracing::trace;
|
||||
|
||||
#[test]
|
||||
fn ensure_is_transport() {
|
||||
fn is_transport<SinkItem, Item, T: crate::Transport<SinkItem, Item>>() {}
|
||||
is_transport::<(), (), UnboundedChannel<(), ()>>();
|
||||
is_transport::<(), (), Channel<(), ()>>();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||
tokio::spawn(
|
||||
stream::once(future::ready(server_channel))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(|_ctx, request: String| {
|
||||
future::ready(request.parse::<u64>().map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("{request:?} is not an int"),
|
||||
)
|
||||
}))
|
||||
}),
|
||||
);
|
||||
|
||||
let client = client::new(client::Config::default(), client_channel).spawn();
|
||||
|
||||
let response1 = client.call(context::current(), "", "123".into()).await?;
|
||||
let response2 = client.call(context::current(), "", "abc".into()).await?;
|
||||
|
||||
trace!("response1: {:?}, response2: {:?}", response1, response2);
|
||||
|
||||
assert_matches!(response1, Ok(123));
|
||||
assert_matches!(response2, Err(ref e) if e.kind() == io::ErrorKind::InvalidInput);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -10,8 +10,8 @@ use std::{
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
|
||||
#[cfg(feature = "serde1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "serde1")))]
|
||||
pub mod serde;
|
||||
|
||||
/// Extension trait for [SystemTimes](SystemTime) in the future, i.e. deadlines.
|
||||
@@ -38,11 +38,34 @@ where
|
||||
H: BuildHasher,
|
||||
{
|
||||
fn compact(&mut self, usage_ratio_threshold: f64) {
|
||||
if self.capacity() > 1000 {
|
||||
let usage_ratio = self.len() as f64 / self.capacity() as f64;
|
||||
if usage_ratio < usage_ratio_threshold {
|
||||
self.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
let usage_ratio_threshold = usage_ratio_threshold.clamp(f64::MIN_POSITIVE, 1.);
|
||||
let cap = f64::max(1000., self.len() as f64 / usage_ratio_threshold);
|
||||
self.shrink_to(cap as usize);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compact() {
|
||||
let mut map = HashMap::with_capacity(2048);
|
||||
assert_eq!(map.capacity(), 3584);
|
||||
|
||||
// Make usage ratio 25%
|
||||
for i in 0..896 {
|
||||
map.insert(format!("k{i}"), "v");
|
||||
}
|
||||
|
||||
map.compact(-1.0);
|
||||
assert_eq!(map.capacity(), 3584);
|
||||
|
||||
map.compact(0.25);
|
||||
assert_eq!(map.capacity(), 3584);
|
||||
|
||||
map.compact(0.50);
|
||||
assert_eq!(map.capacity(), 1792);
|
||||
|
||||
map.compact(1.0);
|
||||
assert_eq!(map.capacity(), 1792);
|
||||
|
||||
map.compact(2.0);
|
||||
assert_eq!(map.capacity(), 1792);
|
||||
}
|
||||
@@ -5,31 +5,7 @@
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::{
|
||||
io,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
/// Serializes `system_time` as a `u64` equal to the number of seconds since the epoch.
|
||||
pub fn serialize_epoch_secs<S>(system_time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
const ZERO_SECS: Duration = Duration::from_secs(0);
|
||||
system_time
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or(ZERO_SECS)
|
||||
.as_secs() // Only care about second precision
|
||||
.serialize(serializer)
|
||||
}
|
||||
|
||||
/// Deserializes [`SystemTime`] from a `u64` equal to the number of seconds since the epoch.
|
||||
pub fn deserialize_epoch_secs<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Ok(SystemTime::UNIX_EPOCH + Duration::from_secs(u64::deserialize(deserializer)?))
|
||||
}
|
||||
use std::io;
|
||||
|
||||
/// Serializes [`io::ErrorKind`] as a `u32`.
|
||||
#[allow(clippy::trivially_copy_pass_by_ref)] // Exact fn signature required by serde derive
|
||||
@@ -2,4 +2,8 @@
|
||||
fn ui() {
|
||||
let t = trybuild::TestCases::new();
|
||||
t.compile_fail("tests/compile_fail/*.rs");
|
||||
#[cfg(feature = "tokio1")]
|
||||
t.compile_fail("tests/compile_fail/tokio/*.rs");
|
||||
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
||||
t.compile_fail("tests/compile_fail/serde_transport/*.rs");
|
||||
}
|
||||
|
||||
15
tarpc/tests/compile_fail/must_use_request_dispatch.rs
Normal file
15
tarpc/tests/compile_fail/must_use_request_dispatch.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
use tarpc::client;
|
||||
|
||||
#[tarpc::service]
|
||||
trait World {
|
||||
async fn hello(name: String) -> String;
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let (client_transport, _) = tarpc::transport::channel::unbounded();
|
||||
|
||||
#[deny(unused_must_use)]
|
||||
{
|
||||
WorldClient::new(client::Config::default(), client_transport).dispatch;
|
||||
}
|
||||
}
|
||||
11
tarpc/tests/compile_fail/must_use_request_dispatch.stderr
Normal file
11
tarpc/tests/compile_fail/must_use_request_dispatch.stderr
Normal file
@@ -0,0 +1,11 @@
|
||||
error: unused `RequestDispatch` that must be used
|
||||
--> tests/compile_fail/must_use_request_dispatch.rs:13:9
|
||||
|
|
||||
13 | WorldClient::new(client::Config::default(), client_transport).dispatch;
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/must_use_request_dispatch.rs:11:12
|
||||
|
|
||||
11 | #[deny(unused_must_use)]
|
||||
| ^^^^^^^^^^^^^^^
|
||||
@@ -0,0 +1,9 @@
|
||||
use tarpc::serde_transport;
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
fn main() {
|
||||
#[deny(unused_must_use)]
|
||||
{
|
||||
serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
error: unused `Connect` that must be used
|
||||
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9
|
||||
|
|
||||
7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:5:12
|
||||
|
|
||||
5 | #[deny(unused_must_use)]
|
||||
| ^^^^^^^^^^^^^^^
|
||||
@@ -1,4 +1,4 @@
|
||||
#[tarpc::service]
|
||||
#[tarpc::service(derive_serde = false)]
|
||||
trait World {
|
||||
async fn hello(name: String) -> String;
|
||||
}
|
||||
@@ -7,8 +7,8 @@ struct HelloServer;
|
||||
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
fn hello(name: String) -> String {
|
||||
format!("Hello, {}!", name)
|
||||
fn hello(name: String) -> String {
|
||||
format!("Hello, {name}!", name)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,13 +7,5 @@ error: not all trait items implemented, missing: `HelloFut`
|
||||
error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async
|
||||
--> $DIR/tarpc_server_missing_async.rs:10:5
|
||||
|
|
||||
10 | fn hello(name: String) -> String {
|
||||
10 | fn hello(name: String) -> String {
|
||||
| ^^
|
||||
|
||||
error[E0433]: failed to resolve: use of undeclared type or module `serde`
|
||||
--> $DIR/tarpc_server_missing_async.rs:1:1
|
||||
|
|
||||
1 | #[tarpc::service]
|
||||
| ^^^^^^^^^^^^^^^^^ use of undeclared type or module `serde`
|
||||
|
|
||||
= note: this error originates in an attribute macro (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||
|
||||
29
tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs
Normal file
29
tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use tarpc::{
|
||||
context,
|
||||
server::{self, Channel},
|
||||
};
|
||||
|
||||
#[tarpc::service]
|
||||
trait World {
|
||||
async fn hello(name: String) -> String;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct HelloServer;
|
||||
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
async fn hello(self, _: context::Context, name: String) -> String {
|
||||
format!("Hello, {name}!")
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let (_, server_transport) = tarpc::transport::channel::unbounded();
|
||||
let server = server::BaseChannel::with_defaults(server_transport);
|
||||
|
||||
#[deny(unused_must_use)]
|
||||
{
|
||||
server.execute(HelloServer.serve());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
error: unused `TokioChannelExecutor` that must be used
|
||||
--> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9
|
||||
|
|
||||
27 | server.execute(HelloServer.serve());
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12
|
||||
|
|
||||
25 | #[deny(unused_must_use)]
|
||||
| ^^^^^^^^^^^^^^^
|
||||
30
tarpc/tests/compile_fail/tokio/must_use_server_executor.rs
Normal file
30
tarpc/tests/compile_fail/tokio/must_use_server_executor.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use futures::stream::once;
|
||||
use tarpc::{
|
||||
context,
|
||||
server::{self, incoming::Incoming},
|
||||
};
|
||||
|
||||
#[tarpc::service]
|
||||
trait World {
|
||||
async fn hello(name: String) -> String;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct HelloServer;
|
||||
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
async fn hello(self, _: context::Context, name: String) -> String {
|
||||
format!("Hello, {name}!")
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let (_, server_transport) = tarpc::transport::channel::unbounded();
|
||||
let server = once(async move { server::BaseChannel::with_defaults(server_transport) });
|
||||
|
||||
#[deny(unused_must_use)]
|
||||
{
|
||||
server.execute(HelloServer.serve());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
error: unused `TokioServerExecutor` that must be used
|
||||
--> tests/compile_fail/tokio/must_use_server_executor.rs:28:9
|
||||
|
|
||||
28 | server.execute(HelloServer.serve());
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/tokio/must_use_server_executor.rs:26:12
|
||||
|
|
||||
26 | #[deny(unused_must_use)]
|
||||
| ^^^^^^^^^^^^^^^
|
||||
55
tarpc/tests/dataservice.rs
Normal file
55
tarpc/tests/dataservice.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use futures::prelude::*;
|
||||
use tarpc::serde_transport;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{incoming::Incoming, BaseChannel},
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
#[tarpc::derive_serde]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum TestData {
|
||||
Black,
|
||||
White,
|
||||
}
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait ColorProtocol {
|
||||
async fn get_opposite_color(color: TestData) -> TestData;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ColorServer;
|
||||
|
||||
#[tarpc::server]
|
||||
impl ColorProtocol for ColorServer {
|
||||
async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData {
|
||||
match color {
|
||||
TestData::White => TestData::Black,
|
||||
TestData::Black => TestData::White,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_call() -> anyhow::Result<()> {
|
||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
|
||||
let addr = transport.local_addr();
|
||||
tokio::spawn(
|
||||
transport
|
||||
.take(1)
|
||||
.filter_map(|r| async { r.ok() })
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(ColorServer.serve()),
|
||||
);
|
||||
|
||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let client = ColorProtocolClient::new(client::Config::default(), transport).spawn();
|
||||
|
||||
let color = client
|
||||
.get_opposite_color(context::current(), TestData::White)
|
||||
.await?;
|
||||
assert_eq!(color, TestData::Black);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -3,15 +3,14 @@ use futures::{
|
||||
future::{join_all, ready, Ready},
|
||||
prelude::*,
|
||||
};
|
||||
use std::io;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use tarpc::{
|
||||
client::{self},
|
||||
context, serde_transport,
|
||||
server::{self, BaseChannel, Channel, Handler},
|
||||
context,
|
||||
server::{self, incoming::Incoming, BaseChannel, Channel},
|
||||
transport::channel,
|
||||
};
|
||||
use tokio::join;
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
#[tarpc_plugins::service]
|
||||
trait Service {
|
||||
@@ -32,23 +31,23 @@ impl Service for Server {
|
||||
type HeyFut = Ready<String>;
|
||||
|
||||
fn hey(self, _: context::Context, name: String) -> Self::HeyFut {
|
||||
ready(format!("Hey, {}.", name))
|
||||
ready(format!("Hey, {name}."))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn sequential() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
#[tokio::test]
|
||||
async fn sequential() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
|
||||
tokio::spawn(
|
||||
BaseChannel::new(server::Config::default(), rx)
|
||||
.respond_with(Server.serve())
|
||||
.execute(),
|
||||
.requests()
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let mut client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||
|
||||
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
|
||||
assert_matches!(
|
||||
@@ -58,21 +57,75 @@ async fn sequential() -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn serde() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
#[tokio::test]
|
||||
async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
|
||||
#[tarpc_plugins::service]
|
||||
trait Loop {
|
||||
async fn r#loop();
|
||||
}
|
||||
|
||||
let transport = serde_transport::tcp::listen("localhost:56789", Json::default).await?;
|
||||
#[derive(Clone)]
|
||||
struct LoopServer;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AllHandlersComplete;
|
||||
|
||||
#[tarpc::server]
|
||||
impl Loop for LoopServer {
|
||||
async fn r#loop(self, _: context::Context) {
|
||||
loop {
|
||||
futures::pending!();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
|
||||
// Set up a client that initiates a long-lived request.
|
||||
// The request will complete in error when the server drops the connection.
|
||||
tokio::spawn(async move {
|
||||
let client = LoopClient::new(client::Config::default(), tx).spawn();
|
||||
|
||||
let mut ctx = context::current();
|
||||
ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60);
|
||||
let _ = client.r#loop(ctx).await;
|
||||
});
|
||||
|
||||
let mut requests = BaseChannel::with_defaults(rx).requests();
|
||||
// Reading a request should trigger the request being registered with BaseChannel.
|
||||
let first_request = requests.next().await.unwrap()?;
|
||||
// Dropping the channel should trigger cleanup of outstanding requests.
|
||||
drop(requests);
|
||||
// In-flight requests should be aborted by channel cleanup.
|
||||
// The first and only request sent by the client is `loop`, which is an infinite loop
|
||||
// on the server side, so if cleanup was not triggered, this line should hang indefinitely.
|
||||
first_request.execute(LoopServer.serve()).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
||||
#[tokio::test]
|
||||
async fn serde() -> anyhow::Result<()> {
|
||||
use tarpc::serde_transport;
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56789", Json::default).await?;
|
||||
let addr = transport.local_addr();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(transport.take(1).filter_map(|r| async { r.ok() }))
|
||||
.respond_with(Server.serve()),
|
||||
transport
|
||||
.take(1)
|
||||
.filter_map(|r| async { r.ok() })
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let mut client = ServiceClient::new(client::Config::default(), transport).spawn()?;
|
||||
let client = ServiceClient::new(client::Config::default(), transport).spawn();
|
||||
|
||||
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
|
||||
assert_matches!(
|
||||
@@ -83,27 +136,22 @@ async fn serde() -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn concurrent() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
#[tokio::test]
|
||||
async fn concurrent() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(stream::once(ready(rx)))
|
||||
.respond_with(Server.serve()),
|
||||
stream::once(ready(rx))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||
|
||||
let mut c = client.clone();
|
||||
let req1 = c.add(context::current(), 1, 2);
|
||||
|
||||
let mut c = client.clone();
|
||||
let req2 = c.add(context::current(), 3, 4);
|
||||
|
||||
let mut c = client.clone();
|
||||
let req3 = c.hey(context::current(), "Tim".to_string());
|
||||
let req1 = client.add(context::current(), 1, 2);
|
||||
let req2 = client.add(context::current(), 3, 4);
|
||||
let req3 = client.hey(context::current(), "Tim".to_string());
|
||||
|
||||
assert_matches!(req1.await, Ok(3));
|
||||
assert_matches!(req2.await, Ok(7));
|
||||
@@ -112,27 +160,22 @@ async fn concurrent() -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn concurrent_join() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
#[tokio::test]
|
||||
async fn concurrent_join() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(stream::once(ready(rx)))
|
||||
.respond_with(Server.serve()),
|
||||
stream::once(ready(rx))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||
|
||||
let mut c = client.clone();
|
||||
let req1 = c.add(context::current(), 1, 2);
|
||||
|
||||
let mut c = client.clone();
|
||||
let req2 = c.add(context::current(), 3, 4);
|
||||
|
||||
let mut c = client.clone();
|
||||
let req3 = c.hey(context::current(), "Tim".to_string());
|
||||
let req1 = client.add(context::current(), 1, 2);
|
||||
let req2 = client.add(context::current(), 3, 4);
|
||||
let req3 = client.hey(context::current(), "Tim".to_string());
|
||||
|
||||
let (resp1, resp2, resp3) = join!(req1, req2, req3);
|
||||
assert_matches!(resp1, Ok(3));
|
||||
@@ -142,24 +185,21 @@ async fn concurrent_join() -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn concurrent_join_all() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
#[tokio::test]
|
||||
async fn concurrent_join_all() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(stream::once(ready(rx)))
|
||||
.respond_with(Server.serve()),
|
||||
stream::once(ready(rx))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||
|
||||
let mut c1 = client.clone();
|
||||
let mut c2 = client.clone();
|
||||
|
||||
let req1 = c1.add(context::current(), 1, 2);
|
||||
let req2 = c2.add(context::current(), 3, 4);
|
||||
let req1 = client.add(context::current(), 1, 2);
|
||||
let req2 = client.add(context::current(), 3, 4);
|
||||
|
||||
let responses = join_all(vec![req1, req2]).await;
|
||||
assert_matches!(responses[0], Ok(3));
|
||||
@@ -167,3 +207,38 @@ async fn concurrent_join_all() -> io::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn counter() -> anyhow::Result<()> {
|
||||
#[tarpc::service]
|
||||
trait Counter {
|
||||
async fn count() -> u32;
|
||||
}
|
||||
|
||||
struct CountService(u32);
|
||||
|
||||
impl Counter for &mut CountService {
|
||||
type CountFut = futures::future::Ready<u32>;
|
||||
|
||||
fn count(self, _: context::Context) -> Self::CountFut {
|
||||
self.0 += 1;
|
||||
futures::future::ready(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(async {
|
||||
let mut requests = BaseChannel::with_defaults(rx).requests();
|
||||
let mut counter = CountService(0);
|
||||
|
||||
while let Some(Ok(request)) = requests.next().await {
|
||||
request.execute(counter.serve()).await;
|
||||
}
|
||||
});
|
||||
|
||||
let client = CounterClient::new(client::Config::default(), tx).spawn();
|
||||
assert_matches!(client.count(context::current()).await, Ok(1));
|
||||
assert_matches!(client.count(context::current()).await, Ok(2));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user