mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
Compare commits
117 Commits
v0.20.0
...
client-clo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e4c3a2b8b | ||
|
|
d78b24b631 | ||
|
|
49900d7a35 | ||
|
|
1e680e3a5a | ||
|
|
2591d21e94 | ||
|
|
6632f68d95 | ||
|
|
25985ad56a | ||
|
|
d6a24e9420 | ||
|
|
281a78f3c7 | ||
|
|
a0787d0091 | ||
|
|
d2acba0e8a | ||
|
|
ea7b6763c4 | ||
|
|
eb67c540b9 | ||
|
|
4151d0abd3 | ||
|
|
d0c11a6efa | ||
|
|
82c4da1743 | ||
|
|
0a15e0b75c | ||
|
|
0b315c29bf | ||
|
|
56f09bf61f | ||
|
|
6d82e82419 | ||
|
|
9bebaf814a | ||
|
|
5f4d6e6008 | ||
|
|
07d07d7ba3 | ||
|
|
a41bbf65b2 | ||
|
|
21e2f7ca62 | ||
|
|
7b7c182411 | ||
|
|
db0c778ead | ||
|
|
c3efb83ac1 | ||
|
|
3d7b0171fe | ||
|
|
c191ff5b2e | ||
|
|
90bc7f741d | ||
|
|
d3f6c01df2 | ||
|
|
c6450521e6 | ||
|
|
1da6bcec57 | ||
|
|
75a5591158 | ||
|
|
9462aad3bf | ||
|
|
0964fc51ff | ||
|
|
27aacab432 | ||
|
|
3feb465ad3 | ||
|
|
66cdc99ae0 | ||
|
|
66419db6fd | ||
|
|
72d5dbba89 | ||
|
|
e75193c191 | ||
|
|
ce4fd49161 | ||
|
|
3c978c5bf6 | ||
|
|
6f419e9a9a | ||
|
|
b3eb8d0b7a | ||
|
|
3b422eb179 | ||
|
|
4b513bad73 | ||
|
|
e71e17866d | ||
|
|
7e3fbec077 | ||
|
|
e4bc5e8e32 | ||
|
|
bc982c5584 | ||
|
|
d440e12c19 | ||
|
|
bc8128af69 | ||
|
|
1d87c14262 | ||
|
|
ca929c2178 | ||
|
|
569039734b | ||
|
|
3d43310e6a | ||
|
|
d21cbddb0d | ||
|
|
25aa857edf | ||
|
|
0bb2e2bbbe | ||
|
|
dc376343d6 | ||
|
|
2e7d1f8a88 | ||
|
|
6314591c65 | ||
|
|
7dd7494420 | ||
|
|
6c10e3649f | ||
|
|
4c6dee13d2 | ||
|
|
e45abe953a | ||
|
|
dec3e491b5 | ||
|
|
6ce341cf79 | ||
|
|
b9868250f8 | ||
|
|
a3f1064efe | ||
|
|
026083d653 | ||
|
|
d27f341bde | ||
|
|
2264ebecfc | ||
|
|
3207affb4a | ||
|
|
0602afd50c | ||
|
|
4343e12217 | ||
|
|
7fda862fb8 | ||
|
|
aa7b875b1a | ||
|
|
54d6e0e3b6 | ||
|
|
bea3b442aa | ||
|
|
954a2502e7 | ||
|
|
e3f34917c5 | ||
|
|
f65dd05949 | ||
|
|
240c436b34 | ||
|
|
c9803688cc | ||
|
|
4987094483 | ||
|
|
ff55080193 | ||
|
|
258193c932 | ||
|
|
67823ef5de | ||
|
|
a671457243 | ||
|
|
cf654549da | ||
|
|
6a01e32a2d | ||
|
|
e6597fab03 | ||
|
|
ebd245a93d | ||
|
|
3ebc3b5845 | ||
|
|
0e5973109d | ||
|
|
5f02d7383a | ||
|
|
2bae148529 | ||
|
|
42a2e03aab | ||
|
|
b566d0c646 | ||
|
|
b359f16767 | ||
|
|
f8681ab134 | ||
|
|
7e521768ab | ||
|
|
e9b1e7d101 | ||
|
|
f0322fb892 | ||
|
|
617daebb88 | ||
|
|
a11d4fff58 | ||
|
|
bf42a04d83 | ||
|
|
06528d6953 | ||
|
|
9f00395746 | ||
|
|
e0674cd57f | ||
|
|
7e49bd9ee7 | ||
|
|
8a1baa9c4e | ||
|
|
31c713d188 |
48
.github/workflows/main.yml
vendored
48
.github/workflows/main.yml
vendored
@@ -1,4 +1,10 @@
|
|||||||
on: [push, pull_request]
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
name: Continuous integration
|
name: Continuous integration
|
||||||
|
|
||||||
@@ -7,27 +13,59 @@ jobs:
|
|||||||
name: Check
|
name: Check
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
- name: Cancel previous
|
||||||
|
uses: styfle/cancel-workflow-action@0.7.0
|
||||||
|
with:
|
||||||
|
access_token: ${{ github.token }}
|
||||||
- uses: actions/checkout@v1
|
- uses: actions/checkout@v1
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
profile: minimal
|
profile: minimal
|
||||||
toolchain: stable
|
toolchain: stable
|
||||||
|
target: mipsel-unknown-linux-gnu
|
||||||
override: true
|
override: true
|
||||||
- uses: actions-rs/cargo@v1
|
- uses: actions-rs/cargo@v1
|
||||||
with:
|
with:
|
||||||
command: check
|
command: check
|
||||||
args: --all-features
|
args: --all-features
|
||||||
|
- uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: check
|
||||||
|
args: --all-features --target mipsel-unknown-linux-gnu
|
||||||
|
|
||||||
test:
|
test:
|
||||||
name: Test Suite
|
name: Test Suite
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
- name: Cancel previous
|
||||||
|
uses: styfle/cancel-workflow-action@0.7.0
|
||||||
|
with:
|
||||||
|
access_token: ${{ github.token }}
|
||||||
- uses: actions/checkout@v1
|
- uses: actions/checkout@v1
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
profile: minimal
|
profile: minimal
|
||||||
toolchain: stable
|
toolchain: stable
|
||||||
override: true
|
override: true
|
||||||
|
- uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: test
|
||||||
|
- uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: test
|
||||||
|
args: --manifest-path tarpc/Cargo.toml --features serde1
|
||||||
|
- uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: test
|
||||||
|
args: --manifest-path tarpc/Cargo.toml --features tokio1
|
||||||
|
- uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: test
|
||||||
|
args: --manifest-path tarpc/Cargo.toml --features serde-transport
|
||||||
|
- uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: test
|
||||||
|
args: --manifest-path tarpc/Cargo.toml --features tcp
|
||||||
- uses: actions-rs/cargo@v1
|
- uses: actions-rs/cargo@v1
|
||||||
with:
|
with:
|
||||||
command: test
|
command: test
|
||||||
@@ -37,6 +75,10 @@ jobs:
|
|||||||
name: Rustfmt
|
name: Rustfmt
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
- name: Cancel previous
|
||||||
|
uses: styfle/cancel-workflow-action@0.7.0
|
||||||
|
with:
|
||||||
|
access_token: ${{ github.token }}
|
||||||
- uses: actions/checkout@v1
|
- uses: actions/checkout@v1
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
@@ -53,6 +95,10 @@ jobs:
|
|||||||
name: Clippy
|
name: Clippy
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
- name: Cancel previous
|
||||||
|
uses: styfle/cancel-workflow-action@0.7.0
|
||||||
|
with:
|
||||||
|
access_token: ${{ github.token }}
|
||||||
- uses: actions/checkout@v1
|
- uses: actions/checkout@v1
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
|
resolver = "2"
|
||||||
|
|
||||||
members = [
|
members = [
|
||||||
"example-service",
|
"example-service",
|
||||||
"tarpc",
|
"tarpc",
|
||||||
"plugins",
|
"plugins",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[profile.dev]
|
||||||
|
split-debuginfo = "unpacked"
|
||||||
|
|||||||
74
README.md
74
README.md
@@ -1,9 +1,21 @@
|
|||||||
[](https://github.com/google/tarpc/actions?query=workflow%3A%22Continuous+integration%22)
|
[![Crates.io][crates-badge]][crates-url]
|
||||||
[](https://crates.io/crates/tarpc)
|
[![MIT licensed][mit-badge]][mit-url]
|
||||||
[](https://discordapp.com/channels/647529123996237854)
|
[![Build status][gh-actions-badge]][gh-actions-url]
|
||||||
|
[![Discord chat][discord-badge]][discord-url]
|
||||||
|
|
||||||
|
[crates-badge]: https://img.shields.io/crates/v/tarpc.svg
|
||||||
|
[crates-url]: https://crates.io/crates/tarpc
|
||||||
|
[mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg
|
||||||
|
[mit-url]: LICENSE
|
||||||
|
[gh-actions-badge]: https://github.com/google/tarpc/workflows/Continuous%20integration/badge.svg
|
||||||
|
[gh-actions-url]: https://github.com/google/tarpc/actions?query=workflow%3A%22Continuous+integration%22
|
||||||
|
[discord-badge]: https://img.shields.io/discord/647529123996237854.svg?logo=discord&style=flat-square
|
||||||
|
[discord-url]: https://discord.gg/gXwpdSt
|
||||||
|
|
||||||
# tarpc
|
# tarpc
|
||||||
|
|
||||||
|
<!-- cargo-sync-readme start -->
|
||||||
|
|
||||||
*Disclaimer*: This is not an official Google product.
|
*Disclaimer*: This is not an official Google product.
|
||||||
|
|
||||||
tarpc is an RPC framework for rust with a focus on ease of use. Defining a
|
tarpc is an RPC framework for rust with a focus on ease of use. Defining a
|
||||||
@@ -12,7 +24,7 @@ writing a server is taken care of for you.
|
|||||||
|
|
||||||
[Documentation](https://docs.rs/crate/tarpc/)
|
[Documentation](https://docs.rs/crate/tarpc/)
|
||||||
|
|
||||||
### What is an RPC framework?
|
## What is an RPC framework?
|
||||||
"RPC" stands for "Remote Procedure Call," a function call where the work of
|
"RPC" stands for "Remote Procedure Call," a function call where the work of
|
||||||
producing the return value is being done somewhere else. When an rpc function is
|
producing the return value is being done somewhere else. When an rpc function is
|
||||||
invoked, behind the scenes the function contacts some other process somewhere
|
invoked, behind the scenes the function contacts some other process somewhere
|
||||||
@@ -30,7 +42,7 @@ process, and no context switching between different languages.
|
|||||||
Some other features of tarpc:
|
Some other features of tarpc:
|
||||||
- Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
|
- Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
|
||||||
used as a transport to connect the client and server.
|
used as a transport to connect the client and server.
|
||||||
- `Send` optional: if the transport doesn't require it, neither does tarpc!
|
- `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
|
||||||
- Cascading cancellation: dropping a request will send a cancellation message to the server.
|
- Cascading cancellation: dropping a request will send a cancellation message to the server.
|
||||||
The server will cease any unfinished work on the request, subsequently cancelling any of its
|
The server will cease any unfinished work on the request, subsequently cancelling any of its
|
||||||
own requests, repeating for the entire chain of transitive dependencies.
|
own requests, repeating for the entire chain of transitive dependencies.
|
||||||
@@ -39,29 +51,39 @@ Some other features of tarpc:
|
|||||||
requests sent by the server that use the request context will propagate the request deadline.
|
requests sent by the server that use the request context will propagate the request deadline.
|
||||||
For example, if a server is handling a request with a 10s deadline, does 2s of work, then
|
For example, if a server is handling a request with a 10s deadline, does 2s of work, then
|
||||||
sends a request to another server, that server will see an 8s deadline.
|
sends a request to another server, that server will see an 8s deadline.
|
||||||
|
- Distributed tracing: tarpc is instrumented with
|
||||||
|
[tracing](https://github.com/tokio-rs/tracing) primitives extended with
|
||||||
|
[OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
|
||||||
|
[Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
|
||||||
|
each RPC can be traced through the client, server, amd other dependencies downstream of the
|
||||||
|
server. Even for applications not connected to a distributed tracing collector, the
|
||||||
|
instrumentation can also be ingested by regular loggers like
|
||||||
|
[env_logger](https://github.com/env-logger-rs/env_logger/).
|
||||||
- Serde serialization: enabling the `serde1` Cargo feature will make service requests and
|
- Serde serialization: enabling the `serde1` Cargo feature will make service requests and
|
||||||
responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
|
responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
|
||||||
be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
|
be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
|
||||||
|
|
||||||
### Usage
|
## Usage
|
||||||
Add to your `Cargo.toml` dependencies:
|
Add to your `Cargo.toml` dependencies:
|
||||||
|
|
||||||
```toml
|
```toml
|
||||||
tarpc = { version = "0.18.0", features = ["full"] }
|
tarpc = "0.27"
|
||||||
```
|
```
|
||||||
|
|
||||||
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||||
These generated types make it easy and ergonomic to write servers with less boilerplate.
|
These generated types make it easy and ergonomic to write servers with less boilerplate.
|
||||||
Simply implement the generated service trait, and you're off to the races!
|
Simply implement the generated service trait, and you're off to the races!
|
||||||
|
|
||||||
### Example
|
## Example
|
||||||
|
|
||||||
For this example, in addition to tarpc, also add two other dependencies to
|
This example uses [tokio](https://tokio.rs), so add the following dependencies to
|
||||||
your `Cargo.toml`:
|
your `Cargo.toml`:
|
||||||
|
|
||||||
```toml
|
```toml
|
||||||
|
anyhow = "1.0"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
tokio = "0.2"
|
tarpc = { version = "0.27", features = ["tokio1"] }
|
||||||
|
tokio = { version = "1.0", features = ["macros"] }
|
||||||
```
|
```
|
||||||
|
|
||||||
In the following example, we use an in-process channel for communication between
|
In the following example, we use an in-process channel for communication between
|
||||||
@@ -71,15 +93,15 @@ For a more real-world example, see [example-service](example-service).
|
|||||||
First, let's set up the dependencies and service definition.
|
First, let's set up the dependencies and service definition.
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
|
|
||||||
use futures::{
|
use futures::{
|
||||||
future::{self, Ready},
|
future::{self, Ready},
|
||||||
prelude::*,
|
prelude::*,
|
||||||
};
|
};
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
client, context,
|
client, context,
|
||||||
server::{self, Handler},
|
server::{self, incoming::Incoming},
|
||||||
};
|
};
|
||||||
use std::io;
|
|
||||||
|
|
||||||
// This is the service definition. It looks a lot like a trait definition.
|
// This is the service definition. It looks a lot like a trait definition.
|
||||||
// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||||
@@ -112,27 +134,21 @@ impl World for HelloServer {
|
|||||||
```
|
```
|
||||||
|
|
||||||
Lastly let's write our `main` that will start the server. While this example uses an
|
Lastly let's write our `main` that will start the server. While this example uses an
|
||||||
[in-process
|
[in-process channel](transport::channel), tarpc also ships a generic [`serde_transport`]
|
||||||
channel](https://docs.rs/tarpc/0.18.0/tarpc/transport/channel/struct.UnboundedChannel.html),
|
behind the `serde-transport` feature, with additional [TCP](serde_transport::tcp) functionality
|
||||||
tarpc also ships bincode and JSON
|
available behind the `tcp` feature.
|
||||||
tokio-net based TCP transports that are generic over all serializable types.
|
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> io::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||||
|
|
||||||
let server = server::new(server::Config::default())
|
let server = server::BaseChannel::with_defaults(server_transport);
|
||||||
// incoming() takes a stream of transports such as would be returned by
|
tokio::spawn(server.execute(HelloServer.serve()));
|
||||||
// TcpListener::incoming (but a stream instead of an iterator).
|
|
||||||
.incoming(stream::once(future::ready(server_transport)))
|
|
||||||
.respond_with(HelloServer.serve());
|
|
||||||
|
|
||||||
tokio::spawn(server);
|
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||||
|
// that takes a config and any Transport as input.
|
||||||
// WorldClient is generated by the macro. It has a constructor `new` that takes a config and
|
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
||||||
// any Transport as input
|
|
||||||
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
|
|
||||||
|
|
||||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||||
@@ -145,9 +161,11 @@ async fn main() -> io::Result<()> {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Service Documentation
|
## Service Documentation
|
||||||
|
|
||||||
Use `cargo doc` as you normally would to see the documentation created for all
|
Use `cargo doc` as you normally would to see the documentation created for all
|
||||||
items expanded by a `service!` invocation.
|
items expanded by a `service!` invocation.
|
||||||
|
|
||||||
|
<!-- cargo-sync-readme end -->
|
||||||
|
|
||||||
License: MIT
|
License: MIT
|
||||||
|
|||||||
295
RELEASES.md
295
RELEASES.md
@@ -1,3 +1,296 @@
|
|||||||
|
## 0.27.1 (2021-09-22)
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
### RPC error type is changing
|
||||||
|
|
||||||
|
RPC return types are changing from `Result<Response, io::Error>` to `Result<Response,
|
||||||
|
tarpc::client::RpcError>`.
|
||||||
|
|
||||||
|
Becaue tarpc is a library, not an application, it should strive to
|
||||||
|
use structured errors in its API so that users have maximal flexibility
|
||||||
|
in how they handle errors. io::Error makes that hard, because it is a
|
||||||
|
kitchen-sink error type.
|
||||||
|
|
||||||
|
RPCs in particular only have 3 classes of errors:
|
||||||
|
|
||||||
|
- The connection breaks.
|
||||||
|
- The request expires.
|
||||||
|
- The server decides not to process the request.
|
||||||
|
|
||||||
|
RPC responses can also contain application-specific errors, but from the
|
||||||
|
perspective of the RPC library, those are opaque to the framework, classified
|
||||||
|
as successful responsees.
|
||||||
|
|
||||||
|
### Open Telemetry
|
||||||
|
|
||||||
|
The Opentelemetry dependency is updated to version 0.16.x.
|
||||||
|
|
||||||
|
## 0.27.0 (2021-09-22)
|
||||||
|
|
||||||
|
This version was yanked due to tarpc-plugins version mismatches.
|
||||||
|
|
||||||
|
|
||||||
|
## 0.26.0 (2021-04-14)
|
||||||
|
|
||||||
|
### New Features
|
||||||
|
|
||||||
|
#### Tracing
|
||||||
|
|
||||||
|
tarpc is now instrumented with tracing primitives extended with
|
||||||
|
OpenTelemetry traces. Using a compatible tracing-opentelemetry
|
||||||
|
subscriber like Jaeger, each RPC can be traced through the client,
|
||||||
|
server, amd other dependencies downstream of the server. Even for
|
||||||
|
applications not connected to a distributed tracing collector, the
|
||||||
|
instrumentation can also be ingested by regular loggers like env_logger.
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
#### Logging
|
||||||
|
|
||||||
|
Logged events are now structured using tracing. For applications using a
|
||||||
|
logger and not a tracing subscriber, these logs may look different or
|
||||||
|
contain information in a less consumable manner. The easiest solution is
|
||||||
|
to add a tracing subscriber that logs to stdout, such as
|
||||||
|
tracing_subscriber::fmt.
|
||||||
|
|
||||||
|
#### Context
|
||||||
|
|
||||||
|
- Context no longer has parent_span, which was actually never needed,
|
||||||
|
because the context sent in an RPC is inherently the parent context.
|
||||||
|
For purposes of distributed tracing, the client side of the RPC has all
|
||||||
|
necessary information to link the span to its parent; the server side
|
||||||
|
need do nothing more than export the (trace ID, span ID) tuple.
|
||||||
|
- Context has a new field, SamplingDecision, which has two variants,
|
||||||
|
Sampled and Unsampled. This field can be used by downstream systems to
|
||||||
|
determine whether a trace needs to be exported. If the parent span is
|
||||||
|
sampled, the expectation is that all child spans be exported, as well;
|
||||||
|
to do otherwise could result in lossy traces being exported. Note that
|
||||||
|
if an Openetelemetry tracing subscriber is not installed, the fallback
|
||||||
|
context will still be used, but the Context's sampling decision will
|
||||||
|
always be inherited by the parent Context's sampling decision.
|
||||||
|
- Context::scope has been removed. Context propagation is now done via
|
||||||
|
tracing's task-local spans. Spans can be propagated across tasks via
|
||||||
|
Span::in_scope. When a service receives a request, it attaches an
|
||||||
|
Opentelemetry context to the local Span created before request handling,
|
||||||
|
and this context contains the request deadline. This span-local deadline
|
||||||
|
is retrieved by Context::current, but it cannot be modified so that
|
||||||
|
future Context::current calls contain a different deadline. However, the
|
||||||
|
deadline in the context passed into an RPC call will override it, so
|
||||||
|
users can retrieve the current context and then modify the deadline
|
||||||
|
field, as has been historically possible.
|
||||||
|
- Context propgation precedence changes: when an RPC is initiated, the
|
||||||
|
current Span's Opentelemetry context takes precedence over the trace
|
||||||
|
context passed into the RPC method. If there is no current Span, then
|
||||||
|
the trace context argument is used as it has been historically. Note
|
||||||
|
that Opentelemetry context propagation requires an Opentelemetry
|
||||||
|
tracing subscriber to be installed.
|
||||||
|
|
||||||
|
#### Server
|
||||||
|
|
||||||
|
- The server::Channel trait now has an additional required associated
|
||||||
|
type and method which returns the underlying transport. This makes it
|
||||||
|
more ergonomic for users to retrieve transport-specific information,
|
||||||
|
like IP Address. BaseChannel implements Channel::transport by returning
|
||||||
|
the underlying transport, and channel decorators like Throttler just
|
||||||
|
delegate to the Channel::transport method of the wrapped channel.
|
||||||
|
|
||||||
|
#### Client
|
||||||
|
|
||||||
|
- NewClient::spawn no longer returns a result, as spawn can't fail.
|
||||||
|
|
||||||
|
### References
|
||||||
|
|
||||||
|
1. https://github.com/tokio-rs/tracing
|
||||||
|
2. https://opentelemetry.io
|
||||||
|
3. https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
|
||||||
|
4. https://github.com/env-logger-rs/env_logger
|
||||||
|
|
||||||
|
## 0.25.0 (2021-03-10)
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
#### Major server module refactoring
|
||||||
|
|
||||||
|
1. Renames
|
||||||
|
|
||||||
|
Some of the items in this module were renamed to be less generic:
|
||||||
|
|
||||||
|
- Handler => Incoming
|
||||||
|
- ClientHandler => Requests
|
||||||
|
- ResponseHandler => InFlightRequest
|
||||||
|
- Channel::{respond_with => requests}
|
||||||
|
|
||||||
|
In the case of Handler: handler of *what*? Now it's a bit clearer that this is a stream of Channels
|
||||||
|
(aka *incoming* connections).
|
||||||
|
|
||||||
|
Similarly, ClientHandler was a stream of requests over a single connection. Hopefully Requests
|
||||||
|
better reflects that.
|
||||||
|
|
||||||
|
ResponseHandler was renamed InFlightRequest because it no longer contains the serving function.
|
||||||
|
Instead, it is just the request, plus the response channel and an abort hook. As a result of this,
|
||||||
|
Channel::respond_with underwent a big change: it used to take the serving function and return a
|
||||||
|
ClientHandler; now it has been renamed Channel::requests and does not take any args.
|
||||||
|
|
||||||
|
2. Execute methods
|
||||||
|
|
||||||
|
All methods thats actually result in responses being generated have been consolidated into methods
|
||||||
|
named `execute`:
|
||||||
|
|
||||||
|
- InFlightRequest::execute returns a future that completes when a response has been generated and
|
||||||
|
sent to the server Channel.
|
||||||
|
- Requests::execute automatically spawns response handlers for all requests over a single channel.
|
||||||
|
- Channel::execute is a convenience for `channel.requests().execute()`.
|
||||||
|
- Incoming::execute automatically spawns response handlers for all requests over all channels.
|
||||||
|
|
||||||
|
3. Removal of Server.
|
||||||
|
|
||||||
|
server::Server was removed, as it provided no value over the Incoming/Channel abstractions.
|
||||||
|
Additionally, server::new was removed, since it just returned a Server.
|
||||||
|
|
||||||
|
#### Client RPC methods now take &self
|
||||||
|
|
||||||
|
This required the breaking change of removing the Client trait. The intent of the Client trait was
|
||||||
|
to facilitate the decorator pattern by allowing users to create their own Clients that added
|
||||||
|
behavior on top of the base client. Unfortunately, this trait had become a maintenance burden,
|
||||||
|
consistently causing issues with lifetimes and the lack of generic associated types. Specifically,
|
||||||
|
it meant that Client impls could not use async fns, which is no longer tenable today, with channel
|
||||||
|
libraries moving to async fns.
|
||||||
|
|
||||||
|
#### Servers no longer send deadline-exceed responses.
|
||||||
|
|
||||||
|
The deadline-exceeded response was largely redundant, because the client
|
||||||
|
shouldn't normally be waiting for such a response, anyway -- the normal
|
||||||
|
client will automatically remove the in-flight request when it reaches
|
||||||
|
the deadline.
|
||||||
|
|
||||||
|
This also allows for internalizing the expiration+cleanup logic entirely
|
||||||
|
within BaseChannel, without having it leak into the Channel trait and
|
||||||
|
requiring action taken by the Requests struct.
|
||||||
|
|
||||||
|
#### Clients no longer send cancel messages when the request deadline is exceeded.
|
||||||
|
|
||||||
|
The server already knows when the request deadline was exceeded, so the client didn't need to inform
|
||||||
|
it.
|
||||||
|
|
||||||
|
### Fixes
|
||||||
|
|
||||||
|
- When a channel is dropped, all in-flight requests for that channel are now aborted.
|
||||||
|
|
||||||
|
## 0.24.1 (2020-12-28)
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
Upgrades tokio to 1.0.
|
||||||
|
|
||||||
|
## 0.24.0 (2020-12-28)
|
||||||
|
|
||||||
|
This release was yanked.
|
||||||
|
|
||||||
|
## 0.23.0 (2020-10-19)
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
Upgrades tokio to 0.3.
|
||||||
|
|
||||||
|
## 0.22.0 (2020-08-02)
|
||||||
|
|
||||||
|
This release adds some flexibility and consistency to `serde_transport`, with one new feature and
|
||||||
|
one small breaking change.
|
||||||
|
|
||||||
|
### New Features
|
||||||
|
|
||||||
|
`serde_transport::tcp` now exposes framing configuration on `connect()` and `listen()`. This is
|
||||||
|
useful if, for instance, you want to send requests or responses that are larger than the maximum
|
||||||
|
payload allowed by default:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
let mut transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default);
|
||||||
|
transport.config_mut().max_frame_length(4294967296);
|
||||||
|
let mut client = MyClient::new(client::Config::default(), transport.await?).spawn()?;
|
||||||
|
```
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
The codec argument to `serde_transport::tcp::connect` changed from a Codec to impl Fn() -> Codec,
|
||||||
|
to be consistent with `serde_transport::tcp::listen`. While only one Codec is needed, more than one
|
||||||
|
person has been tripped up by the inconsistency between `connect` and `listen`. Unfortunately, the
|
||||||
|
compiler errors are not much help in this case, so it was decided to simply do the more intuitive
|
||||||
|
thing so that the compiler doesn't need to step in in the first place.
|
||||||
|
|
||||||
|
|
||||||
|
## 0.21.1 (2020-08-02)
|
||||||
|
|
||||||
|
### New Features
|
||||||
|
|
||||||
|
#### #[tarpc::server] diagnostics
|
||||||
|
|
||||||
|
When a service impl uses #[tarpc::server], only `async fn`s are re-written. This can lead to
|
||||||
|
confusing compiler errors about missing associated types:
|
||||||
|
|
||||||
|
```
|
||||||
|
error: not all trait items implemented, missing: `HelloFut`
|
||||||
|
--> $DIR/tarpc_server_missing_async.rs:9:1
|
||||||
|
|
|
||||||
|
9 | impl World for HelloServer {
|
||||||
|
| ^^^^
|
||||||
|
```
|
||||||
|
|
||||||
|
The proc macro now provides better diagnostics for this case:
|
||||||
|
|
||||||
|
```
|
||||||
|
error: not all trait items implemented, missing: `HelloFut`
|
||||||
|
--> $DIR/tarpc_server_missing_async.rs:9:1
|
||||||
|
|
|
||||||
|
9 | impl World for HelloServer {
|
||||||
|
| ^^^^
|
||||||
|
|
||||||
|
error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async
|
||||||
|
--> $DIR/tarpc_server_missing_async.rs:10:5
|
||||||
|
|
|
||||||
|
10 | fn hello(name: String) -> String {
|
||||||
|
| ^^
|
||||||
|
```
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
#### Fixed client hanging when server shuts down
|
||||||
|
|
||||||
|
Previously, clients would ignore when the read half of the transport was closed, continuing to
|
||||||
|
write requests. This didn't make much sense, because without the ability to receive responses,
|
||||||
|
clients have no way to know if requests were actually processed by the server. It basically just
|
||||||
|
led to clients that would hang for a few seconds before shutting down. This has now been
|
||||||
|
corrected: clients will immediately shut down when the read-half of the transport is closed.
|
||||||
|
|
||||||
|
#### More docs.rs documentation
|
||||||
|
|
||||||
|
Previously, docs.rs only documented items enabled by default, notably leaving out documentation
|
||||||
|
for tokio and serde features. This has now been corrected: docs.rs should have documentation
|
||||||
|
for all optional features.
|
||||||
|
|
||||||
|
## 0.21.0 (2020-06-26)
|
||||||
|
|
||||||
|
### New Features
|
||||||
|
|
||||||
|
A new proc macro, `#[tarpc::server]` was added! This enables service impls to elide the boilerplate
|
||||||
|
of specifying associated types for each RPC. With the ubiquity of async-await, most code won't have
|
||||||
|
nameable futures and will just be boxing the return type anyway. This macro does that for you.
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
- Enums had `_non_exhaustive` fields replaced with the #[non_exhaustive] attribute.
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
- https://github.com/google/tarpc/issues/304
|
||||||
|
|
||||||
|
A race condition in code that limits number of connections per client caused occasional panics.
|
||||||
|
|
||||||
|
- https://github.com/google/tarpc/pull/295
|
||||||
|
|
||||||
|
Made request timeouts account for time spent in the outbound buffer. Previously, a large outbound
|
||||||
|
queue would lead to requests not timing out correctly.
|
||||||
|
|
||||||
## 0.20.0 (2019-12-11)
|
## 0.20.0 (2019-12-11)
|
||||||
|
|
||||||
### Breaking Changes
|
### Breaking Changes
|
||||||
@@ -10,7 +303,7 @@
|
|||||||
|
|
||||||
## 0.13.0 (2018-10-16)
|
## 0.13.0 (2018-10-16)
|
||||||
|
|
||||||
### Breaking Changes
|
### Breaking Changes
|
||||||
|
|
||||||
Version 0.13 marks a significant departure from previous versions of tarpc. The
|
Version 0.13 marks a significant departure from previous versions of tarpc. The
|
||||||
API has changed significantly. The tokio-proto crate has been torn out and
|
API has changed significantly. The tokio-proto crate has been torn out and
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "tarpc-example-service"
|
name = "tarpc-example-service"
|
||||||
version = "0.6.0"
|
version = "0.10.0"
|
||||||
authors = ["Tim Kuehn <tikue@google.com>"]
|
authors = ["Tim Kuehn <tikue@google.com>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
@@ -13,13 +13,18 @@ readme = "../README.md"
|
|||||||
description = "An example server built on tarpc."
|
description = "An example server built on tarpc."
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
clap = "2.0"
|
anyhow = "1.0"
|
||||||
|
clap = "3.0.0-beta.2"
|
||||||
|
log = "0.4"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
serde = { version = "1.0" }
|
opentelemetry = { version = "0.16", features = ["rt-tokio"] }
|
||||||
tarpc = { version = "0.20", path = "../tarpc", features = ["full"] }
|
opentelemetry-jaeger = { version = "0.15", features = ["rt-tokio"] }
|
||||||
tokio = { version = "0.2", features = ["full"] }
|
rand = "0.8"
|
||||||
tokio-serde = { version = "0.6", features = ["json"] }
|
tarpc = { version = "0.27", path = "../tarpc", features = ["full"] }
|
||||||
env_logger = "0.6"
|
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
|
||||||
|
tracing = { version = "0.1" }
|
||||||
|
tracing-opentelemetry = "0.15"
|
||||||
|
tracing-subscriber = "0.2"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
name = "service"
|
name = "service"
|
||||||
|
|||||||
@@ -4,55 +4,49 @@
|
|||||||
// license that can be found in the LICENSE file or at
|
// license that can be found in the LICENSE file or at
|
||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use clap::{App, Arg};
|
use clap::Clap;
|
||||||
use std::{io, net::SocketAddr};
|
use service::{init_tracing, WorldClient};
|
||||||
use tarpc::{client, context};
|
use std::{net::SocketAddr, time::Duration};
|
||||||
use tokio_serde::formats::Json;
|
use tarpc::{client, context, tokio_serde::formats::Json};
|
||||||
|
use tokio::time::sleep;
|
||||||
|
use tracing::Instrument;
|
||||||
|
|
||||||
|
#[derive(Clap)]
|
||||||
|
struct Flags {
|
||||||
|
/// Sets the server address to connect to.
|
||||||
|
#[clap(long)]
|
||||||
|
server_addr: SocketAddr,
|
||||||
|
/// Sets the name to say hello to.
|
||||||
|
#[clap(long)]
|
||||||
|
name: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> io::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let flags = App::new("Hello Client")
|
let flags = Flags::parse();
|
||||||
.version("0.1")
|
init_tracing("Tarpc Example Client")?;
|
||||||
.author("Tim <tikue@google.com>")
|
|
||||||
.about("Say hello!")
|
|
||||||
.arg(
|
|
||||||
Arg::with_name("server_addr")
|
|
||||||
.long("server_addr")
|
|
||||||
.value_name("ADDRESS")
|
|
||||||
.help("Sets the server address to connect to.")
|
|
||||||
.required(true)
|
|
||||||
.takes_value(true),
|
|
||||||
)
|
|
||||||
.arg(
|
|
||||||
Arg::with_name("name")
|
|
||||||
.short("n")
|
|
||||||
.long("name")
|
|
||||||
.value_name("STRING")
|
|
||||||
.help("Sets the name to say hello to.")
|
|
||||||
.required(true)
|
|
||||||
.takes_value(true),
|
|
||||||
)
|
|
||||||
.get_matches();
|
|
||||||
|
|
||||||
let server_addr = flags.value_of("server_addr").unwrap();
|
let transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
|
||||||
let server_addr = server_addr
|
|
||||||
.parse::<SocketAddr>()
|
|
||||||
.unwrap_or_else(|e| panic!(r#"--server_addr value "{}" invalid: {}"#, server_addr, e));
|
|
||||||
|
|
||||||
let name = flags.value_of("name").unwrap().into();
|
|
||||||
|
|
||||||
let transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default()).await?;
|
|
||||||
|
|
||||||
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
|
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
|
||||||
// config and any Transport as input.
|
// config and any Transport as input.
|
||||||
let mut client = service::WorldClient::new(client::Config::default(), transport).spawn()?;
|
let client = WorldClient::new(client::Config::default(), transport.await?).spawn();
|
||||||
|
|
||||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
let hello = async move {
|
||||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
// Send the request twice, just to be safe! ;)
|
||||||
// specifies a deadline and trace information which can be helpful in debugging requests.
|
tokio::select! {
|
||||||
let hello = client.hello(context::current(), name).await?;
|
hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 }
|
||||||
|
hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.instrument(tracing::info_span!("Two Hellos"))
|
||||||
|
.await;
|
||||||
|
|
||||||
println!("{}", hello);
|
tracing::info!("{:?}", hello);
|
||||||
|
|
||||||
|
// Let the background span processor finish.
|
||||||
|
sleep(Duration::from_micros(1)).await;
|
||||||
|
opentelemetry::global::shutdown_tracer_provider();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,9 @@
|
|||||||
// license that can be found in the LICENSE file or at
|
// license that can be found in the LICENSE file or at
|
||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
use std::env;
|
||||||
|
use tracing_subscriber::{fmt::format::FmtSpan, prelude::*};
|
||||||
|
|
||||||
/// This is the service definition. It looks a lot like a trait definition.
|
/// This is the service definition. It looks a lot like a trait definition.
|
||||||
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||||
#[tarpc::service]
|
#[tarpc::service]
|
||||||
@@ -11,3 +14,21 @@ pub trait World {
|
|||||||
/// Returns a greeting for name.
|
/// Returns a greeting for name.
|
||||||
async fn hello(name: String) -> String;
|
async fn hello(name: String) -> String;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
|
||||||
|
pub fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||||
|
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||||
|
|
||||||
|
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||||
|
.with_service_name(service_name)
|
||||||
|
.with_max_packet_size(2usize.pow(13))
|
||||||
|
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||||
|
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||||
|
.with(tracing_subscriber::fmt::layer().with_span_events(FmtSpan::NEW | FmtSpan::CLOSE))
|
||||||
|
.with(tracing_opentelemetry::layer().with_tracer(tracer))
|
||||||
|
.try_init()?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,81 +4,68 @@
|
|||||||
// license that can be found in the LICENSE file or at
|
// license that can be found in the LICENSE file or at
|
||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use clap::{App, Arg};
|
use clap::Clap;
|
||||||
use futures::{
|
use futures::{future, prelude::*};
|
||||||
future::{self, Ready},
|
use rand::{
|
||||||
prelude::*,
|
distributions::{Distribution, Uniform},
|
||||||
|
thread_rng,
|
||||||
};
|
};
|
||||||
use service::World;
|
use service::{init_tracing, World};
|
||||||
use std::{
|
use std::{
|
||||||
io,
|
net::{IpAddr, Ipv6Addr, SocketAddr},
|
||||||
net::{IpAddr, SocketAddr},
|
time::Duration,
|
||||||
};
|
};
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
context,
|
context,
|
||||||
server::{self, Channel, Handler},
|
server::{self, incoming::Incoming, Channel},
|
||||||
|
tokio_serde::formats::Json,
|
||||||
};
|
};
|
||||||
use tokio_serde::formats::Json;
|
use tokio::time;
|
||||||
|
|
||||||
|
#[derive(Clap)]
|
||||||
|
struct Flags {
|
||||||
|
/// Sets the port number to listen on.
|
||||||
|
#[clap(long)]
|
||||||
|
port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
// This is the type that implements the generated World trait. It is the business logic
|
// This is the type that implements the generated World trait. It is the business logic
|
||||||
// and is used to start the server.
|
// and is used to start the server.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct HelloServer(SocketAddr);
|
struct HelloServer(SocketAddr);
|
||||||
|
|
||||||
|
#[tarpc::server]
|
||||||
impl World for HelloServer {
|
impl World for HelloServer {
|
||||||
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
|
async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
// an associated type representing the future output by the fn.
|
let sleep_time =
|
||||||
|
Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng()));
|
||||||
type HelloFut = Ready<String>;
|
time::sleep(sleep_time).await;
|
||||||
|
format!("Hello, {}! You are connected from {}", name, self.0)
|
||||||
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
|
||||||
future::ready(format!(
|
|
||||||
"Hello, {}! You are connected from {:?}.",
|
|
||||||
name, self.0
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> io::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
env_logger::init();
|
let flags = Flags::parse();
|
||||||
|
init_tracing("Tarpc Example Server")?;
|
||||||
|
|
||||||
let flags = App::new("Hello Server")
|
let server_addr = (IpAddr::V6(Ipv6Addr::LOCALHOST), flags.port);
|
||||||
.version("0.1")
|
|
||||||
.author("Tim <tikue@google.com>")
|
|
||||||
.about("Say hello!")
|
|
||||||
.arg(
|
|
||||||
Arg::with_name("port")
|
|
||||||
.short("p")
|
|
||||||
.long("port")
|
|
||||||
.value_name("NUMBER")
|
|
||||||
.help("Sets the port number to listen on")
|
|
||||||
.required(true)
|
|
||||||
.takes_value(true),
|
|
||||||
)
|
|
||||||
.get_matches();
|
|
||||||
|
|
||||||
let port = flags.value_of("port").unwrap();
|
|
||||||
let port = port
|
|
||||||
.parse()
|
|
||||||
.unwrap_or_else(|e| panic!(r#"--port value "{}" invalid: {}"#, port, e));
|
|
||||||
|
|
||||||
let server_addr = (IpAddr::from([0, 0, 0, 0]), port);
|
|
||||||
|
|
||||||
// JSON transport is provided by the json_transport tarpc module. It makes it easy
|
// JSON transport is provided by the json_transport tarpc module. It makes it easy
|
||||||
// to start up a serde-powered json serialization strategy over TCP.
|
// to start up a serde-powered json serialization strategy over TCP.
|
||||||
tarpc::serde_transport::tcp::listen(&server_addr, Json::default)
|
let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?;
|
||||||
.await?
|
listener.config_mut().max_frame_length(usize::MAX);
|
||||||
|
listener
|
||||||
// Ignore accept errors.
|
// Ignore accept errors.
|
||||||
.filter_map(|r| future::ready(r.ok()))
|
.filter_map(|r| future::ready(r.ok()))
|
||||||
.map(server::BaseChannel::with_defaults)
|
.map(server::BaseChannel::with_defaults)
|
||||||
// Limit channels to 1 per IP.
|
// Limit channels to 1 per IP.
|
||||||
.max_channels_per_key(1, |t| t.as_ref().peer_addr().unwrap().ip())
|
.max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip())
|
||||||
// serve is generated by the service attribute. It takes as input any type implementing
|
// serve is generated by the service attribute. It takes as input any type implementing
|
||||||
// the generated World trait.
|
// the generated World trait.
|
||||||
.map(|channel| {
|
.map(|channel| {
|
||||||
let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap());
|
let server = HelloServer(channel.transport().peer_addr().unwrap());
|
||||||
channel.respond_with(server.serve()).execute()
|
channel.execute(server.serve())
|
||||||
})
|
})
|
||||||
// Max 10 channels.
|
// Max 10 channels.
|
||||||
.buffer_unordered(10)
|
.buffer_unordered(10)
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
printf "${PREFIX} Checking for rustfmt ... "
|
printf "${PREFIX} Checking for rustfmt ... "
|
||||||
command -v cargo fmt &>/dev/null
|
command -v rustfmt &>/dev/null
|
||||||
if [ $? == 0 ]; then
|
if [ $? == 0 ]; then
|
||||||
printf "${SUCCESS}\n"
|
printf "${SUCCESS}\n"
|
||||||
else
|
else
|
||||||
@@ -93,19 +93,19 @@ diff=""
|
|||||||
for file in $(git diff --name-only --cached);
|
for file in $(git diff --name-only --cached);
|
||||||
do
|
do
|
||||||
if [ ${file: -3} == ".rs" ]; then
|
if [ ${file: -3} == ".rs" ]; then
|
||||||
diff="$diff$(cargo fmt -- --skip-children --write-mode=diff $file)"
|
diff="$diff$(rustfmt --edition 2018 --check $file)"
|
||||||
|
if [ $? != 0 ]; then
|
||||||
|
FMTRESULT=1
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
if grep --quiet "^[-+]" <<< "$diff"; then
|
|
||||||
FMTRESULT=1
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ "${TARPC_SKIP_RUSTFMT}" == 1 ]; then
|
if [ "${TARPC_SKIP_RUSTFMT}" == 1 ]; then
|
||||||
printf "${SKIPPED}\n"$?
|
printf "${SKIPPED}\n"$?
|
||||||
elif [ ${FMTRESULT} != 0 ]; then
|
elif [ ${FMTRESULT} != 0 ]; then
|
||||||
FAILED=1
|
FAILED=1
|
||||||
printf "${FAILURE}\n"
|
printf "${FAILURE}\n"
|
||||||
echo "$diff" | sed 's/Using rustfmt config file.*$/d/'
|
echo "$diff"
|
||||||
else
|
else
|
||||||
printf "${SUCCESS}\n"
|
printf "${SUCCESS}\n"
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -84,18 +84,19 @@ command -v rustup &>/dev/null
|
|||||||
if [ "$?" == 0 ]; then
|
if [ "$?" == 0 ]; then
|
||||||
printf "${SUCCESS}\n"
|
printf "${SUCCESS}\n"
|
||||||
|
|
||||||
|
try_run "Building ... " cargo +stable build --color=always
|
||||||
|
try_run "Testing ... " cargo +stable test --color=always
|
||||||
|
try_run "Testing with all features enabled ... " cargo +stable test --all-features --color=always
|
||||||
|
for EXAMPLE in $(cargo +stable run --example 2>&1 | grep ' ' | awk '{print $1}')
|
||||||
|
do
|
||||||
|
try_run "Running example \"$EXAMPLE\" ... " cargo +stable run --example $EXAMPLE
|
||||||
|
done
|
||||||
|
|
||||||
check_toolchain nightly
|
check_toolchain nightly
|
||||||
if [ ${TOOLCHAIN_RESULT} == 1 ]; then
|
if [ ${TOOLCHAIN_RESULT} != 1 ]; then
|
||||||
exit 1
|
try_run "Running clippy ... " cargo +nightly clippy --color=always -Z unstable-options -- --deny warnings
|
||||||
fi
|
fi
|
||||||
|
|
||||||
try_run "Building ... " cargo build --color=always
|
|
||||||
try_run "Testing ... " cargo test --color=always
|
|
||||||
try_run "Testing with all features enabled ... " cargo test --all-features --color=always
|
|
||||||
for EXAMPLE in $(cargo run --example 2>&1 | grep ' ' | awk '{print $1}')
|
|
||||||
do
|
|
||||||
try_run "Running example \"$EXAMPLE\" ... " cargo run --example $EXAMPLE
|
|
||||||
done
|
|
||||||
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "tarpc-plugins"
|
name = "tarpc-plugins"
|
||||||
version = "0.7.0"
|
version = "0.12.0"
|
||||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
@@ -19,14 +19,15 @@ serde1 = []
|
|||||||
travis-ci = { repository = "google/tarpc" }
|
travis-ci = { repository = "google/tarpc" }
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
syn = { version = "1.0.11", features = ["full"] }
|
proc-macro2 = "1.0"
|
||||||
quote = "1.0.2"
|
quote = "1.0"
|
||||||
proc-macro2 = "1.0.6"
|
syn = { version = "1.0", features = ["full"] }
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
proc-macro = true
|
proc-macro = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
assert-type-eq = "0.1.0"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
tarpc = { path = "../tarpc" }
|
tarpc = { path = "../tarpc", features = ["serde1"] }
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
144
plugins/tests/server.rs
Normal file
144
plugins/tests/server.rs
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
use assert_type_eq::assert_type_eq;
|
||||||
|
use futures::Future;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use tarpc::context;
|
||||||
|
|
||||||
|
// these need to be out here rather than inside the function so that the
|
||||||
|
// assert_type_eq macro can pick them up.
|
||||||
|
#[tarpc::service]
|
||||||
|
trait Foo {
|
||||||
|
async fn two_part(s: String, i: i32) -> (String, i32);
|
||||||
|
async fn bar(s: String) -> String;
|
||||||
|
async fn baz();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn type_generation_works() {
|
||||||
|
#[tarpc::server]
|
||||||
|
impl Foo for () {
|
||||||
|
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
|
||||||
|
(s, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn bar(self, _: context::Context, s: String) -> String {
|
||||||
|
s
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn baz(self, _: context::Context) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// the assert_type_eq macro can only be used once per block.
|
||||||
|
{
|
||||||
|
assert_type_eq!(
|
||||||
|
<() as Foo>::TwoPartFut,
|
||||||
|
Pin<Box<dyn Future<Output = (String, i32)> + Send>>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
assert_type_eq!(
|
||||||
|
<() as Foo>::BarFut,
|
||||||
|
Pin<Box<dyn Future<Output = String> + Send>>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
assert_type_eq!(
|
||||||
|
<() as Foo>::BazFut,
|
||||||
|
Pin<Box<dyn Future<Output = ()> + Send>>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(non_camel_case_types)]
|
||||||
|
#[test]
|
||||||
|
fn raw_idents_work() {
|
||||||
|
type r#yield = String;
|
||||||
|
|
||||||
|
#[tarpc::service]
|
||||||
|
trait r#trait {
|
||||||
|
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
|
||||||
|
async fn r#fn(r#impl: r#yield) -> r#yield;
|
||||||
|
async fn r#async();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tarpc::server]
|
||||||
|
impl r#trait for () {
|
||||||
|
async fn r#await(
|
||||||
|
self,
|
||||||
|
_: context::Context,
|
||||||
|
r#struct: r#yield,
|
||||||
|
r#enum: i32,
|
||||||
|
) -> (r#yield, i32) {
|
||||||
|
(r#struct, r#enum)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
|
||||||
|
r#impl
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn r#async(self, _: context::Context) {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn syntax() {
|
||||||
|
#[tarpc::service]
|
||||||
|
trait Syntax {
|
||||||
|
#[deny(warnings)]
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
async fn TestCamelCaseDoesntConflict();
|
||||||
|
async fn hello() -> String;
|
||||||
|
#[doc = "attr"]
|
||||||
|
async fn attr(s: String) -> String;
|
||||||
|
async fn no_args_no_return();
|
||||||
|
async fn no_args() -> ();
|
||||||
|
async fn one_arg(one: String) -> i32;
|
||||||
|
async fn two_args_no_return(one: String, two: u64);
|
||||||
|
async fn two_args(one: String, two: u64) -> String;
|
||||||
|
async fn no_args_ret_error() -> i32;
|
||||||
|
async fn one_arg_ret_error(one: String) -> String;
|
||||||
|
async fn no_arg_implicit_return_error();
|
||||||
|
#[doc = "attr"]
|
||||||
|
async fn one_arg_implicit_return_error(one: String);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tarpc::server]
|
||||||
|
impl Syntax for () {
|
||||||
|
#[deny(warnings)]
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
async fn TestCamelCaseDoesntConflict(self, _: context::Context) {}
|
||||||
|
|
||||||
|
async fn hello(self, _: context::Context) -> String {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn attr(self, _: context::Context, _s: String) -> String {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn no_args_no_return(self, _: context::Context) {}
|
||||||
|
|
||||||
|
async fn no_args(self, _: context::Context) -> () {}
|
||||||
|
|
||||||
|
async fn one_arg(self, _: context::Context, _one: String) -> i32 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {}
|
||||||
|
|
||||||
|
async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn no_args_ret_error(self, _: context::Context) -> i32 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn no_arg_implicit_return_error(self, _: context::Context) {}
|
||||||
|
|
||||||
|
async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -29,6 +29,38 @@ fn att_service_trait() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(non_camel_case_types)]
|
||||||
|
#[test]
|
||||||
|
fn raw_idents() {
|
||||||
|
use futures::future::{ready, Ready};
|
||||||
|
|
||||||
|
type r#yield = String;
|
||||||
|
|
||||||
|
#[tarpc::service]
|
||||||
|
trait r#trait {
|
||||||
|
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
|
||||||
|
async fn r#fn(r#impl: r#yield) -> r#yield;
|
||||||
|
async fn r#async();
|
||||||
|
}
|
||||||
|
|
||||||
|
impl r#trait for () {
|
||||||
|
type AwaitFut = Ready<(r#yield, i32)>;
|
||||||
|
fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut {
|
||||||
|
ready((r#struct, r#enum))
|
||||||
|
}
|
||||||
|
|
||||||
|
type FnFut = Ready<r#yield>;
|
||||||
|
fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut {
|
||||||
|
ready(r#impl)
|
||||||
|
}
|
||||||
|
|
||||||
|
type AsyncFut = Ready<()>;
|
||||||
|
fn r#async(self, _: context::Context) -> Self::AsyncFut {
|
||||||
|
ready(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn syntax() {
|
fn syntax() {
|
||||||
#[tarpc::service]
|
#[tarpc::service]
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "tarpc"
|
name = "tarpc"
|
||||||
version = "0.20.0"
|
version = "0.27.1"
|
||||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
@@ -16,42 +16,61 @@ description = "An RPC framework for Rust with a focus on ease of use."
|
|||||||
default = []
|
default = []
|
||||||
|
|
||||||
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
|
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
|
||||||
tokio1 = []
|
tokio1 = ["tokio/rt-multi-thread"]
|
||||||
serde-transport = ["tokio-serde", "tokio-util/codec"]
|
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
||||||
tcp = ["tokio/net", "tokio/stream"]
|
serde-transport-json = ["tokio-serde/json"]
|
||||||
|
serde-transport-bincode = ["tokio-serde/bincode"]
|
||||||
|
tcp = ["tokio/net"]
|
||||||
|
|
||||||
full = ["serde1", "tokio1", "serde-transport", "tcp"]
|
full = ["serde1", "tokio1", "serde-transport", "serde-transport-json", "serde-transport-bincode", "tcp"]
|
||||||
|
|
||||||
[badges]
|
[badges]
|
||||||
travis-ci = { repository = "google/tarpc" }
|
travis-ci = { repository = "google/tarpc" }
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
anyhow = "1.0"
|
||||||
fnv = "1.0"
|
fnv = "1.0"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
humantime = "1.0"
|
humantime = "2.0"
|
||||||
log = "0.4"
|
pin-project = "1.0"
|
||||||
pin-project = "0.4"
|
rand = "0.8"
|
||||||
raii-counter = "0.2"
|
|
||||||
rand = "0.7"
|
|
||||||
tokio = { version = "0.2", features = ["time"] }
|
|
||||||
serde = { optional = true, version = "1.0", features = ["derive"] }
|
serde = { optional = true, version = "1.0", features = ["derive"] }
|
||||||
tokio-util = { optional = true, version = "0.2" }
|
static_assertions = "1.1.0"
|
||||||
tarpc-plugins = { path = "../plugins", version = "0.7" }
|
tarpc-plugins = { path = "../plugins", version = "0.12" }
|
||||||
tokio-serde = { optional = true, version = "0.6" }
|
thiserror = "1.0"
|
||||||
|
tokio = { version = "1", features = ["time"] }
|
||||||
|
tokio-util = { version = "0.6.3", features = ["time"] }
|
||||||
|
tokio-serde = { optional = true, version = "0.8" }
|
||||||
|
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
|
||||||
|
tracing-opentelemetry = { version = "0.15", default-features = false }
|
||||||
|
opentelemetry = { version = "0.16", default-features = false }
|
||||||
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
assert_matches = "1.0"
|
assert_matches = "1.4"
|
||||||
bytes = { version = "0.5", features = ["serde"] }
|
bincode = "1.3"
|
||||||
env_logger = "0.6"
|
bytes = { version = "1", features = ["serde"] }
|
||||||
futures = "0.3"
|
flate2 = "1.0"
|
||||||
humantime = "1.0"
|
futures-test = "0.3"
|
||||||
log = "0.4"
|
opentelemetry = { version = "0.16", default-features = false, features = ["rt-tokio"] }
|
||||||
|
opentelemetry-jaeger = { version = "0.15", features = ["rt-tokio"] }
|
||||||
pin-utils = "0.1.0-alpha"
|
pin-utils = "0.1.0-alpha"
|
||||||
tokio = { version = "0.2", features = ["full"] }
|
serde_bytes = "0.11"
|
||||||
tokio-serde = { version = "0.6", features = ["json"] }
|
tracing-subscriber = "0.2"
|
||||||
|
tokio = { version = "1", features = ["full", "test-util"] }
|
||||||
|
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
|
||||||
|
trybuild = "1.0"
|
||||||
|
|
||||||
|
[package.metadata.docs.rs]
|
||||||
|
all-features = true
|
||||||
|
rustdoc-args = ["--cfg", "docsrs"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "server_calling_server"
|
name = "compression"
|
||||||
|
required-features = ["serde-transport", "tcp"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "tracing"
|
||||||
required-features = ["full"]
|
required-features = ["full"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
@@ -62,3 +81,14 @@ required-features = ["full"]
|
|||||||
name = "pubsub"
|
name = "pubsub"
|
||||||
required-features = ["full"]
|
required-features = ["full"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "custom_transport"
|
||||||
|
required-features = ["serde1", "tokio1", "serde-transport"]
|
||||||
|
|
||||||
|
[[test]]
|
||||||
|
name = "service_functional"
|
||||||
|
required-features = ["serde-transport"]
|
||||||
|
|
||||||
|
[[test]]
|
||||||
|
name = "dataservice"
|
||||||
|
required-features = ["serde-transport", "tcp"]
|
||||||
|
|||||||
128
tarpc/examples/compression.rs
Normal file
128
tarpc/examples/compression.rs
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
|
||||||
|
use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_bytes::ByteBuf;
|
||||||
|
use std::{io, io::Read, io::Write};
|
||||||
|
use tarpc::{
|
||||||
|
client, context,
|
||||||
|
serde_transport::tcp,
|
||||||
|
server::{BaseChannel, Channel},
|
||||||
|
tokio_serde::formats::Bincode,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Type of compression that should be enabled on the request. The transport is free to ignore this.
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)]
|
||||||
|
pub enum CompressionAlgorithm {
|
||||||
|
Deflate,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
pub enum CompressedMessage<T> {
|
||||||
|
Uncompressed(T),
|
||||||
|
Compressed {
|
||||||
|
algorithm: CompressionAlgorithm,
|
||||||
|
payload: ByteBuf,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
enum CompressionType {
|
||||||
|
Uncompressed,
|
||||||
|
Compressed,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn compress<T>(message: T) -> io::Result<CompressedMessage<T>>
|
||||||
|
where
|
||||||
|
T: Serialize,
|
||||||
|
{
|
||||||
|
let message = serialize(message)?;
|
||||||
|
let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
|
||||||
|
encoder.write_all(&message).unwrap();
|
||||||
|
let compressed = encoder.finish()?;
|
||||||
|
Ok(CompressedMessage::Compressed {
|
||||||
|
algorithm: CompressionAlgorithm::Deflate,
|
||||||
|
payload: ByteBuf::from(compressed),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn decompress<T>(message: CompressedMessage<T>) -> io::Result<T>
|
||||||
|
where
|
||||||
|
for<'a> T: Deserialize<'a>,
|
||||||
|
{
|
||||||
|
match message {
|
||||||
|
CompressedMessage::Compressed { algorithm, payload } => {
|
||||||
|
if algorithm != CompressionAlgorithm::Deflate {
|
||||||
|
return Err(io::Error::new(
|
||||||
|
io::ErrorKind::InvalidData,
|
||||||
|
format!("Compression algorithm {:?} not supported", algorithm),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let mut deflater = DeflateDecoder::new(payload.as_slice());
|
||||||
|
let mut payload = ByteBuf::new();
|
||||||
|
deflater.read_to_end(&mut payload)?;
|
||||||
|
let message = deserialize(payload)?;
|
||||||
|
Ok(message)
|
||||||
|
}
|
||||||
|
CompressedMessage::Uncompressed(message) => Ok(message),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn serialize<T: Serialize>(t: T) -> io::Result<ByteBuf> {
|
||||||
|
bincode::serialize(&t)
|
||||||
|
.map(ByteBuf::from)
|
||||||
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deserialize<D>(message: ByteBuf) -> io::Result<D>
|
||||||
|
where
|
||||||
|
for<'a> D: Deserialize<'a>,
|
||||||
|
{
|
||||||
|
bincode::deserialize(message.as_ref()).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_compression<In, Out>(
|
||||||
|
transport: impl Stream<Item = io::Result<CompressedMessage<In>>>
|
||||||
|
+ Sink<CompressedMessage<Out>, Error = io::Error>,
|
||||||
|
) -> impl Stream<Item = io::Result<In>> + Sink<Out, Error = io::Error>
|
||||||
|
where
|
||||||
|
Out: Serialize,
|
||||||
|
for<'a> In: Deserialize<'a>,
|
||||||
|
{
|
||||||
|
transport.with(compress).and_then(decompress)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tarpc::service]
|
||||||
|
pub trait World {
|
||||||
|
async fn hello(name: String) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct HelloServer;
|
||||||
|
|
||||||
|
#[tarpc::server]
|
||||||
|
impl World for HelloServer {
|
||||||
|
async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
|
format!("Hey, {}!", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
let mut incoming = tcp::listen("localhost:0", Bincode::default).await?;
|
||||||
|
let addr = incoming.local_addr();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let transport = incoming.next().await.unwrap().unwrap();
|
||||||
|
BaseChannel::with_defaults(add_compression(transport))
|
||||||
|
.execute(HelloServer.serve())
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let transport = tcp::connect(addr, Bincode::default).await?;
|
||||||
|
let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn();
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{}",
|
||||||
|
client.hello(context::current(), "friend".into()).await?
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
47
tarpc/examples/custom_transport.rs
Normal file
47
tarpc/examples/custom_transport.rs
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
use tarpc::serde_transport as transport;
|
||||||
|
use tarpc::server::{BaseChannel, Channel};
|
||||||
|
use tarpc::{context::Context, tokio_serde::formats::Bincode};
|
||||||
|
use tokio::net::{UnixListener, UnixStream};
|
||||||
|
use tokio_util::codec::length_delimited::LengthDelimitedCodec;
|
||||||
|
|
||||||
|
#[tarpc::service]
|
||||||
|
pub trait PingService {
|
||||||
|
async fn ping();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct Service;
|
||||||
|
|
||||||
|
#[tarpc::server]
|
||||||
|
impl PingService for Service {
|
||||||
|
async fn ping(self, _: Context) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
let bind_addr = "/tmp/tarpc_on_unix_example.sock";
|
||||||
|
|
||||||
|
let _ = std::fs::remove_file(bind_addr);
|
||||||
|
|
||||||
|
let listener = UnixListener::bind(bind_addr).unwrap();
|
||||||
|
let codec_builder = LengthDelimitedCodec::builder();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
let (conn, _addr) = listener.accept().await.unwrap();
|
||||||
|
let framed = codec_builder.new_framed(conn);
|
||||||
|
let transport = transport::new(framed, Bincode::default());
|
||||||
|
|
||||||
|
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let conn = UnixStream::connect(bind_addr).await?;
|
||||||
|
let transport = transport::new(codec_builder.new_framed(conn), Bincode::default());
|
||||||
|
PingServiceClient::new(Default::default(), transport)
|
||||||
|
.spawn()
|
||||||
|
.ping(tarpc::context::current())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -4,192 +4,357 @@
|
|||||||
// license that can be found in the LICENSE file or at
|
// license that can be found in the LICENSE file or at
|
||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
/// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher"
|
||||||
|
/// port. Because both publishers and subscribers initiate their connections to the PubSub
|
||||||
|
/// server, the server requires no prior knowledge of either publishers or subscribers.
|
||||||
|
///
|
||||||
|
/// - Subscribers connect to the server on the server's "subscriber" port. Once a connection is
|
||||||
|
/// established, the server acts as the client of the Subscriber service, initially requesting
|
||||||
|
/// the topics the subscriber is interested in, and subsequently sending topical messages to the
|
||||||
|
/// subscriber.
|
||||||
|
///
|
||||||
|
/// - Publishers connect to the server on the "publisher" port and, once connected, they send
|
||||||
|
/// topical messages via Publisher service to the server. The server then broadcasts each
|
||||||
|
/// messages to all clients subscribed to the topic of that message.
|
||||||
|
///
|
||||||
|
/// Subscriber Publisher PubSub Server
|
||||||
|
/// T1 | | |
|
||||||
|
/// T2 |-----Connect------------------------------------------------------>|
|
||||||
|
/// T3 | | |
|
||||||
|
/// T2 |<-------------------------------------------------------Topics-----|
|
||||||
|
/// T2 |-----(OK) Topics-------------------------------------------------->|
|
||||||
|
/// T3 | | |
|
||||||
|
/// T4 | |-----Connect-------------------->|
|
||||||
|
/// T5 | | |
|
||||||
|
/// T6 | |-----Publish-------------------->|
|
||||||
|
/// T7 | | |
|
||||||
|
/// T8 |<------------------------------------------------------Receive-----|
|
||||||
|
/// T9 |-----(OK) Receive------------------------------------------------->|
|
||||||
|
/// T10 | | |
|
||||||
|
/// T11 | |<--------------(OK) Publish------|
|
||||||
|
use anyhow::anyhow;
|
||||||
use futures::{
|
use futures::{
|
||||||
future::{self, Ready},
|
channel::oneshot,
|
||||||
|
future::{self, AbortHandle},
|
||||||
prelude::*,
|
prelude::*,
|
||||||
Future,
|
|
||||||
};
|
};
|
||||||
use publisher::Publisher as _;
|
use publisher::Publisher as _;
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
|
env,
|
||||||
|
error::Error,
|
||||||
io,
|
io,
|
||||||
net::SocketAddr,
|
net::SocketAddr,
|
||||||
pin::Pin,
|
sync::{Arc, Mutex, RwLock},
|
||||||
sync::{Arc, Mutex},
|
|
||||||
time::Duration,
|
|
||||||
};
|
};
|
||||||
use subscriber::Subscriber as _;
|
use subscriber::Subscriber as _;
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
client, context,
|
client, context,
|
||||||
server::{self, Handler},
|
serde_transport::tcp,
|
||||||
|
server::{self, Channel},
|
||||||
};
|
};
|
||||||
|
use tokio::net::ToSocketAddrs;
|
||||||
use tokio_serde::formats::Json;
|
use tokio_serde::formats::Json;
|
||||||
|
use tracing::info;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
pub mod subscriber {
|
pub mod subscriber {
|
||||||
#[tarpc::service]
|
#[tarpc::service]
|
||||||
pub trait Subscriber {
|
pub trait Subscriber {
|
||||||
async fn receive(message: String);
|
async fn topics() -> Vec<String>;
|
||||||
|
async fn receive(topic: String, message: String);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub mod publisher {
|
pub mod publisher {
|
||||||
use std::net::SocketAddr;
|
|
||||||
|
|
||||||
#[tarpc::service]
|
#[tarpc::service]
|
||||||
pub trait Publisher {
|
pub trait Publisher {
|
||||||
async fn broadcast(message: String);
|
async fn publish(topic: String, message: String);
|
||||||
async fn subscribe(id: u32, address: SocketAddr) -> Result<(), String>;
|
|
||||||
async fn unsubscribe(id: u32);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct Subscriber {
|
struct Subscriber {
|
||||||
id: u32,
|
local_addr: SocketAddr,
|
||||||
|
topics: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tarpc::server]
|
||||||
impl subscriber::Subscriber for Subscriber {
|
impl subscriber::Subscriber for Subscriber {
|
||||||
type ReceiveFut = Ready<()>;
|
async fn topics(self, _: context::Context) -> Vec<String> {
|
||||||
|
self.topics.clone()
|
||||||
|
}
|
||||||
|
|
||||||
fn receive(self, _: context::Context, message: String) -> Self::ReceiveFut {
|
async fn receive(self, _: context::Context, topic: String, message: String) {
|
||||||
eprintln!("{} received message: {}", self.id, message);
|
info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage")
|
||||||
future::ready(())
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SubscriberHandle(AbortHandle);
|
||||||
|
|
||||||
|
impl Drop for SubscriberHandle {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.0.abort();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Subscriber {
|
impl Subscriber {
|
||||||
async fn listen(id: u32, config: server::Config) -> io::Result<SocketAddr> {
|
async fn connect(
|
||||||
let incoming = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
publisher_addr: impl ToSocketAddrs,
|
||||||
.await?
|
topics: Vec<String>,
|
||||||
.filter_map(|r| future::ready(r.ok()));
|
) -> anyhow::Result<SubscriberHandle> {
|
||||||
let addr = incoming.get_ref().local_addr();
|
let publisher = tcp::connect(publisher_addr, Json::default).await?;
|
||||||
tokio::spawn(
|
let local_addr = publisher.local_addr()?;
|
||||||
server::new(config)
|
let mut handler = server::BaseChannel::with_defaults(publisher).requests();
|
||||||
.incoming(incoming)
|
let subscriber = Subscriber { local_addr, topics };
|
||||||
.take(1)
|
// The first request is for the topics being subscribed to.
|
||||||
.respond_with(Subscriber { id }.serve()),
|
match handler.next().await {
|
||||||
);
|
Some(init_topics) => init_topics?.execute(subscriber.clone().serve()).await,
|
||||||
Ok(addr)
|
None => {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"[{}] Server never initialized the subscriber.",
|
||||||
|
local_addr
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve()));
|
||||||
|
tokio::spawn(async move {
|
||||||
|
match handler.await {
|
||||||
|
Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
Ok(SubscriberHandle(abort_handle))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Subscription {
|
||||||
|
subscriber: subscriber::SubscriberClient,
|
||||||
|
topics: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct Publisher {
|
struct Publisher {
|
||||||
clients: Arc<Mutex<HashMap<u32, subscriber::SubscriberClient>>>,
|
clients: Arc<Mutex<HashMap<SocketAddr, Subscription>>>,
|
||||||
|
subscriptions: Arc<RwLock<HashMap<String, HashMap<SocketAddr, subscriber::SubscriberClient>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PublisherAddrs {
|
||||||
|
publisher: SocketAddr,
|
||||||
|
subscriptions: SocketAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Publisher {
|
impl Publisher {
|
||||||
fn new() -> Publisher {
|
async fn start(self) -> io::Result<PublisherAddrs> {
|
||||||
Publisher {
|
let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?;
|
||||||
clients: Arc::new(Mutex::new(HashMap::new())),
|
|
||||||
}
|
let publisher_addrs = PublisherAddrs {
|
||||||
|
publisher: connecting_publishers.local_addr(),
|
||||||
|
subscriptions: self.clone().start_subscription_manager().await?,
|
||||||
|
};
|
||||||
|
|
||||||
|
info!(publisher_addr = %publisher_addrs.publisher, "listening for publishers.",);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// Because this is just an example, we know there will only be one publisher. In more
|
||||||
|
// realistic code, this would be a loop to continually accept new publisher
|
||||||
|
// connections.
|
||||||
|
let publisher = connecting_publishers.next().await.unwrap().unwrap();
|
||||||
|
info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected.");
|
||||||
|
|
||||||
|
server::BaseChannel::with_defaults(publisher)
|
||||||
|
.execute(self.serve())
|
||||||
|
.await
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(publisher_addrs)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl publisher::Publisher for Publisher {
|
async fn start_subscription_manager(mut self) -> io::Result<SocketAddr> {
|
||||||
type BroadcastFut = Pin<Box<dyn Future<Output = ()> + Send>>;
|
let mut connecting_subscribers = tcp::listen("localhost:0", Json::default)
|
||||||
|
.await?
|
||||||
|
.filter_map(|r| future::ready(r.ok()));
|
||||||
|
let new_subscriber_addr = connecting_subscribers.get_ref().local_addr();
|
||||||
|
info!(?new_subscriber_addr, "listening for subscribers.");
|
||||||
|
|
||||||
fn broadcast(self, _: context::Context, message: String) -> Self::BroadcastFut {
|
tokio::spawn(async move {
|
||||||
async fn broadcast(
|
while let Some(conn) = connecting_subscribers.next().await {
|
||||||
clients: Arc<Mutex<HashMap<u32, subscriber::SubscriberClient>>>,
|
let subscriber_addr = conn.peer_addr().unwrap();
|
||||||
message: String,
|
|
||||||
) {
|
let tarpc::client::NewClient {
|
||||||
let mut clients = clients.lock().unwrap().clone();
|
client: subscriber,
|
||||||
for client in clients.values_mut() {
|
dispatch,
|
||||||
// Ignore failing subscribers. In a real pubsub,
|
} = subscriber::SubscriberClient::new(client::Config::default(), conn);
|
||||||
// you'd want to continually retry until subscribers
|
let (ready_tx, ready) = oneshot::channel();
|
||||||
// ack.
|
self.clone()
|
||||||
let _ = client.receive(context::current(), message.clone()).await;
|
.start_subscriber_gc(subscriber_addr, dispatch, ready);
|
||||||
|
|
||||||
|
// Populate the topics
|
||||||
|
self.initialize_subscription(subscriber_addr, subscriber)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Signal that initialization is done.
|
||||||
|
ready_tx.send(()).unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(new_subscriber_addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn initialize_subscription(
|
||||||
|
&mut self,
|
||||||
|
subscriber_addr: SocketAddr,
|
||||||
|
subscriber: subscriber::SubscriberClient,
|
||||||
|
) {
|
||||||
|
// Populate the topics
|
||||||
|
if let Ok(topics) = subscriber.topics(context::current()).await {
|
||||||
|
self.clients.lock().unwrap().insert(
|
||||||
|
subscriber_addr,
|
||||||
|
Subscription {
|
||||||
|
subscriber: subscriber.clone(),
|
||||||
|
topics: topics.clone(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
info!(%subscriber_addr, ?topics, "subscribed to new topics");
|
||||||
|
let mut subscriptions = self.subscriptions.write().unwrap();
|
||||||
|
for topic in topics {
|
||||||
|
subscriptions
|
||||||
|
.entry(topic)
|
||||||
|
.or_insert_with(HashMap::new)
|
||||||
|
.insert(subscriber_addr, subscriber.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
broadcast(self.clients.clone(), message).boxed()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type SubscribeFut = Pin<Box<dyn Future<Output = Result<(), String>> + Send>>;
|
fn start_subscriber_gc<E: Error>(
|
||||||
|
self,
|
||||||
fn subscribe(self, _: context::Context, id: u32, addr: SocketAddr) -> Self::SubscribeFut {
|
subscriber_addr: SocketAddr,
|
||||||
async fn subscribe(
|
client_dispatch: impl Future<Output = Result<(), E>> + Send + 'static,
|
||||||
clients: Arc<Mutex<HashMap<u32, subscriber::SubscriberClient>>>,
|
subscriber_ready: oneshot::Receiver<()>,
|
||||||
id: u32,
|
) {
|
||||||
addr: SocketAddr,
|
tokio::spawn(async move {
|
||||||
) -> io::Result<()> {
|
if let Err(e) = client_dispatch.await {
|
||||||
let conn = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
|
info!(
|
||||||
let subscriber =
|
%subscriber_addr,
|
||||||
subscriber::SubscriberClient::new(client::Config::default(), conn).spawn()?;
|
error = %e,
|
||||||
eprintln!("Subscribing {}.", id);
|
"subscriber connection broken");
|
||||||
clients.lock().unwrap().insert(id, subscriber);
|
}
|
||||||
Ok(())
|
// Don't clean up the subscriber until initialization is done.
|
||||||
}
|
let _ = subscriber_ready.await;
|
||||||
|
if let Some(subscription) = self.clients.lock().unwrap().remove(&subscriber_addr) {
|
||||||
subscribe(Arc::clone(&self.clients), id, addr)
|
info!(
|
||||||
.map_err(|e| e.to_string())
|
"[{} unsubscribing from topics: {:?}",
|
||||||
.boxed()
|
subscriber_addr, subscription.topics
|
||||||
}
|
);
|
||||||
|
let mut subscriptions = self.subscriptions.write().unwrap();
|
||||||
type UnsubscribeFut = Pin<Box<dyn Future<Output = ()> + Send>>;
|
for topic in subscription.topics {
|
||||||
|
let subscribers = subscriptions.get_mut(&topic).unwrap();
|
||||||
fn unsubscribe(self, _: context::Context, id: u32) -> Self::UnsubscribeFut {
|
subscribers.remove(&subscriber_addr);
|
||||||
eprintln!("Unsubscribing {}", id);
|
if subscribers.is_empty() {
|
||||||
let mut clients = self.clients.lock().unwrap();
|
subscriptions.remove(&topic);
|
||||||
if clients.remove(&id).is_none() {
|
}
|
||||||
eprintln!(
|
}
|
||||||
"Client {} not found. Existings clients: {:?}",
|
}
|
||||||
id, &*clients
|
});
|
||||||
);
|
|
||||||
}
|
|
||||||
future::ready(()).boxed()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tarpc::server]
|
||||||
async fn main() -> io::Result<()> {
|
impl publisher::Publisher for Publisher {
|
||||||
env_logger::init();
|
async fn publish(self, _: context::Context, topic: String, message: String) {
|
||||||
|
info!("received message to publish.");
|
||||||
let transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) {
|
||||||
.await?
|
None => return,
|
||||||
.filter_map(|r| future::ready(r.ok()));
|
Some(subscriptions) => subscriptions.clone(),
|
||||||
let publisher_addr = transport.get_ref().local_addr();
|
};
|
||||||
tokio::spawn(
|
let mut publications = Vec::new();
|
||||||
transport
|
for client in subscribers.values_mut() {
|
||||||
.take(1)
|
publications.push(client.receive(context::current(), topic.clone(), message.clone()));
|
||||||
.map(server::BaseChannel::with_defaults)
|
}
|
||||||
.respond_with(Publisher::new().serve()),
|
// Ignore failing subscribers. In a real pubsub, you'd want to continually retry until
|
||||||
);
|
// subscribers ack. Of course, a lot would be different in a real pubsub :)
|
||||||
|
for response in future::join_all(publications).await {
|
||||||
let subscriber1 = Subscriber::listen(0, server::Config::default()).await?;
|
if let Err(e) = response {
|
||||||
let subscriber2 = Subscriber::listen(1, server::Config::default()).await?;
|
info!("failed to broadcast to subscriber: {}", e);
|
||||||
|
}
|
||||||
let publisher_conn = tarpc::serde_transport::tcp::connect(publisher_addr, Json::default());
|
}
|
||||||
let publisher_conn = publisher_conn.await?;
|
|
||||||
let mut publisher =
|
|
||||||
publisher::PublisherClient::new(client::Config::default(), publisher_conn).spawn()?;
|
|
||||||
|
|
||||||
if let Err(e) = publisher
|
|
||||||
.subscribe(context::current(), 0, subscriber1)
|
|
||||||
.await?
|
|
||||||
{
|
|
||||||
eprintln!("Couldn't subscribe subscriber 0: {}", e);
|
|
||||||
}
|
|
||||||
if let Err(e) = publisher
|
|
||||||
.subscribe(context::current(), 1, subscriber2)
|
|
||||||
.await?
|
|
||||||
{
|
|
||||||
eprintln!("Couldn't subscribe subscriber 1: {}", e);
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
println!("Broadcasting...");
|
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
|
||||||
publisher
|
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||||
.broadcast(context::current(), "hello to all".to_string())
|
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||||
.await?;
|
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||||
publisher.unsubscribe(context::current(), 1).await?;
|
.with_service_name(service_name)
|
||||||
publisher
|
.with_max_packet_size(2usize.pow(13))
|
||||||
.broadcast(context::current(), "hi again".to_string())
|
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||||
.await?;
|
|
||||||
drop(publisher);
|
|
||||||
|
|
||||||
tokio::time::delay_for(Duration::from_millis(100)).await;
|
tracing_subscriber::registry()
|
||||||
println!("Done.");
|
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||||
|
.with(tracing_subscriber::fmt::layer())
|
||||||
|
.with(tracing_opentelemetry::layer().with_tracer(tracer))
|
||||||
|
.try_init()?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
init_tracing("Pub/Sub")?;
|
||||||
|
|
||||||
|
let addrs = Publisher {
|
||||||
|
clients: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
subscriptions: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
}
|
||||||
|
.start()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let _subscriber0 = Subscriber::connect(
|
||||||
|
addrs.subscriptions,
|
||||||
|
vec!["calculus".into(), "cool shorts".into()],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let _subscriber1 = Subscriber::connect(
|
||||||
|
addrs.subscriptions,
|
||||||
|
vec!["cool shorts".into(), "history".into()],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let publisher = publisher::PublisherClient::new(
|
||||||
|
client::Config::default(),
|
||||||
|
tcp::connect(addrs.publisher, Json::default).await?,
|
||||||
|
)
|
||||||
|
.spawn();
|
||||||
|
|
||||||
|
publisher
|
||||||
|
.publish(context::current(), "calculus".into(), "sqrt(2)".into())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
publisher
|
||||||
|
.publish(
|
||||||
|
context::current(),
|
||||||
|
"cool shorts".into(),
|
||||||
|
"hello to all".into(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
publisher
|
||||||
|
.publish(context::current(), "history".into(), "napoleon".to_string())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
drop(_subscriber0);
|
||||||
|
|
||||||
|
publisher
|
||||||
|
.publish(
|
||||||
|
context::current(),
|
||||||
|
"cool shorts".into(),
|
||||||
|
"hello to who?".into(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
opentelemetry::global::shutdown_tracer_provider();
|
||||||
|
info!("done.");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,16 +4,11 @@
|
|||||||
// license that can be found in the LICENSE file or at
|
// license that can be found in the LICENSE file or at
|
||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use futures::{
|
use futures::future::{self, Ready};
|
||||||
future::{self, Ready},
|
|
||||||
prelude::*,
|
|
||||||
};
|
|
||||||
use std::io;
|
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
client, context,
|
client, context,
|
||||||
server::{BaseChannel, Channel},
|
server::{self, Channel},
|
||||||
};
|
};
|
||||||
use tokio_serde::formats::Json;
|
|
||||||
|
|
||||||
/// This is the service definition. It looks a lot like a trait definition.
|
/// This is the service definition. It looks a lot like a trait definition.
|
||||||
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||||
@@ -39,41 +34,22 @@ impl World for HelloServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> io::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
// tarpc_json_transport is provided by the associated crate json_transport. It makes it
|
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||||
// easy to start up a serde-powered JSON serialization strategy over TCP.
|
|
||||||
let mut transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default).await?;
|
|
||||||
let addr = transport.local_addr();
|
|
||||||
|
|
||||||
let server = async move {
|
let server = server::BaseChannel::with_defaults(server_transport);
|
||||||
// For this example, we're just going to wait for one connection.
|
tokio::spawn(server.execute(HelloServer.serve()));
|
||||||
let client = transport.next().await.unwrap().unwrap();
|
|
||||||
|
|
||||||
// `Channel` is a trait representing a server-side connection. It is a trait to allow
|
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||||
// for some channels to be instrumented: for example, to track the number of open connections.
|
// that takes a config and any Transport as input.
|
||||||
// BaseChannel is the most basic channel, simply wrapping a transport with no added
|
let client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
||||||
// functionality.
|
|
||||||
BaseChannel::with_defaults(client)
|
|
||||||
// serve_world is generated by the tarpc::service attribute. It takes as input any type
|
|
||||||
// implementing the generated World trait.
|
|
||||||
.respond_with(HelloServer.serve())
|
|
||||||
.execute()
|
|
||||||
.await;
|
|
||||||
};
|
|
||||||
tokio::spawn(server);
|
|
||||||
|
|
||||||
let transport = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
|
|
||||||
|
|
||||||
// WorldClient is generated by the tarpc::service attribute. It has a constructor `new` that
|
|
||||||
// takes a config and any Transport as input.
|
|
||||||
let mut client = WorldClient::new(client::Config::default(), transport).spawn()?;
|
|
||||||
|
|
||||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||||
// specifies a deadline and trace information which can be helpful in debugging requests.
|
// specifies a deadline and trace information which can be helpful in debugging requests.
|
||||||
let hello = client.hello(context::current(), "Stim".to_string()).await?;
|
let hello = client.hello(context::current(), "Stim".to_string()).await?;
|
||||||
|
|
||||||
eprintln!("{}", hello);
|
println!("{}", hello);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,101 +0,0 @@
|
|||||||
// Copyright 2018 Google LLC
|
|
||||||
//
|
|
||||||
// Use of this source code is governed by an MIT-style
|
|
||||||
// license that can be found in the LICENSE file or at
|
|
||||||
// https://opensource.org/licenses/MIT.
|
|
||||||
|
|
||||||
use crate::{add::Add as AddService, double::Double as DoubleService};
|
|
||||||
use futures::{
|
|
||||||
future::{self, Ready},
|
|
||||||
prelude::*,
|
|
||||||
};
|
|
||||||
use std::{io, pin::Pin};
|
|
||||||
use tarpc::{
|
|
||||||
client, context,
|
|
||||||
server::{Handler, Server},
|
|
||||||
};
|
|
||||||
use tokio_serde::formats::Json;
|
|
||||||
|
|
||||||
pub mod add {
|
|
||||||
#[tarpc::service]
|
|
||||||
pub trait Add {
|
|
||||||
/// Add two ints together.
|
|
||||||
async fn add(x: i32, y: i32) -> i32;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod double {
|
|
||||||
#[tarpc::service]
|
|
||||||
pub trait Double {
|
|
||||||
/// 2 * x
|
|
||||||
async fn double(x: i32) -> Result<i32, String>;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct AddServer;
|
|
||||||
|
|
||||||
impl AddService for AddServer {
|
|
||||||
type AddFut = Ready<i32>;
|
|
||||||
|
|
||||||
fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut {
|
|
||||||
future::ready(x + y)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct DoubleServer {
|
|
||||||
add_client: add::AddClient,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DoubleService for DoubleServer {
|
|
||||||
type DoubleFut = Pin<Box<dyn Future<Output = Result<i32, String>> + Send>>;
|
|
||||||
|
|
||||||
fn double(self, _: context::Context, x: i32) -> Self::DoubleFut {
|
|
||||||
async fn double(mut client: add::AddClient, x: i32) -> Result<i32, String> {
|
|
||||||
client
|
|
||||||
.add(context::current(), x, x)
|
|
||||||
.await
|
|
||||||
.map_err(|e| e.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
double(self.add_client.clone(), x).boxed()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() -> io::Result<()> {
|
|
||||||
env_logger::init();
|
|
||||||
|
|
||||||
let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
|
||||||
.await?
|
|
||||||
.filter_map(|r| future::ready(r.ok()));
|
|
||||||
let addr = add_listener.get_ref().local_addr();
|
|
||||||
let add_server = Server::default()
|
|
||||||
.incoming(add_listener)
|
|
||||||
.take(1)
|
|
||||||
.respond_with(AddServer.serve());
|
|
||||||
tokio::spawn(add_server);
|
|
||||||
|
|
||||||
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
|
|
||||||
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn()?;
|
|
||||||
|
|
||||||
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
|
||||||
.await?
|
|
||||||
.filter_map(|r| future::ready(r.ok()));
|
|
||||||
let addr = double_listener.get_ref().local_addr();
|
|
||||||
let double_server = tarpc::Server::default()
|
|
||||||
.incoming(double_listener)
|
|
||||||
.take(1)
|
|
||||||
.respond_with(DoubleServer { add_client }.serve());
|
|
||||||
tokio::spawn(double_server);
|
|
||||||
|
|
||||||
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default()).await?;
|
|
||||||
let mut double_client =
|
|
||||||
double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?;
|
|
||||||
|
|
||||||
for i in 1..=5 {
|
|
||||||
eprintln!("{:?}", double_client.double(context::current(), i).await?);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
113
tarpc/examples/tracing.rs
Normal file
113
tarpc/examples/tracing.rs
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
// Copyright 2018 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
use crate::{add::Add as AddService, double::Double as DoubleService};
|
||||||
|
use futures::{future, prelude::*};
|
||||||
|
use std::env;
|
||||||
|
use tarpc::{
|
||||||
|
client, context,
|
||||||
|
server::{incoming::Incoming, BaseChannel},
|
||||||
|
};
|
||||||
|
use tokio_serde::formats::Json;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
pub mod add {
|
||||||
|
#[tarpc::service]
|
||||||
|
pub trait Add {
|
||||||
|
/// Add two ints together.
|
||||||
|
async fn add(x: i32, y: i32) -> i32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub mod double {
|
||||||
|
#[tarpc::service]
|
||||||
|
pub trait Double {
|
||||||
|
/// 2 * x
|
||||||
|
async fn double(x: i32) -> Result<i32, String>;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct AddServer;
|
||||||
|
|
||||||
|
#[tarpc::server]
|
||||||
|
impl AddService for AddServer {
|
||||||
|
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
|
||||||
|
x + y
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct DoubleServer {
|
||||||
|
add_client: add::AddClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tarpc::server]
|
||||||
|
impl DoubleService for DoubleServer {
|
||||||
|
async fn double(self, _: context::Context, x: i32) -> Result<i32, String> {
|
||||||
|
self.add_client
|
||||||
|
.add(context::current(), x, x)
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||||
|
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||||
|
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||||
|
.with_service_name(service_name)
|
||||||
|
.with_max_packet_size(2usize.pow(13))
|
||||||
|
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||||
|
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||||
|
.with(tracing_subscriber::fmt::layer())
|
||||||
|
.with(tracing_opentelemetry::layer().with_tracer(tracer))
|
||||||
|
.try_init()?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
init_tracing("tarpc_tracing_example")?;
|
||||||
|
|
||||||
|
let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||||
|
.await?
|
||||||
|
.filter_map(|r| future::ready(r.ok()));
|
||||||
|
let addr = add_listener.get_ref().local_addr();
|
||||||
|
let add_server = add_listener
|
||||||
|
.map(BaseChannel::with_defaults)
|
||||||
|
.take(1)
|
||||||
|
.execute(AddServer.serve());
|
||||||
|
tokio::spawn(add_server);
|
||||||
|
|
||||||
|
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||||
|
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn();
|
||||||
|
|
||||||
|
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||||
|
.await?
|
||||||
|
.filter_map(|r| future::ready(r.ok()));
|
||||||
|
let addr = double_listener.get_ref().local_addr();
|
||||||
|
let double_server = double_listener
|
||||||
|
.map(BaseChannel::with_defaults)
|
||||||
|
.take(1)
|
||||||
|
.execute(DoubleServer { add_client }.serve());
|
||||||
|
tokio::spawn(double_server);
|
||||||
|
|
||||||
|
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||||
|
let double_client =
|
||||||
|
double::DoubleClient::new(client::Config::default(), to_double_server).spawn();
|
||||||
|
|
||||||
|
let ctx = context::current();
|
||||||
|
for _ in 1..=5 {
|
||||||
|
tracing::info!("{:?}", double_client.double(ctx, 1).await?);
|
||||||
|
}
|
||||||
|
|
||||||
|
opentelemetry::global::shutdown_tracer_provider();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
919
tarpc/src/client.rs
Normal file
919
tarpc/src/client.rs
Normal file
@@ -0,0 +1,919 @@
|
|||||||
|
// Copyright 2018 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
//! Provides a client that connects to a server and sends multiplexed requests.
|
||||||
|
|
||||||
|
mod in_flight_requests;
|
||||||
|
|
||||||
|
use crate::{context, trace, ClientMessage, Request, Response, ServerError, Transport};
|
||||||
|
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||||
|
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
|
||||||
|
use pin_project::pin_project;
|
||||||
|
use std::{
|
||||||
|
convert::TryFrom,
|
||||||
|
error::Error,
|
||||||
|
fmt, mem,
|
||||||
|
pin::Pin,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicUsize, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
use tracing::Span;
|
||||||
|
|
||||||
|
/// Settings that control the behavior of the client.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub struct Config {
|
||||||
|
/// The number of requests that can be in flight at once.
|
||||||
|
/// `max_in_flight_requests` controls the size of the map used by the client
|
||||||
|
/// for storing pending requests.
|
||||||
|
pub max_in_flight_requests: usize,
|
||||||
|
/// The number of requests that can be buffered client-side before being sent.
|
||||||
|
/// `pending_requests_buffer` controls the size of the channel clients use
|
||||||
|
/// to communicate with the request dispatch task.
|
||||||
|
pub pending_request_buffer: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Config {
|
||||||
|
fn default() -> Self {
|
||||||
|
Config {
|
||||||
|
max_in_flight_requests: 1_000,
|
||||||
|
pending_request_buffer: 100,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A channel and dispatch pair. The dispatch drives the sending and receiving of requests
|
||||||
|
/// and must be polled continuously or spawned.
|
||||||
|
pub struct NewClient<C, D> {
|
||||||
|
/// The new client.
|
||||||
|
pub client: C,
|
||||||
|
/// The client's dispatch.
|
||||||
|
pub dispatch: D,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C, D, E> NewClient<C, D>
|
||||||
|
where
|
||||||
|
D: Future<Output = Result<(), E>> + Send + 'static,
|
||||||
|
E: std::error::Error + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
/// Helper method to spawn the dispatch on the default executor.
|
||||||
|
#[cfg(feature = "tokio1")]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||||
|
pub fn spawn(self) -> C {
|
||||||
|
let dispatch = self.dispatch.unwrap_or_else(move |e| {
|
||||||
|
let e = anyhow::Error::new(e);
|
||||||
|
tracing::warn!("Connection broken: {:?}", e);
|
||||||
|
});
|
||||||
|
tokio::spawn(dispatch);
|
||||||
|
self.client
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C, D> fmt::Debug for NewClient<C, D> {
|
||||||
|
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(fmt, "NewClient")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[allow(clippy::no_effect)]
|
||||||
|
const CHECK_USIZE: () = {
|
||||||
|
if std::mem::size_of::<usize>() > std::mem::size_of::<u64>() {
|
||||||
|
// TODO: replace this with panic!() as soon as RFC 2345 gets stabilized
|
||||||
|
["usize is too big to fit in u64"][42];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Handles communication from the client to request dispatch.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Channel<Req, Resp> {
|
||||||
|
to_dispatch: mpsc::Sender<DispatchRequest<Req, Resp>>,
|
||||||
|
/// Channel to send a cancel message to the dispatcher.
|
||||||
|
cancellation: RequestCancellation,
|
||||||
|
/// The ID to use for the next request to stage.
|
||||||
|
next_request_id: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp> Clone for Channel<Req, Resp> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
to_dispatch: self.to_dispatch.clone(),
|
||||||
|
cancellation: self.cancellation.clone(),
|
||||||
|
next_request_id: self.next_request_id.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp> Channel<Req, Resp> {
|
||||||
|
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
||||||
|
/// resolves to the response.
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "RPC",
|
||||||
|
skip(self, ctx, request_name, request),
|
||||||
|
fields(
|
||||||
|
rpc.trace_id = tracing::field::Empty,
|
||||||
|
otel.kind = "client",
|
||||||
|
otel.name = request_name)
|
||||||
|
)]
|
||||||
|
pub async fn call(
|
||||||
|
&self,
|
||||||
|
mut ctx: context::Context,
|
||||||
|
request_name: &str,
|
||||||
|
request: Req,
|
||||||
|
) -> Result<Resp, RpcError> {
|
||||||
|
let span = Span::current();
|
||||||
|
ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
|
||||||
|
tracing::warn!(
|
||||||
|
"OpenTelemetry subscriber not installed; making unsampled child context."
|
||||||
|
);
|
||||||
|
ctx.trace_context.new_child()
|
||||||
|
});
|
||||||
|
span.record("rpc.trace_id", &tracing::field::display(ctx.trace_id()));
|
||||||
|
let (response_completion, mut response) = oneshot::channel();
|
||||||
|
let request_id =
|
||||||
|
u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||||
|
|
||||||
|
// ResponseGuard impls Drop to cancel in-flight requests. It should be created before
|
||||||
|
// sending out the request; otherwise, the response future could be dropped after the
|
||||||
|
// request is sent out but before ResponseGuard is created, rendering the cancellation
|
||||||
|
// logic inactive.
|
||||||
|
let response_guard = ResponseGuard {
|
||||||
|
response: &mut response,
|
||||||
|
request_id,
|
||||||
|
cancellation: &self.cancellation,
|
||||||
|
};
|
||||||
|
self.to_dispatch
|
||||||
|
.send(DispatchRequest {
|
||||||
|
ctx,
|
||||||
|
span,
|
||||||
|
request_id,
|
||||||
|
request,
|
||||||
|
response_completion,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
|
||||||
|
response_guard.response().await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A server response that is completed by request dispatch when the corresponding response
|
||||||
|
/// arrives off the wire.
|
||||||
|
struct ResponseGuard<'a, Resp> {
|
||||||
|
response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
|
||||||
|
cancellation: &'a RequestCancellation,
|
||||||
|
request_id: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
|
||||||
|
/// rather cross-cutting errors that can always occur.
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum RpcError {
|
||||||
|
/// The client disconnected from the server.
|
||||||
|
#[error("the client disconnected from the server")]
|
||||||
|
Disconnected,
|
||||||
|
/// The request exceeded its deadline.
|
||||||
|
#[error("the request exceeded its deadline")]
|
||||||
|
DeadlineExceeded,
|
||||||
|
/// The server aborted request processing.
|
||||||
|
#[error("the server aborted request processing")]
|
||||||
|
Server(#[from] ServerError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<DeadlineExceededError> for RpcError {
|
||||||
|
fn from(_: DeadlineExceededError) -> Self {
|
||||||
|
RpcError::DeadlineExceeded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Resp> ResponseGuard<'_, Resp> {
|
||||||
|
async fn response(mut self) -> Result<Resp, RpcError> {
|
||||||
|
let response = (&mut self.response).await;
|
||||||
|
// Cancel drop logic once a response has been received.
|
||||||
|
mem::forget(self);
|
||||||
|
match response {
|
||||||
|
Ok(resp) => Ok(resp?.message?),
|
||||||
|
Err(oneshot::error::RecvError { .. }) => {
|
||||||
|
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||||
|
// there's nothing listening on the other side, so there's no point in
|
||||||
|
// propagating cancellation.
|
||||||
|
Err(RpcError::Disconnected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancels the request when dropped, if not already complete.
|
||||||
|
impl<Resp> Drop for ResponseGuard<'_, Resp> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
// The receiver needs to be closed to handle the edge case that the request has not
|
||||||
|
// yet been received by the dispatch task. It is possible for the cancel message to
|
||||||
|
// arrive before the request itself, in which case the request could get stuck in the
|
||||||
|
// dispatch map forever if the server never responds (e.g. if the server dies while
|
||||||
|
// responding). Even if the server does respond, it will have unnecessarily done work
|
||||||
|
// for a client no longer waiting for a response. To avoid this, the dispatch task
|
||||||
|
// checks if the receiver is closed before inserting the request in the map. By
|
||||||
|
// closing the receiver before sending the cancel message, it is guaranteed that if the
|
||||||
|
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||||
|
// receiver as closed.
|
||||||
|
self.response.close();
|
||||||
|
self.cancellation.cancel(self.request_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the
|
||||||
|
/// channel.
|
||||||
|
pub fn new<Req, Resp, C>(
|
||||||
|
config: Config,
|
||||||
|
transport: C,
|
||||||
|
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C>>
|
||||||
|
where
|
||||||
|
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||||
|
{
|
||||||
|
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
|
||||||
|
let (cancellation, canceled_requests) = cancellations();
|
||||||
|
let canceled_requests = canceled_requests;
|
||||||
|
|
||||||
|
NewClient {
|
||||||
|
client: Channel {
|
||||||
|
to_dispatch,
|
||||||
|
cancellation,
|
||||||
|
next_request_id: Arc::new(AtomicUsize::new(0)),
|
||||||
|
},
|
||||||
|
dispatch: RequestDispatch {
|
||||||
|
config,
|
||||||
|
canceled_requests,
|
||||||
|
transport: transport.fuse(),
|
||||||
|
in_flight_requests: InFlightRequests::default(),
|
||||||
|
pending_requests,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
|
||||||
|
/// and dispatching responses to the appropriate channel.
|
||||||
|
#[pin_project]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RequestDispatch<Req, Resp, C> {
|
||||||
|
/// Writes requests to the wire and reads responses off the wire.
|
||||||
|
#[pin]
|
||||||
|
transport: Fuse<C>,
|
||||||
|
/// Requests waiting to be written to the wire.
|
||||||
|
pending_requests: mpsc::Receiver<DispatchRequest<Req, Resp>>,
|
||||||
|
/// Requests that were dropped.
|
||||||
|
canceled_requests: CanceledRequests,
|
||||||
|
/// Requests already written to the wire that haven't yet received responses.
|
||||||
|
in_flight_requests: InFlightRequests<Resp>,
|
||||||
|
/// Configures limits to prevent unlimited resource usage.
|
||||||
|
config: Config,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Critical errors that result in a Channel disconnecting.
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum ChannelError<E>
|
||||||
|
where
|
||||||
|
E: Error + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
/// Could not read from the transport.
|
||||||
|
#[error("could not read from the transport")]
|
||||||
|
Read(#[source] E),
|
||||||
|
/// Could not ready the transport for writes.
|
||||||
|
#[error("could not ready the transport for writes")]
|
||||||
|
Ready(#[source] E),
|
||||||
|
/// Could not write to the transport.
|
||||||
|
#[error("could not write to the transport")]
|
||||||
|
Write(#[source] E),
|
||||||
|
/// Could not flush the transport.
|
||||||
|
#[error("could not flush the transport")]
|
||||||
|
Flush(#[source] E),
|
||||||
|
/// Could not close the write end of the transport.
|
||||||
|
#[error("could not close the write end of the transport")]
|
||||||
|
Close(#[source] E),
|
||||||
|
/// Could not poll expired requests.
|
||||||
|
#[error("could not poll expired requests")]
|
||||||
|
Timer(#[source] tokio::time::error::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
|
||||||
|
where
|
||||||
|
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||||
|
{
|
||||||
|
fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<Resp> {
|
||||||
|
self.as_mut().project().in_flight_requests
|
||||||
|
}
|
||||||
|
|
||||||
|
fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<C>> {
|
||||||
|
self.as_mut().project().transport
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_ready<'a>(
|
||||||
|
self: &'a mut Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||||
|
self.transport_pin_mut()
|
||||||
|
.poll_ready(cx)
|
||||||
|
.map_err(ChannelError::Ready)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(
|
||||||
|
self: &mut Pin<&mut Self>,
|
||||||
|
message: ClientMessage<Req>,
|
||||||
|
) -> Result<(), ChannelError<C::Error>> {
|
||||||
|
self.transport_pin_mut()
|
||||||
|
.start_send(message)
|
||||||
|
.map_err(ChannelError::Write)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush<'a>(
|
||||||
|
self: &'a mut Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||||
|
self.transport_pin_mut()
|
||||||
|
.poll_flush(cx)
|
||||||
|
.map_err(ChannelError::Flush)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close<'a>(
|
||||||
|
self: &'a mut Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||||
|
self.transport_pin_mut()
|
||||||
|
.poll_close(cx)
|
||||||
|
.map_err(ChannelError::Close)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn canceled_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut CanceledRequests {
|
||||||
|
self.as_mut().project().canceled_requests
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pending_requests_mut<'a>(
|
||||||
|
self: &'a mut Pin<&mut Self>,
|
||||||
|
) -> &'a mut mpsc::Receiver<DispatchRequest<Req, Resp>> {
|
||||||
|
self.as_mut().project().pending_requests
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pump_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||||
|
self.transport_pin_mut()
|
||||||
|
.poll_next(cx)
|
||||||
|
.map_err(ChannelError::Read)
|
||||||
|
.map_ok(|response| {
|
||||||
|
self.complete(response);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pump_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||||
|
enum ReceiverStatus {
|
||||||
|
Pending,
|
||||||
|
Closed,
|
||||||
|
}
|
||||||
|
|
||||||
|
let pending_requests_status = match self.as_mut().poll_write_request(cx)? {
|
||||||
|
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
|
||||||
|
Poll::Ready(None) => ReceiverStatus::Closed,
|
||||||
|
Poll::Pending => ReceiverStatus::Pending,
|
||||||
|
};
|
||||||
|
|
||||||
|
let canceled_requests_status = match self.as_mut().poll_write_cancel(cx)? {
|
||||||
|
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
|
||||||
|
Poll::Ready(None) => ReceiverStatus::Closed,
|
||||||
|
Poll::Pending => ReceiverStatus::Pending,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Receiving Poll::Ready(None) when polling expired requests never indicates "Closed",
|
||||||
|
// because there can temporarily be zero in-flight rquests. Therefore, there is no need to
|
||||||
|
// track the status like is done with pending and cancelled requests.
|
||||||
|
if let Poll::Ready(Some(_)) = self
|
||||||
|
.in_flight_requests()
|
||||||
|
.poll_expired(cx)
|
||||||
|
.map_err(ChannelError::Timer)?
|
||||||
|
{
|
||||||
|
// Expired requests are considered complete; there is no compelling reason to send a
|
||||||
|
// cancellation message to the server, since it will have already exhausted its
|
||||||
|
// allotted processing time.
|
||||||
|
return Poll::Ready(Some(Ok(())));
|
||||||
|
}
|
||||||
|
|
||||||
|
match (pending_requests_status, canceled_requests_status) {
|
||||||
|
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
|
||||||
|
ready!(self.poll_close(cx)?);
|
||||||
|
Poll::Ready(None)
|
||||||
|
}
|
||||||
|
(ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => {
|
||||||
|
// No more messages to process, so flush any messages buffered in the transport.
|
||||||
|
ready!(self.poll_flush(cx)?);
|
||||||
|
|
||||||
|
// Even if we fully-flush, we return Pending, because we have no more requests
|
||||||
|
// or cancellations right now.
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Yields the next pending request, if one is ready to be sent.
|
||||||
|
///
|
||||||
|
/// Note that a request will only be yielded if the transport is *ready* to be written to (i.e.
|
||||||
|
/// start_send would succeed).
|
||||||
|
fn poll_next_request(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Option<Result<DispatchRequest<Req, Resp>, ChannelError<C::Error>>>> {
|
||||||
|
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
|
||||||
|
tracing::info!(
|
||||||
|
"At in-flight request capacity ({}/{}).",
|
||||||
|
self.in_flight_requests().len(),
|
||||||
|
self.config.max_in_flight_requests
|
||||||
|
);
|
||||||
|
|
||||||
|
// No need to schedule a wakeup, because timers and responses are responsible
|
||||||
|
// for clearing out in-flight requests.
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
|
||||||
|
ready!(self.ensure_writeable(cx)?);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match ready!(self.pending_requests_mut().poll_recv(cx)) {
|
||||||
|
Some(request) => {
|
||||||
|
if request.response_completion.is_closed() {
|
||||||
|
let _entered = request.span.enter();
|
||||||
|
tracing::info!("AbortRequest");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Poll::Ready(Some(Ok(request)));
|
||||||
|
}
|
||||||
|
None => return Poll::Ready(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
|
||||||
|
///
|
||||||
|
/// Note that a request to cancel will only be yielded if the transport is *ready* to be
|
||||||
|
/// written to (i.e. start_send would succeed).
|
||||||
|
fn poll_next_cancellation(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Option<Result<(context::Context, Span, u64), ChannelError<C::Error>>>> {
|
||||||
|
ready!(self.ensure_writeable(cx)?);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match ready!(self.canceled_requests_mut().poll_next_unpin(cx)) {
|
||||||
|
Some(request_id) => {
|
||||||
|
if let Some((ctx, span)) = self.in_flight_requests().cancel_request(request_id)
|
||||||
|
{
|
||||||
|
return Poll::Ready(Some(Ok((ctx, span, request_id))));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => return Poll::Ready(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns Ready if writing a message to the transport (i.e. via write_request or
|
||||||
|
/// write_cancel) would not fail due to a full buffer. If the transport is not ready to be
|
||||||
|
/// written to, flushes it until it is ready.
|
||||||
|
fn ensure_writeable<'a>(
|
||||||
|
self: &'a mut Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||||
|
while self.poll_ready(cx)?.is_pending() {
|
||||||
|
ready!(self.poll_flush(cx)?);
|
||||||
|
}
|
||||||
|
Poll::Ready(Some(Ok(())))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_request<'a>(
|
||||||
|
self: &'a mut Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||||
|
let DispatchRequest {
|
||||||
|
ctx,
|
||||||
|
span,
|
||||||
|
request_id,
|
||||||
|
request,
|
||||||
|
response_completion,
|
||||||
|
} = match ready!(self.as_mut().poll_next_request(cx)?) {
|
||||||
|
Some(dispatch_request) => dispatch_request,
|
||||||
|
None => return Poll::Ready(None),
|
||||||
|
};
|
||||||
|
let entered = span.enter();
|
||||||
|
// poll_next_request only returns Ready if there is room to buffer another request.
|
||||||
|
// Therefore, we can call write_request without fear of erroring due to a full
|
||||||
|
// buffer.
|
||||||
|
let request_id = request_id;
|
||||||
|
let request = ClientMessage::Request(Request {
|
||||||
|
id: request_id,
|
||||||
|
message: request,
|
||||||
|
context: context::Context {
|
||||||
|
deadline: ctx.deadline,
|
||||||
|
trace_context: ctx.trace_context,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
self.start_send(request)?;
|
||||||
|
let deadline = ctx.deadline;
|
||||||
|
tracing::info!(
|
||||||
|
tarpc.deadline = %humantime::format_rfc3339(deadline),
|
||||||
|
"SendRequest"
|
||||||
|
);
|
||||||
|
drop(entered);
|
||||||
|
|
||||||
|
self.in_flight_requests()
|
||||||
|
.insert_request(request_id, ctx, span, response_completion)
|
||||||
|
.expect("Request IDs should be unique");
|
||||||
|
Poll::Ready(Some(Ok(())))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_cancel<'a>(
|
||||||
|
self: &'a mut Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||||
|
let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
|
||||||
|
Some(triple) => triple,
|
||||||
|
None => return Poll::Ready(None),
|
||||||
|
};
|
||||||
|
let _entered = span.enter();
|
||||||
|
|
||||||
|
let cancel = ClientMessage::Cancel {
|
||||||
|
trace_context: context.trace_context,
|
||||||
|
request_id,
|
||||||
|
};
|
||||||
|
self.start_send(cancel)?;
|
||||||
|
tracing::info!("CancelRequest");
|
||||||
|
Poll::Ready(Some(Ok(())))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sends a server response to the client task that initiated the associated request.
|
||||||
|
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
|
||||||
|
self.in_flight_requests().complete_request(response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
|
||||||
|
where
|
||||||
|
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||||
|
{
|
||||||
|
type Output = Result<(), ChannelError<C::Error>>;
|
||||||
|
|
||||||
|
fn poll(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||||
|
loop {
|
||||||
|
match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) {
|
||||||
|
(Poll::Ready(None), _) => {
|
||||||
|
tracing::info!("Shutdown: read half closed, so shutting down.");
|
||||||
|
return Poll::Ready(Ok(()));
|
||||||
|
}
|
||||||
|
(read, Poll::Ready(None)) => {
|
||||||
|
if self.in_flight_requests.is_empty() {
|
||||||
|
tracing::info!("Shutdown: write half closed, and no requests in flight.");
|
||||||
|
return Poll::Ready(Ok(()));
|
||||||
|
}
|
||||||
|
tracing::info!(
|
||||||
|
"Shutdown: write half closed, and {} requests in flight.",
|
||||||
|
self.in_flight_requests().len()
|
||||||
|
);
|
||||||
|
match read {
|
||||||
|
Poll::Ready(Some(())) => continue,
|
||||||
|
_ => return Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
|
||||||
|
_ => return Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage
|
||||||
|
/// the lifecycle of the request.
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct DispatchRequest<Req, Resp> {
|
||||||
|
pub ctx: context::Context,
|
||||||
|
pub span: Span,
|
||||||
|
pub request_id: u64,
|
||||||
|
pub request: Req,
|
||||||
|
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sends request cancellation signals.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct RequestCancellation(mpsc::UnboundedSender<u64>);
|
||||||
|
|
||||||
|
/// A stream of IDs of requests that have been canceled.
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
|
||||||
|
|
||||||
|
/// Returns a channel to send request cancellation messages.
|
||||||
|
fn cancellations() -> (RequestCancellation, CanceledRequests) {
|
||||||
|
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
|
||||||
|
// bounded by the number of in-flight requests.
|
||||||
|
let (tx, rx) = mpsc::unbounded_channel();
|
||||||
|
(RequestCancellation(tx), CanceledRequests(rx))
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RequestCancellation {
|
||||||
|
/// Cancels the request with ID `request_id`.
|
||||||
|
fn cancel(&self, request_id: u64) {
|
||||||
|
let _ = self.0.send(request_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CanceledRequests {
|
||||||
|
fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
||||||
|
self.0.poll_recv(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Stream for CanceledRequests {
|
||||||
|
type Item = u64;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
||||||
|
self.poll_recv(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{
|
||||||
|
cancellations, CanceledRequests, Channel, DispatchRequest, RequestCancellation,
|
||||||
|
RequestDispatch, ResponseGuard,
|
||||||
|
};
|
||||||
|
use crate::{
|
||||||
|
client::{
|
||||||
|
in_flight_requests::{DeadlineExceededError, InFlightRequests},
|
||||||
|
Config,
|
||||||
|
},
|
||||||
|
context,
|
||||||
|
transport::{self, channel::UnboundedChannel},
|
||||||
|
ClientMessage, Response,
|
||||||
|
};
|
||||||
|
use assert_matches::assert_matches;
|
||||||
|
use futures::{prelude::*, task::*};
|
||||||
|
use std::{
|
||||||
|
convert::TryFrom,
|
||||||
|
pin::Pin,
|
||||||
|
sync::atomic::{AtomicUsize, Ordering},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
use tracing::Span;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn response_completes_request_future() {
|
||||||
|
let (mut dispatch, mut _channel, mut server_channel) = set_up();
|
||||||
|
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||||
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
|
||||||
|
dispatch
|
||||||
|
.in_flight_requests
|
||||||
|
.insert_request(0, context::current(), Span::current(), tx)
|
||||||
|
.unwrap();
|
||||||
|
server_channel
|
||||||
|
.send(Response {
|
||||||
|
request_id: 0,
|
||||||
|
message: Ok("Resp".into()),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
|
||||||
|
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn dispatch_response_cancels_on_drop() {
|
||||||
|
let (cancellation, mut canceled_requests) = cancellations();
|
||||||
|
let (_, mut response) = oneshot::channel();
|
||||||
|
drop(ResponseGuard::<u32> {
|
||||||
|
response: &mut response,
|
||||||
|
cancellation: &cancellation,
|
||||||
|
request_id: 3,
|
||||||
|
});
|
||||||
|
// resp's drop() is run, which should send a cancel message.
|
||||||
|
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||||
|
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(Some(3)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn dispatch_response_doesnt_cancel_after_complete() {
|
||||||
|
let (cancellation, mut canceled_requests) = cancellations();
|
||||||
|
let (tx, mut response) = oneshot::channel();
|
||||||
|
tx.send(Ok(Response {
|
||||||
|
request_id: 0,
|
||||||
|
message: Ok("well done"),
|
||||||
|
}))
|
||||||
|
.unwrap();
|
||||||
|
// resp's drop() is run, but should not send a cancel message.
|
||||||
|
ResponseGuard {
|
||||||
|
response: &mut response,
|
||||||
|
cancellation: &cancellation,
|
||||||
|
request_id: 3,
|
||||||
|
}
|
||||||
|
.response()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
drop(cancellation);
|
||||||
|
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||||
|
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(None));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn stage_request() {
|
||||||
|
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||||
|
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||||
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
|
||||||
|
let _resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||||
|
|
||||||
|
let req = dispatch.as_mut().poll_next_request(cx).ready();
|
||||||
|
assert!(req.is_some());
|
||||||
|
|
||||||
|
let req = req.unwrap();
|
||||||
|
assert_eq!(req.request_id, 0);
|
||||||
|
assert_eq!(req.request, "hi".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regression test for https://github.com/google/tarpc/issues/220
|
||||||
|
#[tokio::test]
|
||||||
|
async fn stage_request_channel_dropped_doesnt_panic() {
|
||||||
|
let (mut dispatch, mut channel, mut server_channel) = set_up();
|
||||||
|
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||||
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
|
||||||
|
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||||
|
drop(channel);
|
||||||
|
|
||||||
|
assert!(dispatch.as_mut().poll(cx).is_ready());
|
||||||
|
send_response(
|
||||||
|
&mut server_channel,
|
||||||
|
Response {
|
||||||
|
request_id: 0,
|
||||||
|
message: Ok("hello".into()),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
dispatch.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn stage_request_response_future_dropped_is_canceled_before_sending() {
|
||||||
|
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||||
|
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||||
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
|
||||||
|
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||||
|
|
||||||
|
// Drop the channel so polling returns none if no requests are currently ready.
|
||||||
|
drop(channel);
|
||||||
|
// Test that a request future dropped before it's processed by dispatch will cause the request
|
||||||
|
// to not be added to the in-flight request map.
|
||||||
|
assert!(dispatch.as_mut().poll_next_request(cx).ready().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn stage_request_response_future_dropped_is_canceled_after_sending() {
|
||||||
|
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||||
|
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||||
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
|
||||||
|
let req = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||||
|
|
||||||
|
assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
|
||||||
|
assert!(!dispatch.in_flight_requests.is_empty());
|
||||||
|
|
||||||
|
// Test that a request future dropped after it's processed by dispatch will cause the request
|
||||||
|
// to be removed from the in-flight request map.
|
||||||
|
drop(req);
|
||||||
|
assert_matches!(
|
||||||
|
dispatch.as_mut().poll_next_cancellation(cx),
|
||||||
|
Poll::Ready(Some(Ok(_)))
|
||||||
|
);
|
||||||
|
assert!(dispatch.in_flight_requests.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn stage_request_response_closed_skipped() {
|
||||||
|
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||||
|
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||||
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
|
||||||
|
// Test that a request future that's closed its receiver but not yet canceled its request --
|
||||||
|
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
|
||||||
|
// map.
|
||||||
|
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||||
|
resp.response.close();
|
||||||
|
|
||||||
|
assert!(dispatch.as_mut().poll_next_request(cx).is_pending());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_up() -> (
|
||||||
|
Pin<
|
||||||
|
Box<
|
||||||
|
RequestDispatch<
|
||||||
|
String,
|
||||||
|
String,
|
||||||
|
UnboundedChannel<Response<String>, ClientMessage<String>>,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
|
Channel<String, String>,
|
||||||
|
UnboundedChannel<ClientMessage<String>, Response<String>>,
|
||||||
|
) {
|
||||||
|
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
|
||||||
|
|
||||||
|
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
||||||
|
let (cancel_tx, canceled_requests) = mpsc::unbounded_channel();
|
||||||
|
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||||
|
|
||||||
|
let dispatch = RequestDispatch::<String, String, _> {
|
||||||
|
transport: client_channel.fuse(),
|
||||||
|
pending_requests: pending_requests,
|
||||||
|
canceled_requests: CanceledRequests(canceled_requests),
|
||||||
|
in_flight_requests: InFlightRequests::default(),
|
||||||
|
config: Config::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let cancellation = RequestCancellation(cancel_tx);
|
||||||
|
let channel = Channel {
|
||||||
|
to_dispatch,
|
||||||
|
cancellation,
|
||||||
|
next_request_id: Arc::new(AtomicUsize::new(0)),
|
||||||
|
};
|
||||||
|
|
||||||
|
(Box::pin(dispatch), channel, server_channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_request<'a>(
|
||||||
|
channel: &'a mut Channel<String, String>,
|
||||||
|
request: &str,
|
||||||
|
response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
|
||||||
|
response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
|
||||||
|
) -> ResponseGuard<'a, String> {
|
||||||
|
let request_id =
|
||||||
|
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||||
|
let request = DispatchRequest {
|
||||||
|
ctx: context::current(),
|
||||||
|
span: Span::current(),
|
||||||
|
request_id,
|
||||||
|
request: request.to_string(),
|
||||||
|
response_completion,
|
||||||
|
};
|
||||||
|
channel.to_dispatch.send(request).await.unwrap();
|
||||||
|
|
||||||
|
ResponseGuard {
|
||||||
|
response,
|
||||||
|
cancellation: &channel.cancellation,
|
||||||
|
request_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_response(
|
||||||
|
channel: &mut UnboundedChannel<ClientMessage<String>, Response<String>>,
|
||||||
|
response: Response<String>,
|
||||||
|
) {
|
||||||
|
channel.send(response).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
trait PollTest {
|
||||||
|
type T;
|
||||||
|
fn unwrap(self) -> Poll<Self::T>;
|
||||||
|
fn ready(self) -> Self::T;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, E> PollTest for Poll<Option<Result<T, E>>>
|
||||||
|
where
|
||||||
|
E: ::std::fmt::Display,
|
||||||
|
{
|
||||||
|
type T = Option<T>;
|
||||||
|
|
||||||
|
fn unwrap(self) -> Poll<Option<T>> {
|
||||||
|
match self {
|
||||||
|
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
|
||||||
|
Poll::Ready(None) => Poll::Ready(None),
|
||||||
|
Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()),
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ready(self) -> Option<T> {
|
||||||
|
match self {
|
||||||
|
Poll::Ready(Some(Ok(t))) => Some(t),
|
||||||
|
Poll::Ready(None) => None,
|
||||||
|
Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()),
|
||||||
|
Poll::Pending => panic!("Pending"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
137
tarpc/src/client/in_flight_requests.rs
Normal file
137
tarpc/src/client/in_flight_requests.rs
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
use crate::{
|
||||||
|
context,
|
||||||
|
util::{Compact, TimeUntil},
|
||||||
|
Response,
|
||||||
|
};
|
||||||
|
use fnv::FnvHashMap;
|
||||||
|
use std::{
|
||||||
|
collections::hash_map,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
use tokio::sync::oneshot;
|
||||||
|
use tokio_util::time::delay_queue::{self, DelayQueue};
|
||||||
|
use tracing::Span;
|
||||||
|
|
||||||
|
/// Requests already written to the wire that haven't yet received responses.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct InFlightRequests<Resp> {
|
||||||
|
request_data: FnvHashMap<u64, RequestData<Resp>>,
|
||||||
|
deadlines: DelayQueue<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Resp> Default for InFlightRequests<Resp> {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
request_data: Default::default(),
|
||||||
|
deadlines: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The request exceeded its deadline.
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
#[error("the request exceeded its deadline")]
|
||||||
|
pub struct DeadlineExceededError;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct RequestData<Resp> {
|
||||||
|
ctx: context::Context,
|
||||||
|
span: Span,
|
||||||
|
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||||
|
/// The key to remove the timer for the request's deadline.
|
||||||
|
deadline_key: delay_queue::Key,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An error returned when an attempt is made to insert a request with an ID that is already in
|
||||||
|
/// use.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AlreadyExistsError;
|
||||||
|
|
||||||
|
impl<Resp> InFlightRequests<Resp> {
|
||||||
|
/// Returns the number of in-flight requests.
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.request_data.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true iff there are no requests in flight.
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.request_data.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Starts a request, unless a request with the same ID is already in flight.
|
||||||
|
pub fn insert_request(
|
||||||
|
&mut self,
|
||||||
|
request_id: u64,
|
||||||
|
ctx: context::Context,
|
||||||
|
span: Span,
|
||||||
|
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||||
|
) -> Result<(), AlreadyExistsError> {
|
||||||
|
match self.request_data.entry(request_id) {
|
||||||
|
hash_map::Entry::Vacant(vacant) => {
|
||||||
|
let timeout = ctx.deadline.time_until();
|
||||||
|
let deadline_key = self.deadlines.insert(request_id, timeout);
|
||||||
|
vacant.insert(RequestData {
|
||||||
|
ctx,
|
||||||
|
span,
|
||||||
|
response_completion,
|
||||||
|
deadline_key,
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
hash_map::Entry::Occupied(_) => Err(AlreadyExistsError),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Removes a request without aborting. Returns true iff the request was found.
|
||||||
|
pub fn complete_request(&mut self, response: Response<Resp>) -> bool {
|
||||||
|
if let Some(request_data) = self.request_data.remove(&response.request_id) {
|
||||||
|
let _entered = request_data.span.enter();
|
||||||
|
tracing::info!("ReceiveResponse");
|
||||||
|
self.request_data.compact(0.1);
|
||||||
|
self.deadlines.remove(&request_data.deadline_key);
|
||||||
|
let _ = request_data.response_completion.send(Ok(response));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
"No in-flight request found for request_id = {}.",
|
||||||
|
response.request_id
|
||||||
|
);
|
||||||
|
|
||||||
|
// If the response completion was absent, then the request was already canceled.
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cancels a request without completing (typically used when a request handle was dropped
|
||||||
|
/// before the request completed).
|
||||||
|
pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> {
|
||||||
|
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||||
|
self.request_data.compact(0.1);
|
||||||
|
self.deadlines.remove(&request_data.deadline_key);
|
||||||
|
Some((request_data.ctx, request_data.span))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Yields a request that has expired, completing it with a TimedOut error.
|
||||||
|
/// The caller should send cancellation messages for any yielded request ID.
|
||||||
|
pub fn poll_expired(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut Context,
|
||||||
|
) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
|
||||||
|
self.deadlines.poll_expired(cx).map_ok(|expired| {
|
||||||
|
let request_id = expired.into_inner();
|
||||||
|
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||||
|
let _entered = request_data.span.enter();
|
||||||
|
tracing::error!("DeadlineExceeded");
|
||||||
|
self.request_data.compact(0.1);
|
||||||
|
let _ = request_data
|
||||||
|
.response_completion
|
||||||
|
.send(Err(DeadlineExceededError));
|
||||||
|
}
|
||||||
|
request_id
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
102
tarpc/src/context.rs
Normal file
102
tarpc/src/context.rs
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
// Copyright 2018 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
//! Provides a request context that carries a deadline and trace context. This context is sent from
|
||||||
|
//! client to server and is used by the server to enforce response deadlines.
|
||||||
|
|
||||||
|
use crate::trace::{self, TraceId};
|
||||||
|
use opentelemetry::trace::TraceContextExt;
|
||||||
|
use static_assertions::assert_impl_all;
|
||||||
|
use std::{
|
||||||
|
convert::TryFrom,
|
||||||
|
time::{Duration, SystemTime},
|
||||||
|
};
|
||||||
|
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||||
|
|
||||||
|
/// A request context that carries request-scoped information like deadlines and trace information.
|
||||||
|
/// It is sent from client to server and is used by the server to enforce response deadlines.
|
||||||
|
///
|
||||||
|
/// The context should not be stored directly in a server implementation, because the context will
|
||||||
|
/// be different for each request in scope.
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
|
pub struct Context {
|
||||||
|
/// When the client expects the request to be complete by. The server should cancel the request
|
||||||
|
/// if it is not complete by this time.
|
||||||
|
#[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))]
|
||||||
|
pub deadline: SystemTime,
|
||||||
|
/// Uniquely identifies requests originating from the same source.
|
||||||
|
/// When a service handles a request by making requests itself, those requests should
|
||||||
|
/// include the same `trace_id` as that included on the original request. This way,
|
||||||
|
/// users can trace related actions across a distributed system.
|
||||||
|
pub trace_context: trace::Context,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_impl_all!(Context: Send, Sync);
|
||||||
|
|
||||||
|
fn ten_seconds_from_now() -> SystemTime {
|
||||||
|
SystemTime::now() + Duration::from_secs(10)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the context for the current request, or a default Context if no request is active.
|
||||||
|
pub fn current() -> Context {
|
||||||
|
Context::current()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct Deadline(SystemTime);
|
||||||
|
|
||||||
|
impl Default for Deadline {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self(ten_seconds_from_now())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Context {
|
||||||
|
/// Returns the context for the current request, or a default Context if no request is active.
|
||||||
|
pub fn current() -> Self {
|
||||||
|
let span = tracing::Span::current();
|
||||||
|
Self {
|
||||||
|
trace_context: trace::Context::try_from(&span)
|
||||||
|
.unwrap_or_else(|_| trace::Context::default()),
|
||||||
|
deadline: span
|
||||||
|
.context()
|
||||||
|
.get::<Deadline>()
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_default()
|
||||||
|
.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the ID of the request-scoped trace.
|
||||||
|
pub fn trace_id(&self) -> &TraceId {
|
||||||
|
&self.trace_context.trace_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An extension trait for [`tracing::Span`] for propagating tarpc Contexts.
|
||||||
|
pub(crate) trait SpanExt {
|
||||||
|
/// Sets the given context on this span. Newly-created spans will be children of the given
|
||||||
|
/// context's trace context.
|
||||||
|
fn set_context(&self, context: &Context);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SpanExt for tracing::Span {
|
||||||
|
fn set_context(&self, context: &Context) {
|
||||||
|
self.set_parent(
|
||||||
|
opentelemetry::Context::new()
|
||||||
|
.with_remote_span_context(opentelemetry::trace::SpanContext::new(
|
||||||
|
opentelemetry::trace::TraceId::from(context.trace_context.trace_id),
|
||||||
|
opentelemetry::trace::SpanId::from(context.trace_context.span_id),
|
||||||
|
opentelemetry::trace::TraceFlags::from(context.trace_context.sampling_decision),
|
||||||
|
true,
|
||||||
|
opentelemetry::trace::TraceState::default(),
|
||||||
|
))
|
||||||
|
.with_value(Deadline(context.deadline)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
241
tarpc/src/lib.rs
241
tarpc/src/lib.rs
@@ -3,10 +3,6 @@
|
|||||||
// Use of this source code is governed by an MIT-style
|
// Use of this source code is governed by an MIT-style
|
||||||
// license that can be found in the LICENSE file or at
|
// license that can be found in the LICENSE file or at
|
||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
//! [](https://crates.io/crates/tarpc)
|
|
||||||
//! [](https://gitter.im/tarpc/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
|
||||||
//!
|
|
||||||
//! *Disclaimer*: This is not an official Google product.
|
//! *Disclaimer*: This is not an official Google product.
|
||||||
//!
|
//!
|
||||||
//! tarpc is an RPC framework for rust with a focus on ease of use. Defining a
|
//! tarpc is an RPC framework for rust with a focus on ease of use. Defining a
|
||||||
@@ -42,6 +38,14 @@
|
|||||||
//! requests sent by the server that use the request context will propagate the request deadline.
|
//! requests sent by the server that use the request context will propagate the request deadline.
|
||||||
//! For example, if a server is handling a request with a 10s deadline, does 2s of work, then
|
//! For example, if a server is handling a request with a 10s deadline, does 2s of work, then
|
||||||
//! sends a request to another server, that server will see an 8s deadline.
|
//! sends a request to another server, that server will see an 8s deadline.
|
||||||
|
//! - Distributed tracing: tarpc is instrumented with
|
||||||
|
//! [tracing](https://github.com/tokio-rs/tracing) primitives extended with
|
||||||
|
//! [OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
|
||||||
|
//! [Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
|
||||||
|
//! each RPC can be traced through the client, server, amd other dependencies downstream of the
|
||||||
|
//! server. Even for applications not connected to a distributed tracing collector, the
|
||||||
|
//! instrumentation can also be ingested by regular loggers like
|
||||||
|
//! [env_logger](https://github.com/env-logger-rs/env_logger/).
|
||||||
//! - Serde serialization: enabling the `serde1` Cargo feature will make service requests and
|
//! - Serde serialization: enabling the `serde1` Cargo feature will make service requests and
|
||||||
//! responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
|
//! responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
|
||||||
//! be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
|
//! be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
|
||||||
@@ -50,7 +54,7 @@
|
|||||||
//! Add to your `Cargo.toml` dependencies:
|
//! Add to your `Cargo.toml` dependencies:
|
||||||
//!
|
//!
|
||||||
//! ```toml
|
//! ```toml
|
||||||
//! tarpc = "0.20.0"
|
//! tarpc = "0.27"
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||||
@@ -59,12 +63,14 @@
|
|||||||
//!
|
//!
|
||||||
//! ## Example
|
//! ## Example
|
||||||
//!
|
//!
|
||||||
//! For this example, in addition to tarpc, also add two other dependencies to
|
//! This example uses [tokio](https://tokio.rs), so add the following dependencies to
|
||||||
//! your `Cargo.toml`:
|
//! your `Cargo.toml`:
|
||||||
//!
|
//!
|
||||||
//! ```toml
|
//! ```toml
|
||||||
|
//! anyhow = "1.0"
|
||||||
//! futures = "0.3"
|
//! futures = "0.3"
|
||||||
//! tokio = "0.2"
|
//! tarpc = { version = "0.27", features = ["tokio1"] }
|
||||||
|
//! tokio = { version = "1.0", features = ["macros"] }
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! In the following example, we use an in-process channel for communication between
|
//! In the following example, we use an in-process channel for communication between
|
||||||
@@ -82,9 +88,8 @@
|
|||||||
//! };
|
//! };
|
||||||
//! use tarpc::{
|
//! use tarpc::{
|
||||||
//! client, context,
|
//! client, context,
|
||||||
//! server::{self, Handler},
|
//! server::{self, incoming::Incoming},
|
||||||
//! };
|
//! };
|
||||||
//! use std::io;
|
|
||||||
//!
|
//!
|
||||||
//! // This is the service definition. It looks a lot like a trait definition.
|
//! // This is the service definition. It looks a lot like a trait definition.
|
||||||
//! // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
//! // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||||
@@ -106,9 +111,8 @@
|
|||||||
//! # };
|
//! # };
|
||||||
//! # use tarpc::{
|
//! # use tarpc::{
|
||||||
//! # client, context,
|
//! # client, context,
|
||||||
//! # server::{self, Handler},
|
//! # server::{self, incoming::Incoming},
|
||||||
//! # };
|
//! # };
|
||||||
//! # use std::io;
|
|
||||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||||
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||||
//! # #[tarpc::service]
|
//! # #[tarpc::service]
|
||||||
@@ -134,7 +138,7 @@
|
|||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! Lastly let's write our `main` that will start the server. While this example uses an
|
//! Lastly let's write our `main` that will start the server. While this example uses an
|
||||||
//! [in-process channel](rpc::transport::channel), tarpc also ships a generic [`serde_transport`]
|
//! [in-process channel](transport::channel), tarpc also ships a generic [`serde_transport`]
|
||||||
//! behind the `serde-transport` feature, with additional [TCP](serde_transport::tcp) functionality
|
//! behind the `serde-transport` feature, with additional [TCP](serde_transport::tcp) functionality
|
||||||
//! available behind the `tcp` feature.
|
//! available behind the `tcp` feature.
|
||||||
//!
|
//!
|
||||||
@@ -146,9 +150,8 @@
|
|||||||
//! # };
|
//! # };
|
||||||
//! # use tarpc::{
|
//! # use tarpc::{
|
||||||
//! # client, context,
|
//! # client, context,
|
||||||
//! # server::{self, Handler},
|
//! # server::{self, Channel},
|
||||||
//! # };
|
//! # };
|
||||||
//! # use std::io;
|
|
||||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||||
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||||
//! # #[tarpc::service]
|
//! # #[tarpc::service]
|
||||||
@@ -168,21 +171,19 @@
|
|||||||
//! # future::ready(format!("Hello, {}!", name))
|
//! # future::ready(format!("Hello, {}!", name))
|
||||||
//! # }
|
//! # }
|
||||||
//! # }
|
//! # }
|
||||||
|
//! # #[cfg(not(feature = "tokio1"))]
|
||||||
|
//! # fn main() {}
|
||||||
|
//! # #[cfg(feature = "tokio1")]
|
||||||
//! #[tokio::main]
|
//! #[tokio::main]
|
||||||
//! async fn main() -> io::Result<()> {
|
//! async fn main() -> anyhow::Result<()> {
|
||||||
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||||
//!
|
//!
|
||||||
//! let server = server::new(server::Config::default())
|
//! let server = server::BaseChannel::with_defaults(server_transport);
|
||||||
//! // incoming() takes a stream of transports such as would be returned by
|
//! tokio::spawn(server.execute(HelloServer.serve()));
|
||||||
//! // TcpListener::incoming (but a stream instead of an iterator).
|
|
||||||
//! .incoming(stream::once(future::ready(server_transport)))
|
|
||||||
//! .respond_with(HelloServer.serve());
|
|
||||||
//!
|
//!
|
||||||
//! tokio::spawn(server);
|
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||||
//!
|
//! // that takes a config and any Transport as input.
|
||||||
//! // WorldClient is generated by the macro. It has a constructor `new` that takes a config and
|
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
||||||
//! // any Transport as input
|
|
||||||
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
|
|
||||||
//!
|
//!
|
||||||
//! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
//! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||||
//! // args as defined, with the addition of a Context, which is always the first arg. The Context
|
//! // args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||||
@@ -201,21 +202,29 @@
|
|||||||
//! items expanded by a `service!` invocation.
|
//! items expanded by a `service!` invocation.
|
||||||
#![deny(missing_docs)]
|
#![deny(missing_docs)]
|
||||||
#![allow(clippy::type_complexity)]
|
#![allow(clippy::type_complexity)]
|
||||||
|
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||||
|
|
||||||
pub mod rpc;
|
#[cfg(feature = "serde1")]
|
||||||
pub use rpc::*;
|
#[doc(hidden)]
|
||||||
|
pub use serde;
|
||||||
|
|
||||||
#[cfg(feature = "serde-transport")]
|
#[cfg(feature = "serde-transport")]
|
||||||
|
pub use tokio_serde;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde-transport")]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(feature = "serde-transport")))]
|
||||||
pub mod serde_transport;
|
pub mod serde_transport;
|
||||||
|
|
||||||
pub mod trace;
|
pub mod trace;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde1")]
|
||||||
|
pub use tarpc_plugins::derive_serde;
|
||||||
|
|
||||||
/// The main macro that creates RPC services.
|
/// The main macro that creates RPC services.
|
||||||
///
|
///
|
||||||
/// Rpc methods are specified, mirroring trait syntax:
|
/// Rpc methods are specified, mirroring trait syntax:
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
/// # fn main() {}
|
|
||||||
/// #[tarpc::service]
|
/// #[tarpc::service]
|
||||||
/// trait Service {
|
/// trait Service {
|
||||||
/// /// Say hello
|
/// /// Say hello
|
||||||
@@ -234,3 +243,179 @@ pub mod trace;
|
|||||||
/// * `Client` -- a client stub with a fn for each RPC.
|
/// * `Client` -- a client stub with a fn for each RPC.
|
||||||
/// * `fn new_stub` -- creates a new Client stub.
|
/// * `fn new_stub` -- creates a new Client stub.
|
||||||
pub use tarpc_plugins::service;
|
pub use tarpc_plugins::service;
|
||||||
|
|
||||||
|
/// A utility macro that can be used for RPC server implementations.
|
||||||
|
///
|
||||||
|
/// Syntactic sugar to make using async functions in the server implementation
|
||||||
|
/// easier. It does this by rewriting code like this, which would normally not
|
||||||
|
/// compile because async functions are disallowed in trait implementations:
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use tarpc::context;
|
||||||
|
/// # use std::net::SocketAddr;
|
||||||
|
/// #[tarpc::service]
|
||||||
|
/// trait World {
|
||||||
|
/// async fn hello(name: String) -> String;
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// #[derive(Clone)]
|
||||||
|
/// struct HelloServer(SocketAddr);
|
||||||
|
///
|
||||||
|
/// #[tarpc::server]
|
||||||
|
/// impl World for HelloServer {
|
||||||
|
/// async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
|
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Into code like this, which matches the service trait definition:
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use tarpc::context;
|
||||||
|
/// # use std::pin::Pin;
|
||||||
|
/// # use futures::Future;
|
||||||
|
/// # use std::net::SocketAddr;
|
||||||
|
/// #[derive(Clone)]
|
||||||
|
/// struct HelloServer(SocketAddr);
|
||||||
|
///
|
||||||
|
/// #[tarpc::service]
|
||||||
|
/// trait World {
|
||||||
|
/// async fn hello(name: String) -> String;
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// impl World for HelloServer {
|
||||||
|
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
|
||||||
|
///
|
||||||
|
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
|
||||||
|
/// + Send>> {
|
||||||
|
/// Box::pin(async move {
|
||||||
|
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
||||||
|
/// })
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Note that this won't touch functions unless they have been annotated with
|
||||||
|
/// `async`, meaning that this should not break existing code.
|
||||||
|
pub use tarpc_plugins::server;
|
||||||
|
|
||||||
|
pub mod client;
|
||||||
|
pub mod context;
|
||||||
|
pub mod server;
|
||||||
|
pub mod transport;
|
||||||
|
pub(crate) mod util;
|
||||||
|
|
||||||
|
pub use crate::transport::sealed::Transport;
|
||||||
|
|
||||||
|
use anyhow::Context as _;
|
||||||
|
use futures::task::*;
|
||||||
|
use std::{error::Error, fmt::Display, io, time::SystemTime};
|
||||||
|
|
||||||
|
/// A message from a client to a server.
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub enum ClientMessage<T> {
|
||||||
|
/// A request initiated by a user. The server responds to a request by invoking a
|
||||||
|
/// service-provided request handler. The handler completes with a [`response`](Response), which
|
||||||
|
/// the server sends back to the client.
|
||||||
|
Request(Request<T>),
|
||||||
|
/// A command to cancel an in-flight request, automatically sent by the client when a response
|
||||||
|
/// future is dropped.
|
||||||
|
///
|
||||||
|
/// When received, the server will immediately cancel the main task (top-level future) of the
|
||||||
|
/// request handler for the associated request. Any tasks spawned by the request handler will
|
||||||
|
/// not be canceled, because the framework layer does not
|
||||||
|
/// know about them.
|
||||||
|
Cancel {
|
||||||
|
/// The trace context associates the message with a specific chain of causally-related actions,
|
||||||
|
/// possibly orchestrated across many distributed systems.
|
||||||
|
#[cfg_attr(feature = "serde1", serde(default))]
|
||||||
|
trace_context: trace::Context,
|
||||||
|
/// The ID of the request to cancel.
|
||||||
|
request_id: u64,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A request from a client to a server.
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
|
pub struct Request<T> {
|
||||||
|
/// Trace context, deadline, and other cross-cutting concerns.
|
||||||
|
pub context: context::Context,
|
||||||
|
/// Uniquely identifies the request across all requests sent over a single channel.
|
||||||
|
pub id: u64,
|
||||||
|
/// The request body.
|
||||||
|
pub message: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A response from a server to a client.
|
||||||
|
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
|
pub struct Response<T> {
|
||||||
|
/// The ID of the request being responded to.
|
||||||
|
pub request_id: u64,
|
||||||
|
/// The response body, or an error if the request failed.
|
||||||
|
pub message: Result<T, ServerError>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An error indicating the server aborted the request early, e.g., due to request throttling.
|
||||||
|
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
|
||||||
|
#[error("{kind:?}: {detail}")]
|
||||||
|
#[non_exhaustive]
|
||||||
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
|
pub struct ServerError {
|
||||||
|
#[cfg_attr(
|
||||||
|
feature = "serde1",
|
||||||
|
serde(serialize_with = "util::serde::serialize_io_error_kind_as_u32")
|
||||||
|
)]
|
||||||
|
#[cfg_attr(
|
||||||
|
feature = "serde1",
|
||||||
|
serde(deserialize_with = "util::serde::deserialize_io_error_kind_from_u32")
|
||||||
|
)]
|
||||||
|
/// The type of error that occurred to fail the request.
|
||||||
|
pub kind: io::ErrorKind,
|
||||||
|
/// A message describing more detail about the error that occurred.
|
||||||
|
pub detail: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Request<T> {
|
||||||
|
/// Returns the deadline for this request.
|
||||||
|
pub fn deadline(&self) -> &SystemTime {
|
||||||
|
&self.context.deadline
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) trait PollContext<T> {
|
||||||
|
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
|
||||||
|
where
|
||||||
|
C: Display + Send + Sync + 'static;
|
||||||
|
|
||||||
|
fn with_context<C, F>(self, f: F) -> Poll<Option<anyhow::Result<T>>>
|
||||||
|
where
|
||||||
|
C: Display + Send + Sync + 'static,
|
||||||
|
F: FnOnce() -> C;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, E> PollContext<T> for Poll<Option<Result<T, E>>>
|
||||||
|
where
|
||||||
|
E: Error + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
|
||||||
|
where
|
||||||
|
C: Display + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
self.map(|o| o.map(|r| r.context(context)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn with_context<C, F>(self, f: F) -> Poll<Option<anyhow::Result<T>>>
|
||||||
|
where
|
||||||
|
C: Display + Send + Sync + 'static,
|
||||||
|
F: FnOnce() -> C,
|
||||||
|
{
|
||||||
|
self.map(|o| o.map(|r| r.with_context(f)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,906 +0,0 @@
|
|||||||
// Copyright 2018 Google LLC
|
|
||||||
//
|
|
||||||
// Use of this source code is governed by an MIT-style
|
|
||||||
// license that can be found in the LICENSE file or at
|
|
||||||
// https://opensource.org/licenses/MIT.
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
context,
|
|
||||||
trace::SpanId,
|
|
||||||
util::{Compact, TimeUntil},
|
|
||||||
ClientMessage, PollIo, Request, Response, Transport,
|
|
||||||
};
|
|
||||||
use fnv::FnvHashMap;
|
|
||||||
use futures::{
|
|
||||||
channel::{mpsc, oneshot},
|
|
||||||
prelude::*,
|
|
||||||
ready,
|
|
||||||
stream::Fuse,
|
|
||||||
task::*,
|
|
||||||
};
|
|
||||||
use log::{debug, info, trace};
|
|
||||||
use pin_project::{pin_project, pinned_drop};
|
|
||||||
use std::{
|
|
||||||
io,
|
|
||||||
pin::Pin,
|
|
||||||
sync::{
|
|
||||||
atomic::{AtomicU64, Ordering},
|
|
||||||
Arc,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::{Config, NewClient};
|
|
||||||
|
|
||||||
/// Handles communication from the client to request dispatch.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Channel<Req, Resp> {
|
|
||||||
to_dispatch: mpsc::Sender<DispatchRequest<Req, Resp>>,
|
|
||||||
/// Channel to send a cancel message to the dispatcher.
|
|
||||||
cancellation: RequestCancellation,
|
|
||||||
/// The ID to use for the next request to stage.
|
|
||||||
next_request_id: Arc<AtomicU64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Req, Resp> Clone for Channel<Req, Resp> {
|
|
||||||
fn clone(&self) -> Self {
|
|
||||||
Self {
|
|
||||||
to_dispatch: self.to_dispatch.clone(),
|
|
||||||
cancellation: self.cancellation.clone(),
|
|
||||||
next_request_id: self.next_request_id.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A future returned by [`Channel::send`] that resolves to a server response.
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
#[must_use = "futures do nothing unless polled"]
|
|
||||||
struct Send<'a, Req, Resp> {
|
|
||||||
#[pin]
|
|
||||||
fut: MapOkDispatchResponse<SendMapErrConnectionReset<'a, Req, Resp>, Resp>,
|
|
||||||
}
|
|
||||||
|
|
||||||
type SendMapErrConnectionReset<'a, Req, Resp> = MapErrConnectionReset<
|
|
||||||
futures::sink::Send<'a, mpsc::Sender<DispatchRequest<Req, Resp>>, DispatchRequest<Req, Resp>>,
|
|
||||||
>;
|
|
||||||
|
|
||||||
impl<'a, Req, Resp> Future for Send<'a, Req, Resp> {
|
|
||||||
type Output = io::Result<DispatchResponse<Resp>>;
|
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
||||||
self.as_mut().project().fut.poll(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A future returned by [`Channel::call`] that resolves to a server response.
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
#[must_use = "futures do nothing unless polled"]
|
|
||||||
pub struct Call<'a, Req, Resp> {
|
|
||||||
#[pin]
|
|
||||||
fut: AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, Req, Resp> Future for Call<'a, Req, Resp> {
|
|
||||||
type Output = io::Result<Resp>;
|
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
||||||
self.as_mut().project().fut.poll(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Req, Resp> Channel<Req, Resp> {
|
|
||||||
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
|
||||||
/// resolves when the request is sent (not when the response is received).
|
|
||||||
fn send(&mut self, mut ctx: context::Context, request: Req) -> Send<Req, Resp> {
|
|
||||||
// Convert the context to the call context.
|
|
||||||
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
|
|
||||||
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
|
|
||||||
|
|
||||||
let timeout = ctx.deadline.time_until();
|
|
||||||
trace!(
|
|
||||||
"[{}] Queuing request with timeout {:?}.",
|
|
||||||
ctx.trace_id(),
|
|
||||||
timeout,
|
|
||||||
);
|
|
||||||
|
|
||||||
let (response_completion, response) = oneshot::channel();
|
|
||||||
let cancellation = self.cancellation.clone();
|
|
||||||
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
|
||||||
Send {
|
|
||||||
fut: MapOkDispatchResponse::new(
|
|
||||||
MapErrConnectionReset::new(self.to_dispatch.send(DispatchRequest {
|
|
||||||
ctx,
|
|
||||||
request_id,
|
|
||||||
request,
|
|
||||||
response_completion,
|
|
||||||
})),
|
|
||||||
DispatchResponse {
|
|
||||||
response: tokio::time::timeout(timeout, response),
|
|
||||||
complete: false,
|
|
||||||
request_id,
|
|
||||||
cancellation,
|
|
||||||
ctx,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
|
||||||
/// resolves to the response.
|
|
||||||
pub fn call(&mut self, context: context::Context, request: Req) -> Call<Req, Resp> {
|
|
||||||
Call {
|
|
||||||
fut: AndThenIdent::new(self.send(context, request)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A server response that is completed by request dispatch when the corresponding response
|
|
||||||
/// arrives off the wire.
|
|
||||||
#[pin_project(PinnedDrop)]
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct DispatchResponse<Resp> {
|
|
||||||
response: tokio::time::Timeout<oneshot::Receiver<Response<Resp>>>,
|
|
||||||
ctx: context::Context,
|
|
||||||
complete: bool,
|
|
||||||
cancellation: RequestCancellation,
|
|
||||||
request_id: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Resp> Future for DispatchResponse<Resp> {
|
|
||||||
type Output = io::Result<Resp>;
|
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
|
|
||||||
let resp = ready!(self.response.poll_unpin(cx));
|
|
||||||
|
|
||||||
Poll::Ready(match resp {
|
|
||||||
Ok(resp) => {
|
|
||||||
self.complete = true;
|
|
||||||
match resp {
|
|
||||||
Ok(resp) => Ok(resp.message?),
|
|
||||||
Err(oneshot::Canceled) => {
|
|
||||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
|
||||||
// there's nothing listening on the other side, so there's no point in
|
|
||||||
// propagating cancellation.
|
|
||||||
Err(io::Error::from(io::ErrorKind::ConnectionReset))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(tokio::time::Elapsed { .. }) => Err(io::Error::new(
|
|
||||||
io::ErrorKind::TimedOut,
|
|
||||||
"Client dropped expired request.".to_string(),
|
|
||||||
)),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cancels the request when dropped, if not already complete.
|
|
||||||
#[pinned_drop]
|
|
||||||
impl<Resp> PinnedDrop for DispatchResponse<Resp> {
|
|
||||||
fn drop(mut self: Pin<&mut Self>) {
|
|
||||||
if !self.complete {
|
|
||||||
// The receiver needs to be closed to handle the edge case that the request has not
|
|
||||||
// yet been received by the dispatch task. It is possible for the cancel message to
|
|
||||||
// arrive before the request itself, in which case the request could get stuck in the
|
|
||||||
// dispatch map forever if the server never responds (e.g. if the server dies while
|
|
||||||
// responding). Even if the server does respond, it will have unnecessarily done work
|
|
||||||
// for a client no longer waiting for a response. To avoid this, the dispatch task
|
|
||||||
// checks if the receiver is closed before inserting the request in the map. By
|
|
||||||
// closing the receiver before sending the cancel message, it is guaranteed that if the
|
|
||||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
|
||||||
// receiver as closed.
|
|
||||||
self.response.get_mut().close();
|
|
||||||
let request_id = self.request_id;
|
|
||||||
self.cancellation.cancel(request_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the
|
|
||||||
/// channel.
|
|
||||||
pub fn new<Req, Resp, C>(
|
|
||||||
config: Config,
|
|
||||||
transport: C,
|
|
||||||
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C>>
|
|
||||||
where
|
|
||||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
|
||||||
{
|
|
||||||
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
|
|
||||||
let (cancellation, canceled_requests) = cancellations();
|
|
||||||
let canceled_requests = canceled_requests.fuse();
|
|
||||||
|
|
||||||
NewClient {
|
|
||||||
client: Channel {
|
|
||||||
to_dispatch,
|
|
||||||
cancellation,
|
|
||||||
next_request_id: Arc::new(AtomicU64::new(0)),
|
|
||||||
},
|
|
||||||
dispatch: RequestDispatch {
|
|
||||||
config,
|
|
||||||
canceled_requests,
|
|
||||||
transport: transport.fuse(),
|
|
||||||
in_flight_requests: FnvHashMap::default(),
|
|
||||||
pending_requests: pending_requests.fuse(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
|
|
||||||
/// and dispatching responses to the appropriate channel.
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct RequestDispatch<Req, Resp, C> {
|
|
||||||
/// Writes requests to the wire and reads responses off the wire.
|
|
||||||
#[pin]
|
|
||||||
transport: Fuse<C>,
|
|
||||||
/// Requests waiting to be written to the wire.
|
|
||||||
#[pin]
|
|
||||||
pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>,
|
|
||||||
/// Requests that were dropped.
|
|
||||||
#[pin]
|
|
||||||
canceled_requests: Fuse<CanceledRequests>,
|
|
||||||
/// Requests already written to the wire that haven't yet received responses.
|
|
||||||
in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>,
|
|
||||||
/// Configures limits to prevent unlimited resource usage.
|
|
||||||
config: Config,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
|
|
||||||
where
|
|
||||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
|
||||||
{
|
|
||||||
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
|
||||||
Poll::Ready(
|
|
||||||
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
|
|
||||||
Some(response) => {
|
|
||||||
self.complete(response);
|
|
||||||
Some(Ok(()))
|
|
||||||
}
|
|
||||||
None => None,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
|
||||||
enum ReceiverStatus {
|
|
||||||
NotReady,
|
|
||||||
Closed,
|
|
||||||
}
|
|
||||||
|
|
||||||
let pending_requests_status = match self.as_mut().poll_next_request(cx)? {
|
|
||||||
Poll::Ready(Some(dispatch_request)) => {
|
|
||||||
self.as_mut().write_request(dispatch_request)?;
|
|
||||||
return Poll::Ready(Some(Ok(())));
|
|
||||||
}
|
|
||||||
Poll::Ready(None) => ReceiverStatus::Closed,
|
|
||||||
Poll::Pending => ReceiverStatus::NotReady,
|
|
||||||
};
|
|
||||||
|
|
||||||
let canceled_requests_status = match self.as_mut().poll_next_cancellation(cx)? {
|
|
||||||
Poll::Ready(Some((context, request_id))) => {
|
|
||||||
self.as_mut().write_cancel(context, request_id)?;
|
|
||||||
return Poll::Ready(Some(Ok(())));
|
|
||||||
}
|
|
||||||
Poll::Ready(None) => ReceiverStatus::Closed,
|
|
||||||
Poll::Pending => ReceiverStatus::NotReady,
|
|
||||||
};
|
|
||||||
|
|
||||||
match (pending_requests_status, canceled_requests_status) {
|
|
||||||
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
|
|
||||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
|
||||||
Poll::Ready(None)
|
|
||||||
}
|
|
||||||
(ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
|
|
||||||
// No more messages to process, so flush any messages buffered in the transport.
|
|
||||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
|
||||||
|
|
||||||
// Even if we fully-flush, we return Pending, because we have no more requests
|
|
||||||
// or cancellations right now.
|
|
||||||
Poll::Pending
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Yields the next pending request, if one is ready to be sent.
|
|
||||||
fn poll_next_request(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
) -> PollIo<DispatchRequest<Req, Resp>> {
|
|
||||||
if self.as_mut().project().in_flight_requests.len() >= self.config.max_in_flight_requests {
|
|
||||||
info!(
|
|
||||||
"At in-flight request capacity ({}/{}).",
|
|
||||||
self.as_mut().project().in_flight_requests.len(),
|
|
||||||
self.config.max_in_flight_requests
|
|
||||||
);
|
|
||||||
|
|
||||||
// No need to schedule a wakeup, because timers and responses are responsible
|
|
||||||
// for clearing out in-flight requests.
|
|
||||||
return Poll::Pending;
|
|
||||||
}
|
|
||||||
|
|
||||||
while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? {
|
|
||||||
// We can't yield a request-to-be-sent before the transport is capable of buffering it.
|
|
||||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
|
||||||
}
|
|
||||||
|
|
||||||
loop {
|
|
||||||
match ready!(self.as_mut().project().pending_requests.poll_next_unpin(cx)) {
|
|
||||||
Some(request) => {
|
|
||||||
if request.response_completion.is_canceled() {
|
|
||||||
trace!(
|
|
||||||
"[{}] Request canceled before being sent.",
|
|
||||||
request.ctx.trace_id()
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Poll::Ready(Some(Ok(request)));
|
|
||||||
}
|
|
||||||
None => return Poll::Ready(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
|
|
||||||
fn poll_next_cancellation(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
) -> PollIo<(context::Context, u64)> {
|
|
||||||
while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? {
|
|
||||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
|
||||||
}
|
|
||||||
|
|
||||||
loop {
|
|
||||||
let cancellation = self
|
|
||||||
.as_mut()
|
|
||||||
.project()
|
|
||||||
.canceled_requests
|
|
||||||
.poll_next_unpin(cx);
|
|
||||||
match ready!(cancellation) {
|
|
||||||
Some(request_id) => {
|
|
||||||
if let Some(in_flight_data) = self
|
|
||||||
.as_mut()
|
|
||||||
.project()
|
|
||||||
.in_flight_requests
|
|
||||||
.remove(&request_id)
|
|
||||||
{
|
|
||||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
|
||||||
debug!("[{}] Removed request.", in_flight_data.ctx.trace_id());
|
|
||||||
return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id))));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => return Poll::Ready(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn write_request(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
dispatch_request: DispatchRequest<Req, Resp>,
|
|
||||||
) -> io::Result<()> {
|
|
||||||
let request_id = dispatch_request.request_id;
|
|
||||||
let request = ClientMessage::Request(Request {
|
|
||||||
id: request_id,
|
|
||||||
message: dispatch_request.request,
|
|
||||||
context: context::Context {
|
|
||||||
deadline: dispatch_request.ctx.deadline,
|
|
||||||
trace_context: dispatch_request.ctx.trace_context,
|
|
||||||
_non_exhaustive: (),
|
|
||||||
},
|
|
||||||
_non_exhaustive: (),
|
|
||||||
});
|
|
||||||
self.as_mut().project().transport.start_send(request)?;
|
|
||||||
self.as_mut().project().in_flight_requests.insert(
|
|
||||||
request_id,
|
|
||||||
InFlightData {
|
|
||||||
ctx: dispatch_request.ctx,
|
|
||||||
response_completion: dispatch_request.response_completion,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn write_cancel(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
context: context::Context,
|
|
||||||
request_id: u64,
|
|
||||||
) -> io::Result<()> {
|
|
||||||
let trace_id = *context.trace_id();
|
|
||||||
let cancel = ClientMessage::Cancel {
|
|
||||||
trace_context: context.trace_context,
|
|
||||||
request_id,
|
|
||||||
};
|
|
||||||
self.as_mut().project().transport.start_send(cancel)?;
|
|
||||||
trace!("[{}] Cancel message sent.", trace_id);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends a server response to the client task that initiated the associated request.
|
|
||||||
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
|
|
||||||
if let Some(in_flight_data) = self
|
|
||||||
.as_mut()
|
|
||||||
.project()
|
|
||||||
.in_flight_requests
|
|
||||||
.remove(&response.request_id)
|
|
||||||
{
|
|
||||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
|
||||||
|
|
||||||
trace!("[{}] Received response.", in_flight_data.ctx.trace_id());
|
|
||||||
let _ = in_flight_data.response_completion.send(response);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
debug!(
|
|
||||||
"No in-flight request found for request_id = {}.",
|
|
||||||
response.request_id
|
|
||||||
);
|
|
||||||
|
|
||||||
// If the response completion was absent, then the request was already canceled.
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
|
|
||||||
where
|
|
||||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
|
||||||
{
|
|
||||||
type Output = io::Result<()>;
|
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
loop {
|
|
||||||
match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) {
|
|
||||||
(read, Poll::Ready(None)) => {
|
|
||||||
if self.as_mut().project().in_flight_requests.is_empty() {
|
|
||||||
info!("Shutdown: write half closed, and no requests in flight.");
|
|
||||||
return Poll::Ready(Ok(()));
|
|
||||||
}
|
|
||||||
info!(
|
|
||||||
"Shutdown: write half closed, and {} requests in flight.",
|
|
||||||
self.as_mut().project().in_flight_requests.len()
|
|
||||||
);
|
|
||||||
match read {
|
|
||||||
Poll::Ready(Some(())) => continue,
|
|
||||||
_ => return Poll::Pending,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
|
|
||||||
_ => return Poll::Pending,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage
|
|
||||||
/// the lifecycle of the request.
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct DispatchRequest<Req, Resp> {
|
|
||||||
ctx: context::Context,
|
|
||||||
request_id: u64,
|
|
||||||
request: Req,
|
|
||||||
response_completion: oneshot::Sender<Response<Resp>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct InFlightData<Resp> {
|
|
||||||
ctx: context::Context,
|
|
||||||
response_completion: oneshot::Sender<Response<Resp>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends request cancellation signals.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct RequestCancellation(mpsc::UnboundedSender<u64>);
|
|
||||||
|
|
||||||
/// A stream of IDs of requests that have been canceled.
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
|
|
||||||
|
|
||||||
/// Returns a channel to send request cancellation messages.
|
|
||||||
fn cancellations() -> (RequestCancellation, CanceledRequests) {
|
|
||||||
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
|
|
||||||
// bounded by the number of in-flight requests. Additionally, each request has a clone
|
|
||||||
// of the sender, so the bounded channel would have the same behavior,
|
|
||||||
// since it guarantees a slot.
|
|
||||||
let (tx, rx) = mpsc::unbounded();
|
|
||||||
(RequestCancellation(tx), CanceledRequests(rx))
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RequestCancellation {
|
|
||||||
/// Cancels the request with ID `request_id`.
|
|
||||||
fn cancel(&mut self, request_id: u64) {
|
|
||||||
let _ = self.0.unbounded_send(request_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Stream for CanceledRequests {
|
|
||||||
type Item = u64;
|
|
||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
|
||||||
self.0.poll_next_unpin(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
#[must_use = "futures do nothing unless polled"]
|
|
||||||
struct MapErrConnectionReset<Fut> {
|
|
||||||
#[pin]
|
|
||||||
future: Fut,
|
|
||||||
finished: Option<()>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Fut> MapErrConnectionReset<Fut> {
|
|
||||||
fn new(future: Fut) -> MapErrConnectionReset<Fut> {
|
|
||||||
MapErrConnectionReset {
|
|
||||||
future,
|
|
||||||
finished: Some(()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Fut> Future for MapErrConnectionReset<Fut>
|
|
||||||
where
|
|
||||||
Fut: TryFuture,
|
|
||||||
{
|
|
||||||
type Output = io::Result<Fut::Ok>;
|
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
||||||
match self.as_mut().project().future.try_poll(cx) {
|
|
||||||
Poll::Pending => Poll::Pending,
|
|
||||||
Poll::Ready(result) => {
|
|
||||||
self.project().finished.take().expect(
|
|
||||||
"MapErrConnectionReset must not be polled after it returned `Poll::Ready`",
|
|
||||||
);
|
|
||||||
Poll::Ready(result.map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
#[must_use = "futures do nothing unless polled"]
|
|
||||||
struct MapOkDispatchResponse<Fut, Resp> {
|
|
||||||
#[pin]
|
|
||||||
future: Fut,
|
|
||||||
response: Option<DispatchResponse<Resp>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Fut, Resp> MapOkDispatchResponse<Fut, Resp> {
|
|
||||||
fn new(future: Fut, response: DispatchResponse<Resp>) -> MapOkDispatchResponse<Fut, Resp> {
|
|
||||||
MapOkDispatchResponse {
|
|
||||||
future,
|
|
||||||
response: Some(response),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Fut, Resp> Future for MapOkDispatchResponse<Fut, Resp>
|
|
||||||
where
|
|
||||||
Fut: TryFuture,
|
|
||||||
{
|
|
||||||
type Output = Result<DispatchResponse<Resp>, Fut::Error>;
|
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
||||||
match self.as_mut().project().future.try_poll(cx) {
|
|
||||||
Poll::Pending => Poll::Pending,
|
|
||||||
Poll::Ready(result) => {
|
|
||||||
let response = self
|
|
||||||
.as_mut()
|
|
||||||
.project()
|
|
||||||
.response
|
|
||||||
.take()
|
|
||||||
.expect("MapOk must not be polled after it returned `Poll::Ready`");
|
|
||||||
Poll::Ready(result.map(|_| response))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
#[must_use = "futures do nothing unless polled"]
|
|
||||||
struct AndThenIdent<Fut1, Fut2> {
|
|
||||||
#[pin]
|
|
||||||
try_chain: TryChain<Fut1, Fut2>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Fut1, Fut2> AndThenIdent<Fut1, Fut2>
|
|
||||||
where
|
|
||||||
Fut1: TryFuture<Ok = Fut2>,
|
|
||||||
Fut2: TryFuture,
|
|
||||||
{
|
|
||||||
/// Creates a new `Then`.
|
|
||||||
fn new(future: Fut1) -> AndThenIdent<Fut1, Fut2> {
|
|
||||||
AndThenIdent {
|
|
||||||
try_chain: TryChain::new(future),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Fut1, Fut2> Future for AndThenIdent<Fut1, Fut2>
|
|
||||||
where
|
|
||||||
Fut1: TryFuture<Ok = Fut2>,
|
|
||||||
Fut2: TryFuture<Error = Fut1::Error>,
|
|
||||||
{
|
|
||||||
type Output = Result<Fut2::Ok, Fut2::Error>;
|
|
||||||
|
|
||||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
||||||
self.project().try_chain.poll(cx, |result| match result {
|
|
||||||
Ok(ok) => TryChainAction::Future(ok),
|
|
||||||
Err(err) => TryChainAction::Output(Err(err)),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use = "futures do nothing unless polled"]
|
|
||||||
#[derive(Debug)]
|
|
||||||
enum TryChain<Fut1, Fut2> {
|
|
||||||
First(Fut1),
|
|
||||||
Second(Fut2),
|
|
||||||
Empty,
|
|
||||||
}
|
|
||||||
|
|
||||||
enum TryChainAction<Fut2>
|
|
||||||
where
|
|
||||||
Fut2: TryFuture,
|
|
||||||
{
|
|
||||||
Future(Fut2),
|
|
||||||
Output(Result<Fut2::Ok, Fut2::Error>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Fut1, Fut2> TryChain<Fut1, Fut2>
|
|
||||||
where
|
|
||||||
Fut1: TryFuture<Ok = Fut2>,
|
|
||||||
Fut2: TryFuture,
|
|
||||||
{
|
|
||||||
fn new(fut1: Fut1) -> TryChain<Fut1, Fut2> {
|
|
||||||
TryChain::First(fut1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll<F>(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
f: F,
|
|
||||||
) -> Poll<Result<Fut2::Ok, Fut2::Error>>
|
|
||||||
where
|
|
||||||
F: FnOnce(Result<Fut1::Ok, Fut1::Error>) -> TryChainAction<Fut2>,
|
|
||||||
{
|
|
||||||
let mut f = Some(f);
|
|
||||||
|
|
||||||
// Safe to call `get_unchecked_mut` because we won't move the futures.
|
|
||||||
let this = unsafe { Pin::get_unchecked_mut(self) };
|
|
||||||
|
|
||||||
loop {
|
|
||||||
let output = match this {
|
|
||||||
TryChain::First(fut1) => {
|
|
||||||
// Poll the first future
|
|
||||||
match unsafe { Pin::new_unchecked(fut1) }.try_poll(cx) {
|
|
||||||
Poll::Pending => return Poll::Pending,
|
|
||||||
Poll::Ready(output) => output,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TryChain::Second(fut2) => {
|
|
||||||
// Poll the second future
|
|
||||||
return unsafe { Pin::new_unchecked(fut2) }.try_poll(cx);
|
|
||||||
}
|
|
||||||
TryChain::Empty => {
|
|
||||||
panic!("future must not be polled after it returned `Poll::Ready`");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
*this = TryChain::Empty; // Drop fut1
|
|
||||||
let f = f.take().unwrap();
|
|
||||||
match f(output) {
|
|
||||||
TryChainAction::Future(fut2) => *this = TryChain::Second(fut2),
|
|
||||||
TryChainAction::Output(output) => return Poll::Ready(output),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::{
|
|
||||||
cancellations, CanceledRequests, Channel, DispatchResponse, RequestCancellation,
|
|
||||||
RequestDispatch,
|
|
||||||
};
|
|
||||||
use crate::{
|
|
||||||
client::Config,
|
|
||||||
context,
|
|
||||||
transport::{self, channel::UnboundedChannel},
|
|
||||||
ClientMessage, Response,
|
|
||||||
};
|
|
||||||
use fnv::FnvHashMap;
|
|
||||||
use futures::{
|
|
||||||
channel::{mpsc, oneshot},
|
|
||||||
prelude::*,
|
|
||||||
task::*,
|
|
||||||
};
|
|
||||||
use std::time::Duration;
|
|
||||||
use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc};
|
|
||||||
|
|
||||||
#[tokio::test(threaded_scheduler)]
|
|
||||||
async fn dispatch_response_cancels_on_timeout() {
|
|
||||||
let (_response_completion, response) = oneshot::channel();
|
|
||||||
let (cancellation, mut canceled_requests) = cancellations();
|
|
||||||
let resp = DispatchResponse::<u64> {
|
|
||||||
// Timeout in the past should cause resp to error out when polled.
|
|
||||||
response: tokio::time::timeout(Duration::from_secs(0), response),
|
|
||||||
complete: false,
|
|
||||||
request_id: 3,
|
|
||||||
cancellation,
|
|
||||||
ctx: context::current(),
|
|
||||||
};
|
|
||||||
let _ = futures::poll!(resp);
|
|
||||||
// resp's drop() is run, which should send a cancel message.
|
|
||||||
assert!(canceled_requests.0.try_next().unwrap() == Some(3));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test(threaded_scheduler)]
|
|
||||||
async fn stage_request() {
|
|
||||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
|
||||||
let dispatch = Pin::new(&mut dispatch);
|
|
||||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
|
||||||
|
|
||||||
let _resp = send_request(&mut channel, "hi").await;
|
|
||||||
|
|
||||||
let req = dispatch.poll_next_request(cx).ready();
|
|
||||||
assert!(req.is_some());
|
|
||||||
|
|
||||||
let req = req.unwrap();
|
|
||||||
assert_eq!(req.request_id, 0);
|
|
||||||
assert_eq!(req.request, "hi".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Regression test for https://github.com/google/tarpc/issues/220
|
|
||||||
#[tokio::test(threaded_scheduler)]
|
|
||||||
async fn stage_request_channel_dropped_doesnt_panic() {
|
|
||||||
let (mut dispatch, mut channel, mut server_channel) = set_up();
|
|
||||||
let mut dispatch = Pin::new(&mut dispatch);
|
|
||||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
|
||||||
|
|
||||||
let _ = send_request(&mut channel, "hi").await;
|
|
||||||
drop(channel);
|
|
||||||
|
|
||||||
assert!(dispatch.as_mut().poll(cx).is_ready());
|
|
||||||
send_response(
|
|
||||||
&mut server_channel,
|
|
||||||
Response {
|
|
||||||
request_id: 0,
|
|
||||||
message: Ok("hello".into()),
|
|
||||||
_non_exhaustive: (),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
dispatch.await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test(threaded_scheduler)]
|
|
||||||
async fn stage_request_response_future_dropped_is_canceled_before_sending() {
|
|
||||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
|
||||||
let dispatch = Pin::new(&mut dispatch);
|
|
||||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
|
||||||
|
|
||||||
let _ = send_request(&mut channel, "hi").await;
|
|
||||||
|
|
||||||
// Drop the channel so polling returns none if no requests are currently ready.
|
|
||||||
drop(channel);
|
|
||||||
// Test that a request future dropped before it's processed by dispatch will cause the request
|
|
||||||
// to not be added to the in-flight request map.
|
|
||||||
assert!(dispatch.poll_next_request(cx).ready().is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test(threaded_scheduler)]
|
|
||||||
async fn stage_request_response_future_dropped_is_canceled_after_sending() {
|
|
||||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
|
||||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
|
||||||
let mut dispatch = Pin::new(&mut dispatch);
|
|
||||||
|
|
||||||
let req = send_request(&mut channel, "hi").await;
|
|
||||||
|
|
||||||
assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
|
|
||||||
assert!(!dispatch.as_mut().project().in_flight_requests.is_empty());
|
|
||||||
|
|
||||||
// Test that a request future dropped after it's processed by dispatch will cause the request
|
|
||||||
// to be removed from the in-flight request map.
|
|
||||||
drop(req);
|
|
||||||
if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(cx).unwrap() {
|
|
||||||
// ok
|
|
||||||
} else {
|
|
||||||
panic!("Expected request to be cancelled")
|
|
||||||
};
|
|
||||||
assert!(dispatch.project().in_flight_requests.is_empty());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test(threaded_scheduler)]
|
|
||||||
async fn stage_request_response_closed_skipped() {
|
|
||||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
|
||||||
let dispatch = Pin::new(&mut dispatch);
|
|
||||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
|
||||||
|
|
||||||
// Test that a request future that's closed its receiver but not yet canceled its request --
|
|
||||||
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
|
|
||||||
// map.
|
|
||||||
let mut resp = send_request(&mut channel, "hi").await;
|
|
||||||
resp.response.get_mut().close();
|
|
||||||
|
|
||||||
assert!(dispatch.poll_next_request(cx).is_pending());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_up() -> (
|
|
||||||
RequestDispatch<String, String, UnboundedChannel<Response<String>, ClientMessage<String>>>,
|
|
||||||
Channel<String, String>,
|
|
||||||
UnboundedChannel<ClientMessage<String>, Response<String>>,
|
|
||||||
) {
|
|
||||||
let _ = env_logger::try_init();
|
|
||||||
|
|
||||||
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
|
||||||
let (cancel_tx, canceled_requests) = mpsc::unbounded();
|
|
||||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
|
||||||
|
|
||||||
let dispatch = RequestDispatch::<String, String, _> {
|
|
||||||
transport: client_channel.fuse(),
|
|
||||||
pending_requests: pending_requests.fuse(),
|
|
||||||
canceled_requests: CanceledRequests(canceled_requests).fuse(),
|
|
||||||
in_flight_requests: FnvHashMap::default(),
|
|
||||||
config: Config::default(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let cancellation = RequestCancellation(cancel_tx);
|
|
||||||
let channel = Channel {
|
|
||||||
to_dispatch,
|
|
||||||
cancellation,
|
|
||||||
next_request_id: Arc::new(AtomicU64::new(0)),
|
|
||||||
};
|
|
||||||
|
|
||||||
(dispatch, channel, server_channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_request(
|
|
||||||
channel: &mut Channel<String, String>,
|
|
||||||
request: &str,
|
|
||||||
) -> DispatchResponse<String> {
|
|
||||||
channel
|
|
||||||
.send(context::current(), request.to_string())
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_response(
|
|
||||||
channel: &mut UnboundedChannel<ClientMessage<String>, Response<String>>,
|
|
||||||
response: Response<String>,
|
|
||||||
) {
|
|
||||||
channel.send(response).await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
trait PollTest {
|
|
||||||
type T;
|
|
||||||
fn unwrap(self) -> Poll<Self::T>;
|
|
||||||
fn ready(self) -> Self::T;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, E> PollTest for Poll<Option<Result<T, E>>>
|
|
||||||
where
|
|
||||||
E: ::std::fmt::Display,
|
|
||||||
{
|
|
||||||
type T = Option<T>;
|
|
||||||
|
|
||||||
fn unwrap(self) -> Poll<Option<T>> {
|
|
||||||
match self {
|
|
||||||
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
|
|
||||||
Poll::Ready(None) => Poll::Ready(None),
|
|
||||||
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
|
|
||||||
Poll::Pending => Poll::Pending,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ready(self) -> Option<T> {
|
|
||||||
match self {
|
|
||||||
Poll::Ready(Some(Ok(t))) => Some(t),
|
|
||||||
Poll::Ready(None) => None,
|
|
||||||
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
|
|
||||||
Poll::Pending => panic!("Pending"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
// Copyright 2018 Google LLC
|
|
||||||
//
|
|
||||||
// Use of this source code is governed by an MIT-style
|
|
||||||
// license that can be found in the LICENSE file or at
|
|
||||||
// https://opensource.org/licenses/MIT.
|
|
||||||
|
|
||||||
//! Provides a client that connects to a server and sends multiplexed requests.
|
|
||||||
|
|
||||||
use crate::context;
|
|
||||||
use futures::prelude::*;
|
|
||||||
use std::io;
|
|
||||||
|
|
||||||
/// Provides a [`Client`] backed by a transport.
|
|
||||||
pub mod channel;
|
|
||||||
pub use channel::{new, Channel};
|
|
||||||
|
|
||||||
/// Sends multiplexed requests to, and receives responses from, a server.
|
|
||||||
pub trait Client<'a, Req> {
|
|
||||||
/// The response type.
|
|
||||||
type Response;
|
|
||||||
|
|
||||||
/// The future response.
|
|
||||||
type Future: Future<Output = io::Result<Self::Response>> + 'a;
|
|
||||||
|
|
||||||
/// Initiates a request, sending it to the dispatch task.
|
|
||||||
///
|
|
||||||
/// Returns a [`Future`] that resolves to this client and the future response
|
|
||||||
/// once the request is successfully enqueued.
|
|
||||||
///
|
|
||||||
/// [`Future`]: futures::Future
|
|
||||||
fn call(&'a mut self, ctx: context::Context, request: Req) -> Self::Future;
|
|
||||||
|
|
||||||
/// Returns a Client that applies a post-processing function to the returned response.
|
|
||||||
fn map_response<F, R>(self, f: F) -> MapResponse<Self, F>
|
|
||||||
where
|
|
||||||
F: FnMut(Self::Response) -> R,
|
|
||||||
Self: Sized,
|
|
||||||
{
|
|
||||||
MapResponse { inner: self, f }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a Client that applies a pre-processing function to the request.
|
|
||||||
fn with_request<F, Req2>(self, f: F) -> WithRequest<Self, F>
|
|
||||||
where
|
|
||||||
F: FnMut(Req2) -> Req,
|
|
||||||
Self: Sized,
|
|
||||||
{
|
|
||||||
WithRequest { inner: self, f }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A Client that applies a function to the returned response.
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct MapResponse<C, F> {
|
|
||||||
inner: C,
|
|
||||||
f: F,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, C, F, Req, Resp, Resp2> Client<'a, Req> for MapResponse<C, F>
|
|
||||||
where
|
|
||||||
C: Client<'a, Req, Response = Resp>,
|
|
||||||
F: FnMut(Resp) -> Resp2 + 'a,
|
|
||||||
{
|
|
||||||
type Response = Resp2;
|
|
||||||
type Future = futures::future::MapOk<<C as Client<'a, Req>>::Future, &'a mut F>;
|
|
||||||
|
|
||||||
fn call(&'a mut self, ctx: context::Context, request: Req) -> Self::Future {
|
|
||||||
self.inner.call(ctx, request).map_ok(&mut self.f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A Client that applies a pre-processing function to the request.
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct WithRequest<C, F> {
|
|
||||||
inner: C,
|
|
||||||
f: F,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, C, F, Req, Req2, Resp> Client<'a, Req2> for WithRequest<C, F>
|
|
||||||
where
|
|
||||||
C: Client<'a, Req, Response = Resp>,
|
|
||||||
F: FnMut(Req2) -> Req,
|
|
||||||
{
|
|
||||||
type Response = Resp;
|
|
||||||
type Future = <C as Client<'a, Req>>::Future;
|
|
||||||
|
|
||||||
fn call(&'a mut self, ctx: context::Context, request: Req2) -> Self::Future {
|
|
||||||
self.inner.call(ctx, (self.f)(request))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, Req, Resp> Client<'a, Req> for Channel<Req, Resp>
|
|
||||||
where
|
|
||||||
Req: 'a,
|
|
||||||
Resp: 'a,
|
|
||||||
{
|
|
||||||
type Response = Resp;
|
|
||||||
type Future = channel::Call<'a, Req, Resp>;
|
|
||||||
|
|
||||||
fn call(&'a mut self, ctx: context::Context, request: Req) -> channel::Call<'a, Req, Resp> {
|
|
||||||
self.call(ctx, request)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Settings that control the behavior of the client.
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct Config {
|
|
||||||
/// The number of requests that can be in flight at once.
|
|
||||||
/// `max_in_flight_requests` controls the size of the map used by the client
|
|
||||||
/// for storing pending requests.
|
|
||||||
pub max_in_flight_requests: usize,
|
|
||||||
/// The number of requests that can be buffered client-side before being sent.
|
|
||||||
/// `pending_requests_buffer` controls the size of the channel clients use
|
|
||||||
/// to communicate with the request dispatch task.
|
|
||||||
pub pending_request_buffer: usize,
|
|
||||||
#[doc(hidden)]
|
|
||||||
_non_exhaustive: (),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for Config {
|
|
||||||
fn default() -> Self {
|
|
||||||
Config {
|
|
||||||
max_in_flight_requests: 1_000,
|
|
||||||
pending_request_buffer: 100,
|
|
||||||
_non_exhaustive: (),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A channel and dispatch pair. The dispatch drives the sending and receiving of requests
|
|
||||||
/// and must be polled continuously or spawned.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct NewClient<C, D> {
|
|
||||||
/// The new client.
|
|
||||||
pub client: C,
|
|
||||||
/// The client's dispatch.
|
|
||||||
pub dispatch: D,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C, D> NewClient<C, D>
|
|
||||||
where
|
|
||||||
D: Future<Output = io::Result<()>> + Send + 'static,
|
|
||||||
{
|
|
||||||
/// Helper method to spawn the dispatch on the default executor.
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
pub fn spawn(self) -> io::Result<C> {
|
|
||||||
use log::error;
|
|
||||||
|
|
||||||
let dispatch = self
|
|
||||||
.dispatch
|
|
||||||
.unwrap_or_else(move |e| error!("Connection broken: {}", e));
|
|
||||||
tokio::spawn(dispatch);
|
|
||||||
Ok(self.client)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
// Copyright 2018 Google LLC
|
|
||||||
//
|
|
||||||
// Use of this source code is governed by an MIT-style
|
|
||||||
// license that can be found in the LICENSE file or at
|
|
||||||
// https://opensource.org/licenses/MIT.
|
|
||||||
|
|
||||||
//! Provides a request context that carries a deadline and trace context. This context is sent from
|
|
||||||
//! client to server and is used by the server to enforce response deadlines.
|
|
||||||
|
|
||||||
use crate::trace::{self, TraceId};
|
|
||||||
use std::time::{Duration, SystemTime};
|
|
||||||
|
|
||||||
/// A request context that carries request-scoped information like deadlines and trace information.
|
|
||||||
/// It is sent from client to server and is used by the server to enforce response deadlines.
|
|
||||||
///
|
|
||||||
/// The context should not be stored directly in a server implementation, because the context will
|
|
||||||
/// be different for each request in scope.
|
|
||||||
#[derive(Clone, Copy, Debug)]
|
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
|
||||||
pub struct Context {
|
|
||||||
/// When the client expects the request to be complete by. The server should cancel the request
|
|
||||||
/// if it is not complete by this time.
|
|
||||||
#[cfg_attr(
|
|
||||||
feature = "serde1",
|
|
||||||
serde(serialize_with = "crate::util::serde::serialize_epoch_secs")
|
|
||||||
)]
|
|
||||||
#[cfg_attr(
|
|
||||||
feature = "serde1",
|
|
||||||
serde(deserialize_with = "crate::util::serde::deserialize_epoch_secs")
|
|
||||||
)]
|
|
||||||
#[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))]
|
|
||||||
pub deadline: SystemTime,
|
|
||||||
/// Uniquely identifies requests originating from the same source.
|
|
||||||
/// When a service handles a request by making requests itself, those requests should
|
|
||||||
/// include the same `trace_id` as that included on the original request. This way,
|
|
||||||
/// users can trace related actions across a distributed system.
|
|
||||||
pub trace_context: trace::Context,
|
|
||||||
#[doc(hidden)]
|
|
||||||
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
|
|
||||||
pub(crate) _non_exhaustive: (),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "serde1")]
|
|
||||||
fn ten_seconds_from_now() -> SystemTime {
|
|
||||||
SystemTime::now() + Duration::from_secs(10)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the context for the current request, or a default Context if no request is active.
|
|
||||||
// TODO: populate Context with request-scoped data, with default fallbacks.
|
|
||||||
pub fn current() -> Context {
|
|
||||||
Context {
|
|
||||||
deadline: SystemTime::now() + Duration::from_secs(10),
|
|
||||||
trace_context: trace::Context::new_root(),
|
|
||||||
_non_exhaustive: (),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Context {
|
|
||||||
/// Returns the ID of the request-scoped trace.
|
|
||||||
pub fn trace_id(&self) -> &TraceId {
|
|
||||||
&self.trace_context.trace_id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
// Copyright 2018 Google LLC
|
|
||||||
//
|
|
||||||
// Use of this source code is governed by an MIT-style
|
|
||||||
// license that can be found in the LICENSE file or at
|
|
||||||
// https://opensource.org/licenses/MIT.
|
|
||||||
|
|
||||||
#![deny(missing_docs, missing_debug_implementations)]
|
|
||||||
|
|
||||||
//! An RPC framework providing client and server.
|
|
||||||
//!
|
|
||||||
//! Features:
|
|
||||||
//! * RPC deadlines, both client- and server-side.
|
|
||||||
//! * Cascading cancellation (works with multiple hops).
|
|
||||||
//! * Configurable limits
|
|
||||||
//! * In-flight requests, both client and server-side.
|
|
||||||
//! * Server-side limit is per-connection.
|
|
||||||
//! * When the server reaches the in-flight request maximum, it returns a throttled error
|
|
||||||
//! to the client.
|
|
||||||
//! * When the client reaches the in-flight request max, messages are buffered up to a
|
|
||||||
//! configurable maximum, beyond which the requests are back-pressured.
|
|
||||||
//! * Server connections.
|
|
||||||
//! * Total and per-IP limits.
|
|
||||||
//! * When an incoming connection is accepted, if already at maximum, the connection is
|
|
||||||
//! dropped.
|
|
||||||
//! * Transport agnostic.
|
|
||||||
|
|
||||||
pub mod client;
|
|
||||||
pub mod context;
|
|
||||||
pub mod server;
|
|
||||||
pub mod transport;
|
|
||||||
pub(crate) mod util;
|
|
||||||
|
|
||||||
pub use crate::{client::Client, server::Server, trace, transport::sealed::Transport};
|
|
||||||
|
|
||||||
use futures::task::*;
|
|
||||||
use std::{io, time::SystemTime};
|
|
||||||
|
|
||||||
/// A message from a client to a server.
|
|
||||||
#[derive(Debug)]
|
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
|
||||||
pub enum ClientMessage<T> {
|
|
||||||
/// A request initiated by a user. The server responds to a request by invoking a
|
|
||||||
/// service-provided request handler. The handler completes with a [`response`](Response), which
|
|
||||||
/// the server sends back to the client.
|
|
||||||
Request(Request<T>),
|
|
||||||
/// A command to cancel an in-flight request, automatically sent by the client when a response
|
|
||||||
/// future is dropped.
|
|
||||||
///
|
|
||||||
/// When received, the server will immediately cancel the main task (top-level future) of the
|
|
||||||
/// request handler for the associated request. Any tasks spawned by the request handler will
|
|
||||||
/// not be canceled, because the framework layer does not
|
|
||||||
/// know about them.
|
|
||||||
Cancel {
|
|
||||||
/// The trace context associates the message with a specific chain of causally-related actions,
|
|
||||||
/// possibly orchestrated across many distributed systems.
|
|
||||||
#[cfg_attr(feature = "serde", serde(default))]
|
|
||||||
trace_context: trace::Context,
|
|
||||||
/// The ID of the request to cancel.
|
|
||||||
request_id: u64,
|
|
||||||
},
|
|
||||||
#[doc(hidden)]
|
|
||||||
_NonExhaustive,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A request from a client to a server.
|
|
||||||
#[derive(Clone, Copy, Debug)]
|
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
|
||||||
pub struct Request<T> {
|
|
||||||
/// Trace context, deadline, and other cross-cutting concerns.
|
|
||||||
pub context: context::Context,
|
|
||||||
/// Uniquely identifies the request across all requests sent over a single channel.
|
|
||||||
pub id: u64,
|
|
||||||
/// The request body.
|
|
||||||
pub message: T,
|
|
||||||
#[doc(hidden)]
|
|
||||||
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
|
|
||||||
_non_exhaustive: (),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A response from a server to a client.
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
|
||||||
pub struct Response<T> {
|
|
||||||
/// The ID of the request being responded to.
|
|
||||||
pub request_id: u64,
|
|
||||||
/// The response body, or an error if the request failed.
|
|
||||||
pub message: Result<T, ServerError>,
|
|
||||||
#[doc(hidden)]
|
|
||||||
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
|
|
||||||
_non_exhaustive: (),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// An error response from a server to a client.
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
|
||||||
pub struct ServerError {
|
|
||||||
#[cfg_attr(
|
|
||||||
feature = "serde1",
|
|
||||||
serde(serialize_with = "util::serde::serialize_io_error_kind_as_u32")
|
|
||||||
)]
|
|
||||||
#[cfg_attr(
|
|
||||||
feature = "serde1",
|
|
||||||
serde(deserialize_with = "util::serde::deserialize_io_error_kind_from_u32")
|
|
||||||
)]
|
|
||||||
/// The type of error that occurred to fail the request.
|
|
||||||
pub kind: io::ErrorKind,
|
|
||||||
/// A message describing more detail about the error that occurred.
|
|
||||||
pub detail: Option<String>,
|
|
||||||
#[doc(hidden)]
|
|
||||||
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
|
|
||||||
_non_exhaustive: (),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<ServerError> for io::Error {
|
|
||||||
fn from(e: ServerError) -> io::Error {
|
|
||||||
io::Error::new(e.kind, e.detail.unwrap_or_default())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> Request<T> {
|
|
||||||
/// Returns the deadline for this request.
|
|
||||||
pub fn deadline(&self) -> &SystemTime {
|
|
||||||
&self.context.deadline
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) type PollIo<T> = Poll<Option<io::Result<T>>>;
|
|
||||||
@@ -1,702 +0,0 @@
|
|||||||
// Copyright 2018 Google LLC
|
|
||||||
//
|
|
||||||
// Use of this source code is governed by an MIT-style
|
|
||||||
// license that can be found in the LICENSE file or at
|
|
||||||
// https://opensource.org/licenses/MIT.
|
|
||||||
|
|
||||||
//! Provides a server that concurrently handles many connections sending multiplexed requests.
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
context, trace, util::Compact, util::TimeUntil, ClientMessage, PollIo, Request, Response,
|
|
||||||
ServerError, Transport,
|
|
||||||
};
|
|
||||||
use fnv::FnvHashMap;
|
|
||||||
use futures::{
|
|
||||||
channel::mpsc,
|
|
||||||
future::{AbortHandle, AbortRegistration, Abortable},
|
|
||||||
prelude::*,
|
|
||||||
ready,
|
|
||||||
stream::Fuse,
|
|
||||||
task::*,
|
|
||||||
};
|
|
||||||
use humantime::format_rfc3339;
|
|
||||||
use log::{debug, trace};
|
|
||||||
use pin_project::pin_project;
|
|
||||||
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
|
|
||||||
use tokio::time::Timeout;
|
|
||||||
|
|
||||||
mod filter;
|
|
||||||
#[cfg(test)]
|
|
||||||
mod testing;
|
|
||||||
mod throttle;
|
|
||||||
|
|
||||||
pub use self::{
|
|
||||||
filter::ChannelFilter,
|
|
||||||
throttle::{Throttler, ThrottlerStream},
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Manages clients, serving multiplexed requests over each connection.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Server<Req, Resp> {
|
|
||||||
config: Config,
|
|
||||||
ghost: PhantomData<(Req, Resp)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Req, Resp> Default for Server<Req, Resp> {
|
|
||||||
fn default() -> Self {
|
|
||||||
new(Config::default())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Settings that control the behavior of the server.
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct Config {
|
|
||||||
/// The number of responses per client that can be buffered server-side before being sent.
|
|
||||||
/// `pending_response_buffer` controls the buffer size of the channel that a server's
|
|
||||||
/// response tasks use to send responses to the client handler task.
|
|
||||||
pub pending_response_buffer: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for Config {
|
|
||||||
fn default() -> Self {
|
|
||||||
Config {
|
|
||||||
pending_response_buffer: 100,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
/// Returns a channel backed by `transport` and configured with `self`.
|
|
||||||
pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
|
|
||||||
where
|
|
||||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
|
||||||
{
|
|
||||||
BaseChannel::new(self, transport)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a new server with configuration specified `config`.
|
|
||||||
pub fn new<Req, Resp>(config: Config) -> Server<Req, Resp> {
|
|
||||||
Server {
|
|
||||||
config,
|
|
||||||
ghost: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Req, Resp> Server<Req, Resp> {
|
|
||||||
/// Returns the config for this server.
|
|
||||||
pub fn config(&self) -> &Config {
|
|
||||||
&self.config
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a stream of server channels.
|
|
||||||
pub fn incoming<S, T>(self, listener: S) -> impl Stream<Item = BaseChannel<Req, Resp, T>>
|
|
||||||
where
|
|
||||||
S: Stream<Item = T>,
|
|
||||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
|
||||||
{
|
|
||||||
listener.map(move |t| BaseChannel::new(self.config.clone(), t))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Basically a Fn(Req) -> impl Future<Output = Resp>;
|
|
||||||
pub trait Serve<Req>: Sized + Clone {
|
|
||||||
/// Type of response.
|
|
||||||
type Resp;
|
|
||||||
|
|
||||||
/// Type of response future.
|
|
||||||
type Fut: Future<Output = Self::Resp>;
|
|
||||||
|
|
||||||
/// Responds to a single request.
|
|
||||||
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Req, Resp, Fut, F> Serve<Req> for F
|
|
||||||
where
|
|
||||||
F: FnOnce(context::Context, Req) -> Fut + Clone,
|
|
||||||
Fut: Future<Output = Resp>,
|
|
||||||
{
|
|
||||||
type Resp = Resp;
|
|
||||||
type Fut = Fut;
|
|
||||||
|
|
||||||
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut {
|
|
||||||
self(ctx, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A utility trait enabling a stream to fluently chain a request handler.
|
|
||||||
pub trait Handler<C>
|
|
||||||
where
|
|
||||||
Self: Sized + Stream<Item = C>,
|
|
||||||
C: Channel,
|
|
||||||
{
|
|
||||||
/// Enforces channel per-key limits.
|
|
||||||
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> filter::ChannelFilter<Self, K, KF>
|
|
||||||
where
|
|
||||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
|
||||||
KF: Fn(&C) -> K,
|
|
||||||
{
|
|
||||||
ChannelFilter::new(self, n, keymaker)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Caps the number of concurrent requests per channel.
|
|
||||||
fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream<Self> {
|
|
||||||
ThrottlerStream::new(self, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Responds to all requests with `server`.
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
fn respond_with<S>(self, server: S) -> Running<Self, S>
|
|
||||||
where
|
|
||||||
S: Serve<C::Req, Resp = C::Resp>,
|
|
||||||
{
|
|
||||||
Running {
|
|
||||||
incoming: self,
|
|
||||||
server,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, C> Handler<C> for S
|
|
||||||
where
|
|
||||||
S: Sized + Stream<Item = C>,
|
|
||||||
C: Channel,
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct BaseChannel<Req, Resp, T> {
|
|
||||||
config: Config,
|
|
||||||
/// Writes responses to the wire and reads requests off the wire.
|
|
||||||
#[pin]
|
|
||||||
transport: Fuse<T>,
|
|
||||||
/// Number of requests currently being responded to.
|
|
||||||
in_flight_requests: FnvHashMap<u64, AbortHandle>,
|
|
||||||
/// Types the request and response.
|
|
||||||
ghost: PhantomData<(Req, Resp)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
|
|
||||||
where
|
|
||||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
|
||||||
{
|
|
||||||
/// Creates a new channel backed by `transport` and configured with `config`.
|
|
||||||
pub fn new(config: Config, transport: T) -> Self {
|
|
||||||
BaseChannel {
|
|
||||||
config,
|
|
||||||
transport: transport.fuse(),
|
|
||||||
in_flight_requests: FnvHashMap::default(),
|
|
||||||
ghost: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new channel backed by `transport` and configured with the defaults.
|
|
||||||
pub fn with_defaults(transport: T) -> Self {
|
|
||||||
Self::new(Config::default(), transport)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the inner transport.
|
|
||||||
pub fn get_ref(&self) -> &T {
|
|
||||||
self.transport.get_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
|
|
||||||
// It's possible the request was already completed, so it's fine
|
|
||||||
// if this is None.
|
|
||||||
if let Some(cancel_handle) = self
|
|
||||||
.as_mut()
|
|
||||||
.project()
|
|
||||||
.in_flight_requests
|
|
||||||
.remove(&request_id)
|
|
||||||
{
|
|
||||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
|
||||||
|
|
||||||
cancel_handle.abort();
|
|
||||||
let remaining = self.as_mut().project().in_flight_requests.len();
|
|
||||||
trace!(
|
|
||||||
"[{}] Request canceled. In-flight requests = {}",
|
|
||||||
trace_context.trace_id,
|
|
||||||
remaining,
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
trace!(
|
|
||||||
"[{}] Received cancellation, but response handler \
|
|
||||||
is already complete.",
|
|
||||||
trace_context.trace_id,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The server end of an open connection with a client, streaming in requests from, and sinking
|
|
||||||
/// responses to, the client.
|
|
||||||
///
|
|
||||||
/// Channels are free to somewhat rely on the assumption that all in-flight requests are eventually
|
|
||||||
/// either [cancelled](BaseChannel::cancel_request) or [responded to](Sink::start_send). Safety cannot
|
|
||||||
/// rely on this assumption, but it is best for `Channel` users to always account for all outstanding
|
|
||||||
/// requests.
|
|
||||||
pub trait Channel
|
|
||||||
where
|
|
||||||
Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
|
|
||||||
{
|
|
||||||
/// Type of request item.
|
|
||||||
type Req;
|
|
||||||
|
|
||||||
/// Type of response sink item.
|
|
||||||
type Resp;
|
|
||||||
|
|
||||||
/// Configuration of the channel.
|
|
||||||
fn config(&self) -> &Config;
|
|
||||||
|
|
||||||
/// Returns the number of in-flight requests over this channel.
|
|
||||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize;
|
|
||||||
|
|
||||||
/// Caps the number of concurrent requests.
|
|
||||||
fn max_concurrent_requests(self, n: usize) -> Throttler<Self>
|
|
||||||
where
|
|
||||||
Self: Sized,
|
|
||||||
{
|
|
||||||
Throttler::new(self, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Tells the Channel that request with ID `request_id` is being handled.
|
|
||||||
/// The request will be tracked until a response with the same ID is sent
|
|
||||||
/// to the Channel.
|
|
||||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration;
|
|
||||||
|
|
||||||
/// Respond to requests coming over the channel with `f`. Returns a future that drives the
|
|
||||||
/// responses and resolves when the connection is closed.
|
|
||||||
fn respond_with<S>(self, server: S) -> ClientHandler<Self, S>
|
|
||||||
where
|
|
||||||
S: Serve<Self::Req, Resp = Self::Resp>,
|
|
||||||
Self: Sized,
|
|
||||||
{
|
|
||||||
let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
|
|
||||||
let responses = responses.fuse();
|
|
||||||
|
|
||||||
ClientHandler {
|
|
||||||
channel: self,
|
|
||||||
server,
|
|
||||||
pending_responses: responses,
|
|
||||||
responses_tx,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
|
|
||||||
where
|
|
||||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
|
||||||
{
|
|
||||||
type Item = io::Result<Request<Req>>;
|
|
||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
|
||||||
loop {
|
|
||||||
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
|
|
||||||
Some(message) => match message {
|
|
||||||
ClientMessage::Request(request) => {
|
|
||||||
return Poll::Ready(Some(Ok(request)));
|
|
||||||
}
|
|
||||||
ClientMessage::Cancel {
|
|
||||||
trace_context,
|
|
||||||
request_id,
|
|
||||||
} => {
|
|
||||||
self.as_mut().cancel_request(&trace_context, request_id);
|
|
||||||
}
|
|
||||||
ClientMessage::_NonExhaustive => unreachable!(),
|
|
||||||
},
|
|
||||||
None => return Poll::Ready(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
|
|
||||||
where
|
|
||||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
|
||||||
{
|
|
||||||
type Error = io::Error;
|
|
||||||
|
|
||||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
|
||||||
self.project().transport.poll_ready(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
|
|
||||||
if self
|
|
||||||
.as_mut()
|
|
||||||
.project()
|
|
||||||
.in_flight_requests
|
|
||||||
.remove(&response.request_id)
|
|
||||||
.is_some()
|
|
||||||
{
|
|
||||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.project().transport.start_send(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
|
||||||
self.project().transport.poll_flush(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
|
||||||
self.project().transport.poll_close(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
|
|
||||||
fn as_ref(&self) -> &T {
|
|
||||||
self.transport.get_ref()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
|
|
||||||
where
|
|
||||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
|
||||||
{
|
|
||||||
type Req = Req;
|
|
||||||
type Resp = Resp;
|
|
||||||
|
|
||||||
fn config(&self) -> &Config {
|
|
||||||
&self.config
|
|
||||||
}
|
|
||||||
|
|
||||||
fn in_flight_requests(mut self: Pin<&mut Self>) -> usize {
|
|
||||||
self.as_mut().project().in_flight_requests.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
|
||||||
let (abort_handle, abort_registration) = AbortHandle::new_pair();
|
|
||||||
assert!(self
|
|
||||||
.project()
|
|
||||||
.in_flight_requests
|
|
||||||
.insert(request_id, abort_handle)
|
|
||||||
.is_none());
|
|
||||||
abort_registration
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A running handler serving all requests coming over a channel.
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ClientHandler<C, S>
|
|
||||||
where
|
|
||||||
C: Channel,
|
|
||||||
{
|
|
||||||
#[pin]
|
|
||||||
channel: C,
|
|
||||||
/// Responses waiting to be written to the wire.
|
|
||||||
#[pin]
|
|
||||||
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>,
|
|
||||||
/// Handed out to request handlers to fan in responses.
|
|
||||||
#[pin]
|
|
||||||
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
|
|
||||||
/// Server
|
|
||||||
server: S,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C, S> ClientHandler<C, S>
|
|
||||||
where
|
|
||||||
C: Channel,
|
|
||||||
S: Serve<C::Req, Resp = C::Resp>,
|
|
||||||
{
|
|
||||||
fn pump_read(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
) -> PollIo<RequestHandler<S::Fut, C::Resp>> {
|
|
||||||
match ready!(self.as_mut().project().channel.poll_next(cx)?) {
|
|
||||||
Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))),
|
|
||||||
None => Poll::Ready(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn pump_write(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
read_half_closed: bool,
|
|
||||||
) -> PollIo<()> {
|
|
||||||
match self.as_mut().poll_next_response(cx)? {
|
|
||||||
Poll::Ready(Some((ctx, response))) => {
|
|
||||||
trace!(
|
|
||||||
"[{}] Staging response. In-flight requests = {}.",
|
|
||||||
ctx.trace_id(),
|
|
||||||
self.as_mut().project().channel.in_flight_requests(),
|
|
||||||
);
|
|
||||||
self.as_mut().project().channel.start_send(response)?;
|
|
||||||
Poll::Ready(Some(Ok(())))
|
|
||||||
}
|
|
||||||
Poll::Ready(None) => {
|
|
||||||
// Shutdown can't be done before we finish pumping out remaining responses.
|
|
||||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
|
||||||
Poll::Ready(None)
|
|
||||||
}
|
|
||||||
Poll::Pending => {
|
|
||||||
// No more requests to process, so flush any requests buffered in the transport.
|
|
||||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
|
||||||
|
|
||||||
// Being here means there are no staged requests and all written responses are
|
|
||||||
// fully flushed. So, if the read half is closed and there are no in-flight
|
|
||||||
// requests, then we can close the write half.
|
|
||||||
if read_half_closed && self.as_mut().project().channel.in_flight_requests() == 0 {
|
|
||||||
Poll::Ready(None)
|
|
||||||
} else {
|
|
||||||
Poll::Pending
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_next_response(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
) -> PollIo<(context::Context, Response<C::Resp>)> {
|
|
||||||
// Ensure there's room to write a response.
|
|
||||||
while let Poll::Pending = self.as_mut().project().channel.poll_ready(cx)? {
|
|
||||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
|
||||||
}
|
|
||||||
|
|
||||||
match ready!(self.as_mut().project().pending_responses.poll_next(cx)) {
|
|
||||||
Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
|
|
||||||
None => {
|
|
||||||
// This branch likely won't happen, since the ClientHandler is holding a Sender.
|
|
||||||
Poll::Ready(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn handle_request(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
request: Request<C::Req>,
|
|
||||||
) -> RequestHandler<S::Fut, C::Resp> {
|
|
||||||
let request_id = request.id;
|
|
||||||
let deadline = request.context.deadline;
|
|
||||||
let timeout = deadline.time_until();
|
|
||||||
trace!(
|
|
||||||
"[{}] Received request with deadline {} (timeout {:?}).",
|
|
||||||
request.context.trace_id(),
|
|
||||||
format_rfc3339(deadline),
|
|
||||||
timeout,
|
|
||||||
);
|
|
||||||
let ctx = request.context;
|
|
||||||
let request = request.message;
|
|
||||||
|
|
||||||
let response = self.as_mut().project().server.clone().serve(ctx, request);
|
|
||||||
let response = Resp {
|
|
||||||
state: RespState::PollResp,
|
|
||||||
request_id,
|
|
||||||
ctx,
|
|
||||||
deadline,
|
|
||||||
f: tokio::time::timeout(timeout, response),
|
|
||||||
response: None,
|
|
||||||
response_tx: self.as_mut().project().responses_tx.clone(),
|
|
||||||
};
|
|
||||||
let abort_registration = self.as_mut().project().channel.start_request(request_id);
|
|
||||||
RequestHandler {
|
|
||||||
resp: Abortable::new(response, abort_registration),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A future fulfilling a single client request.
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct RequestHandler<F, R> {
|
|
||||||
#[pin]
|
|
||||||
resp: Abortable<Resp<F, R>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F, R> Future for RequestHandler<F, R>
|
|
||||||
where
|
|
||||||
F: Future<Output = R>,
|
|
||||||
{
|
|
||||||
type Output = ();
|
|
||||||
|
|
||||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
|
||||||
let _ = ready!(self.project().resp.poll(cx));
|
|
||||||
Poll::Ready(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Resp<F, R> {
|
|
||||||
state: RespState,
|
|
||||||
request_id: u64,
|
|
||||||
ctx: context::Context,
|
|
||||||
deadline: SystemTime,
|
|
||||||
#[pin]
|
|
||||||
f: Timeout<F>,
|
|
||||||
response: Option<Response<R>>,
|
|
||||||
#[pin]
|
|
||||||
response_tx: mpsc::Sender<(context::Context, Response<R>)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
#[allow(clippy::enum_variant_names)]
|
|
||||||
enum RespState {
|
|
||||||
PollResp,
|
|
||||||
PollReady,
|
|
||||||
PollFlush,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F, R> Future for Resp<F, R>
|
|
||||||
where
|
|
||||||
F: Future<Output = R>,
|
|
||||||
{
|
|
||||||
type Output = ();
|
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
|
||||||
loop {
|
|
||||||
match self.as_mut().project().state {
|
|
||||||
RespState::PollResp => {
|
|
||||||
let result = ready!(self.as_mut().project().f.poll(cx));
|
|
||||||
*self.as_mut().project().response = Some(Response {
|
|
||||||
request_id: self.request_id,
|
|
||||||
message: match result {
|
|
||||||
Ok(message) => Ok(message),
|
|
||||||
Err(tokio::time::Elapsed { .. }) => {
|
|
||||||
debug!(
|
|
||||||
"[{}] Response did not complete before deadline of {}s.",
|
|
||||||
self.ctx.trace_id(),
|
|
||||||
format_rfc3339(self.deadline)
|
|
||||||
);
|
|
||||||
// No point in responding, since the client will have dropped the
|
|
||||||
// request.
|
|
||||||
Err(ServerError {
|
|
||||||
kind: io::ErrorKind::TimedOut,
|
|
||||||
detail: Some(format!(
|
|
||||||
"Response did not complete before deadline of {}s.",
|
|
||||||
format_rfc3339(self.deadline)
|
|
||||||
)),
|
|
||||||
_non_exhaustive: (),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_non_exhaustive: (),
|
|
||||||
});
|
|
||||||
*self.as_mut().project().state = RespState::PollReady;
|
|
||||||
}
|
|
||||||
RespState::PollReady => {
|
|
||||||
let ready = ready!(self.as_mut().project().response_tx.poll_ready(cx));
|
|
||||||
if ready.is_err() {
|
|
||||||
return Poll::Ready(());
|
|
||||||
}
|
|
||||||
let resp = (self.ctx, self.as_mut().project().response.take().unwrap());
|
|
||||||
if self
|
|
||||||
.as_mut()
|
|
||||||
.project()
|
|
||||||
.response_tx
|
|
||||||
.start_send(resp)
|
|
||||||
.is_err()
|
|
||||||
{
|
|
||||||
return Poll::Ready(());
|
|
||||||
}
|
|
||||||
*self.as_mut().project().state = RespState::PollFlush;
|
|
||||||
}
|
|
||||||
RespState::PollFlush => {
|
|
||||||
let ready = ready!(self.as_mut().project().response_tx.poll_flush(cx));
|
|
||||||
if ready.is_err() {
|
|
||||||
return Poll::Ready(());
|
|
||||||
}
|
|
||||||
return Poll::Ready(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C, S> Stream for ClientHandler<C, S>
|
|
||||||
where
|
|
||||||
C: Channel,
|
|
||||||
S: Serve<C::Req, Resp = C::Resp>,
|
|
||||||
{
|
|
||||||
type Item = io::Result<RequestHandler<S::Fut, C::Resp>>;
|
|
||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
||||||
loop {
|
|
||||||
let read = self.as_mut().pump_read(cx)?;
|
|
||||||
let read_closed = if let Poll::Ready(None) = read {
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
};
|
|
||||||
match (read, self.as_mut().pump_write(cx, read_closed)?) {
|
|
||||||
(Poll::Ready(None), Poll::Ready(None)) => {
|
|
||||||
return Poll::Ready(None);
|
|
||||||
}
|
|
||||||
(Poll::Ready(Some(request_handler)), _) => {
|
|
||||||
return Poll::Ready(Some(Ok(request_handler)));
|
|
||||||
}
|
|
||||||
(_, Poll::Ready(Some(()))) => {}
|
|
||||||
_ => {
|
|
||||||
return Poll::Pending;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send + 'static execution helper methods.
|
|
||||||
|
|
||||||
impl<C, S> ClientHandler<C, S>
|
|
||||||
where
|
|
||||||
C: Channel + 'static,
|
|
||||||
C::Req: Send + 'static,
|
|
||||||
C::Resp: Send + 'static,
|
|
||||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
|
||||||
S::Fut: Send + 'static,
|
|
||||||
{
|
|
||||||
/// Runs the client handler until completion by spawning each
|
|
||||||
/// request handler onto the default executor.
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
pub fn execute(self) -> impl Future<Output = ()> {
|
|
||||||
use log::info;
|
|
||||||
|
|
||||||
self.try_for_each(|request_handler| {
|
|
||||||
async {
|
|
||||||
tokio::spawn(request_handler);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.unwrap_or_else(|e| info!("ClientHandler errored out: {}", e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A future that drives the server by spawning channels and request handlers on the default
|
|
||||||
/// executor.
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
pub struct Running<St, Se> {
|
|
||||||
#[pin]
|
|
||||||
incoming: St,
|
|
||||||
server: Se,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
impl<St, C, Se> Future for Running<St, Se>
|
|
||||||
where
|
|
||||||
St: Sized + Stream<Item = C>,
|
|
||||||
C: Channel + Send + 'static,
|
|
||||||
C::Req: Send + 'static,
|
|
||||||
C::Resp: Send + 'static,
|
|
||||||
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
|
||||||
Se::Fut: Send + 'static,
|
|
||||||
{
|
|
||||||
type Output = ();
|
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
|
||||||
use log::info;
|
|
||||||
|
|
||||||
while let Some(channel) = ready!(self.as_mut().project().incoming.poll_next(cx)) {
|
|
||||||
tokio::spawn(
|
|
||||||
channel
|
|
||||||
.respond_with(self.as_mut().project().server.clone())
|
|
||||||
.execute(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
info!("Server shutting down.");
|
|
||||||
Poll::Ready(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,326 +0,0 @@
|
|||||||
use super::{Channel, Config};
|
|
||||||
use crate::{Response, ServerError};
|
|
||||||
use futures::{future::AbortRegistration, prelude::*, ready, task::*};
|
|
||||||
use log::debug;
|
|
||||||
use pin_project::pin_project;
|
|
||||||
use std::{io, pin::Pin};
|
|
||||||
|
|
||||||
/// A [`Channel`] that limits the number of concurrent
|
|
||||||
/// requests by throttling.
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Throttler<C> {
|
|
||||||
max_in_flight_requests: usize,
|
|
||||||
#[pin]
|
|
||||||
inner: C,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C> Throttler<C> {
|
|
||||||
/// Returns the inner channel.
|
|
||||||
pub fn get_ref(&self) -> &C {
|
|
||||||
&self.inner
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C> Throttler<C>
|
|
||||||
where
|
|
||||||
C: Channel,
|
|
||||||
{
|
|
||||||
/// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to
|
|
||||||
/// `max_in_flight_requests`.
|
|
||||||
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
|
|
||||||
Throttler {
|
|
||||||
inner,
|
|
||||||
max_in_flight_requests,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C> Stream for Throttler<C>
|
|
||||||
where
|
|
||||||
C: Channel,
|
|
||||||
{
|
|
||||||
type Item = <C as Stream>::Item;
|
|
||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
|
||||||
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
|
|
||||||
{
|
|
||||||
ready!(self.as_mut().project().inner.poll_ready(cx)?);
|
|
||||||
|
|
||||||
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
|
|
||||||
Some(request) => {
|
|
||||||
debug!(
|
|
||||||
"[{}] Client has reached in-flight request limit ({}/{}).",
|
|
||||||
request.context.trace_id(),
|
|
||||||
self.as_mut().in_flight_requests(),
|
|
||||||
self.as_mut().project().max_in_flight_requests,
|
|
||||||
);
|
|
||||||
|
|
||||||
self.as_mut().start_send(Response {
|
|
||||||
request_id: request.id,
|
|
||||||
message: Err(ServerError {
|
|
||||||
kind: io::ErrorKind::WouldBlock,
|
|
||||||
detail: Some("Server throttled the request.".into()),
|
|
||||||
_non_exhaustive: (),
|
|
||||||
}),
|
|
||||||
_non_exhaustive: (),
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
None => return Poll::Ready(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.project().inner.poll_next(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
|
|
||||||
where
|
|
||||||
C: Channel,
|
|
||||||
{
|
|
||||||
type Error = io::Error;
|
|
||||||
|
|
||||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
|
||||||
self.project().inner.poll_ready(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
|
|
||||||
self.project().inner.start_send(item)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
|
||||||
self.project().inner.poll_flush(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
|
||||||
self.project().inner.poll_close(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C> AsRef<C> for Throttler<C> {
|
|
||||||
fn as_ref(&self) -> &C {
|
|
||||||
&self.inner
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C> Channel for Throttler<C>
|
|
||||||
where
|
|
||||||
C: Channel,
|
|
||||||
{
|
|
||||||
type Req = <C as Channel>::Req;
|
|
||||||
type Resp = <C as Channel>::Resp;
|
|
||||||
|
|
||||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
|
||||||
self.project().inner.in_flight_requests()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config(&self) -> &Config {
|
|
||||||
self.inner.config()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
|
||||||
self.project().inner.start_request(request_id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A stream of throttling channels.
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ThrottlerStream<S> {
|
|
||||||
#[pin]
|
|
||||||
inner: S,
|
|
||||||
max_in_flight_requests: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> ThrottlerStream<S>
|
|
||||||
where
|
|
||||||
S: Stream,
|
|
||||||
<S as Stream>::Item: Channel,
|
|
||||||
{
|
|
||||||
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
|
|
||||||
Self {
|
|
||||||
inner,
|
|
||||||
max_in_flight_requests,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> Stream for ThrottlerStream<S>
|
|
||||||
where
|
|
||||||
S: Stream,
|
|
||||||
<S as Stream>::Item: Channel,
|
|
||||||
{
|
|
||||||
type Item = Throttler<<S as Stream>::Item>;
|
|
||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
|
||||||
match ready!(self.as_mut().project().inner.poll_next(cx)) {
|
|
||||||
Some(channel) => Poll::Ready(Some(Throttler::new(
|
|
||||||
channel,
|
|
||||||
*self.project().max_in_flight_requests,
|
|
||||||
))),
|
|
||||||
None => Poll::Ready(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
use super::testing::{self, FakeChannel, PollExt};
|
|
||||||
#[cfg(test)]
|
|
||||||
use crate::Request;
|
|
||||||
#[cfg(test)]
|
|
||||||
use pin_utils::pin_mut;
|
|
||||||
#[cfg(test)]
|
|
||||||
use std::marker::PhantomData;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn throttler_in_flight_requests() {
|
|
||||||
let throttler = Throttler {
|
|
||||||
max_in_flight_requests: 0,
|
|
||||||
inner: FakeChannel::default::<isize, isize>(),
|
|
||||||
};
|
|
||||||
|
|
||||||
pin_mut!(throttler);
|
|
||||||
for i in 0..5 {
|
|
||||||
throttler.inner.in_flight_requests.insert(i);
|
|
||||||
}
|
|
||||||
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn throttler_start_request() {
|
|
||||||
let throttler = Throttler {
|
|
||||||
max_in_flight_requests: 0,
|
|
||||||
inner: FakeChannel::default::<isize, isize>(),
|
|
||||||
};
|
|
||||||
|
|
||||||
pin_mut!(throttler);
|
|
||||||
throttler.as_mut().start_request(1);
|
|
||||||
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn throttler_poll_next_done() {
|
|
||||||
let throttler = Throttler {
|
|
||||||
max_in_flight_requests: 0,
|
|
||||||
inner: FakeChannel::default::<isize, isize>(),
|
|
||||||
};
|
|
||||||
|
|
||||||
pin_mut!(throttler);
|
|
||||||
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn throttler_poll_next_some() -> io::Result<()> {
|
|
||||||
let throttler = Throttler {
|
|
||||||
max_in_flight_requests: 1,
|
|
||||||
inner: FakeChannel::default::<isize, isize>(),
|
|
||||||
};
|
|
||||||
|
|
||||||
pin_mut!(throttler);
|
|
||||||
throttler.inner.push_req(0, 1);
|
|
||||||
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
|
|
||||||
assert_eq!(
|
|
||||||
throttler
|
|
||||||
.as_mut()
|
|
||||||
.poll_next(&mut testing::cx())?
|
|
||||||
.map(|r| r.map(|r| (r.id, r.message))),
|
|
||||||
Poll::Ready(Some((0, 1)))
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn throttler_poll_next_throttled() {
|
|
||||||
let throttler = Throttler {
|
|
||||||
max_in_flight_requests: 0,
|
|
||||||
inner: FakeChannel::default::<isize, isize>(),
|
|
||||||
};
|
|
||||||
|
|
||||||
pin_mut!(throttler);
|
|
||||||
throttler.inner.push_req(1, 1);
|
|
||||||
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
|
||||||
assert_eq!(throttler.inner.sink.len(), 1);
|
|
||||||
let resp = throttler.inner.sink.get(0).unwrap();
|
|
||||||
assert_eq!(resp.request_id, 1);
|
|
||||||
assert!(resp.message.is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn throttler_poll_next_throttled_sink_not_ready() {
|
|
||||||
let throttler = Throttler {
|
|
||||||
max_in_flight_requests: 0,
|
|
||||||
inner: PendingSink::default::<isize, isize>(),
|
|
||||||
};
|
|
||||||
pin_mut!(throttler);
|
|
||||||
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
|
|
||||||
|
|
||||||
struct PendingSink<In, Out> {
|
|
||||||
ghost: PhantomData<fn(Out) -> In>,
|
|
||||||
}
|
|
||||||
impl PendingSink<(), ()> {
|
|
||||||
pub fn default<Req, Resp>() -> PendingSink<io::Result<Request<Req>>, Response<Resp>> {
|
|
||||||
PendingSink { ghost: PhantomData }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl<In, Out> Stream for PendingSink<In, Out> {
|
|
||||||
type Item = In;
|
|
||||||
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
|
|
||||||
type Error = io::Error;
|
|
||||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
|
||||||
Poll::Pending
|
|
||||||
}
|
|
||||||
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
|
|
||||||
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
|
||||||
}
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
|
||||||
Poll::Pending
|
|
||||||
}
|
|
||||||
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
|
||||||
Poll::Pending
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>> {
|
|
||||||
type Req = Req;
|
|
||||||
type Resp = Resp;
|
|
||||||
fn config(&self) -> &Config {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn throttler_start_send() {
|
|
||||||
let throttler = Throttler {
|
|
||||||
max_in_flight_requests: 0,
|
|
||||||
inner: FakeChannel::default::<isize, isize>(),
|
|
||||||
};
|
|
||||||
|
|
||||||
pin_mut!(throttler);
|
|
||||||
throttler.inner.in_flight_requests.insert(0);
|
|
||||||
throttler
|
|
||||||
.as_mut()
|
|
||||||
.start_send(Response {
|
|
||||||
request_id: 0,
|
|
||||||
message: Ok(1),
|
|
||||||
_non_exhaustive: (),
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
assert!(throttler.inner.in_flight_requests.is_empty());
|
|
||||||
assert_eq!(
|
|
||||||
throttler.inner.sink.get(0),
|
|
||||||
Some(&Response {
|
|
||||||
request_id: 0,
|
|
||||||
message: Ok(1),
|
|
||||||
_non_exhaustive: ()
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,123 +0,0 @@
|
|||||||
// Copyright 2018 Google LLC
|
|
||||||
//
|
|
||||||
// Use of this source code is governed by an MIT-style
|
|
||||||
// license that can be found in the LICENSE file or at
|
|
||||||
// https://opensource.org/licenses/MIT.
|
|
||||||
|
|
||||||
//! Transports backed by in-memory channels.
|
|
||||||
|
|
||||||
use crate::PollIo;
|
|
||||||
use futures::{channel::mpsc, task::*, Sink, Stream};
|
|
||||||
use pin_project::pin_project;
|
|
||||||
use std::io;
|
|
||||||
use std::pin::Pin;
|
|
||||||
|
|
||||||
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
|
|
||||||
/// [`Sink`].
|
|
||||||
pub fn unbounded<SinkItem, Item>() -> (
|
|
||||||
UnboundedChannel<SinkItem, Item>,
|
|
||||||
UnboundedChannel<Item, SinkItem>,
|
|
||||||
) {
|
|
||||||
let (tx1, rx2) = mpsc::unbounded();
|
|
||||||
let (tx2, rx1) = mpsc::unbounded();
|
|
||||||
(
|
|
||||||
UnboundedChannel { tx: tx1, rx: rx1 },
|
|
||||||
UnboundedChannel { tx: tx2, rx: rx2 },
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender)
|
|
||||||
/// and [`UnboundedReceiver`](mpsc::UnboundedReceiver).
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct UnboundedChannel<Item, SinkItem> {
|
|
||||||
#[pin]
|
|
||||||
rx: mpsc::UnboundedReceiver<Item>,
|
|
||||||
#[pin]
|
|
||||||
tx: mpsc::UnboundedSender<SinkItem>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
|
|
||||||
type Item = Result<Item, io::Error>;
|
|
||||||
|
|
||||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> {
|
|
||||||
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
|
|
||||||
type Error = io::Error;
|
|
||||||
|
|
||||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
self.project()
|
|
||||||
.tx
|
|
||||||
.poll_ready(cx)
|
|
||||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
|
||||||
self.project()
|
|
||||||
.tx
|
|
||||||
.start_send(item)
|
|
||||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
||||||
self.project()
|
|
||||||
.tx
|
|
||||||
.poll_flush(cx)
|
|
||||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
self.project()
|
|
||||||
.tx
|
|
||||||
.poll_close(cx)
|
|
||||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use crate::{
|
|
||||||
client, context,
|
|
||||||
server::{Handler, Server},
|
|
||||||
transport,
|
|
||||||
};
|
|
||||||
use assert_matches::assert_matches;
|
|
||||||
use futures::{prelude::*, stream};
|
|
||||||
use log::trace;
|
|
||||||
use std::io;
|
|
||||||
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
#[tokio::test(threaded_scheduler)]
|
|
||||||
async fn integration() -> io::Result<()> {
|
|
||||||
let _ = env_logger::try_init();
|
|
||||||
|
|
||||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
|
||||||
tokio::spawn(
|
|
||||||
Server::default()
|
|
||||||
.incoming(stream::once(future::ready(server_channel)))
|
|
||||||
.respond_with(|_ctx, request: String| {
|
|
||||||
future::ready(request.parse::<u64>().map_err(|_| {
|
|
||||||
io::Error::new(
|
|
||||||
io::ErrorKind::InvalidInput,
|
|
||||||
format!("{:?} is not an int", request),
|
|
||||||
)
|
|
||||||
}))
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut client = client::new(client::Config::default(), client_channel).spawn()?;
|
|
||||||
|
|
||||||
let response1 = client.call(context::current(), "123".into()).await?;
|
|
||||||
let response2 = client.call(context::current(), "abc".into()).await?;
|
|
||||||
|
|
||||||
trace!("response1: {:?}, response2: {:?}", response1, response2);
|
|
||||||
|
|
||||||
assert_matches!(response1, Ok(123));
|
|
||||||
assert_matches!(response2, Err(ref e) if e.kind() == io::ErrorKind::InvalidInput);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
// Copyright 2018 Google LLC
|
|
||||||
//
|
|
||||||
// Use of this source code is governed by an MIT-style
|
|
||||||
// license that can be found in the LICENSE file or at
|
|
||||||
// https://opensource.org/licenses/MIT.
|
|
||||||
|
|
||||||
//! Provides a [`Transport`] trait as well as implementations.
|
|
||||||
//!
|
|
||||||
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`]
|
|
||||||
//! can be plugged in, using whatever protocol it wants.
|
|
||||||
|
|
||||||
use futures::prelude::*;
|
|
||||||
use std::io;
|
|
||||||
|
|
||||||
pub mod channel;
|
|
||||||
|
|
||||||
pub(crate) mod sealed {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
|
|
||||||
pub trait Transport<SinkItem, Item>:
|
|
||||||
Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error>
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, SinkItem, Item> Transport<SinkItem, Item> for T where
|
|
||||||
T: Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error> + ?Sized
|
|
||||||
{
|
|
||||||
}
|
|
||||||
}
|
|
||||||
395
tarpc/src/serde_transport.rs
Normal file
395
tarpc/src/serde_transport.rs
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
// Copyright 2019 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
//! A generic Serde-based `Transport` that can serialize anything supported by `tokio-serde` via any medium that implements `AsyncRead` and `AsyncWrite`.
|
||||||
|
|
||||||
|
#![deny(missing_docs)]
|
||||||
|
|
||||||
|
use futures::{prelude::*, task::*};
|
||||||
|
use pin_project::pin_project;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{error::Error, io, pin::Pin};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
use tokio_serde::{Framed as SerdeFramed, *};
|
||||||
|
use tokio_util::codec::{length_delimited::LengthDelimitedCodec, Framed};
|
||||||
|
|
||||||
|
/// A transport that serializes to, and deserializes from, a byte stream.
|
||||||
|
#[pin_project]
|
||||||
|
pub struct Transport<S, Item, SinkItem, Codec> {
|
||||||
|
#[pin]
|
||||||
|
inner: SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, Item, SinkItem, Codec> Transport<S, Item, SinkItem, Codec> {
|
||||||
|
/// Returns the inner transport over which messages are sent and received.
|
||||||
|
pub fn get_ref(&self) -> &S {
|
||||||
|
self.inner.get_ref().get_ref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, Item, SinkItem, Codec, CodecError> Stream for Transport<S, Item, SinkItem, Codec>
|
||||||
|
where
|
||||||
|
S: AsyncWrite + AsyncRead,
|
||||||
|
Item: for<'a> Deserialize<'a>,
|
||||||
|
Codec: Deserializer<Item>,
|
||||||
|
CodecError: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||||
|
SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>:
|
||||||
|
Stream<Item = Result<Item, CodecError>>,
|
||||||
|
{
|
||||||
|
type Item = io::Result<Item>;
|
||||||
|
|
||||||
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
|
||||||
|
self.project()
|
||||||
|
.inner
|
||||||
|
.poll_next(cx)
|
||||||
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, Item, SinkItem, Codec, CodecError> Sink<SinkItem> for Transport<S, Item, SinkItem, Codec>
|
||||||
|
where
|
||||||
|
S: AsyncWrite,
|
||||||
|
SinkItem: Serialize,
|
||||||
|
Codec: Serializer<SinkItem>,
|
||||||
|
CodecError: Into<Box<dyn Error + Send + Sync>>,
|
||||||
|
SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>:
|
||||||
|
Sink<SinkItem, Error = CodecError>,
|
||||||
|
{
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
self.project()
|
||||||
|
.inner
|
||||||
|
.poll_ready(cx)
|
||||||
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
||||||
|
self.project()
|
||||||
|
.inner
|
||||||
|
.start_send(item)
|
||||||
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
self.project()
|
||||||
|
.inner
|
||||||
|
.poll_flush(cx)
|
||||||
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
self.project()
|
||||||
|
.inner
|
||||||
|
.poll_close(cx)
|
||||||
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Constructs a new transport from a framed transport and a serialization codec.
|
||||||
|
pub fn new<S, Item, SinkItem, Codec>(
|
||||||
|
framed_io: Framed<S, LengthDelimitedCodec>,
|
||||||
|
codec: Codec,
|
||||||
|
) -> Transport<S, Item, SinkItem, Codec>
|
||||||
|
where
|
||||||
|
S: AsyncWrite + AsyncRead,
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
SinkItem: Serialize,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
{
|
||||||
|
Transport {
|
||||||
|
inner: SerdeFramed::new(framed_io, codec),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, Item, SinkItem, Codec> From<(S, Codec)> for Transport<S, Item, SinkItem, Codec>
|
||||||
|
where
|
||||||
|
S: AsyncWrite + AsyncRead,
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
SinkItem: Serialize,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
{
|
||||||
|
fn from((io, codec): (S, Codec)) -> Self {
|
||||||
|
new(Framed::new(io, LengthDelimitedCodec::new()), codec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "tcp")]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(feature = "tcp")))]
|
||||||
|
/// TCP support for generic transport using Tokio.
|
||||||
|
pub mod tcp {
|
||||||
|
use {
|
||||||
|
super::*,
|
||||||
|
futures::ready,
|
||||||
|
std::{marker::PhantomData, net::SocketAddr},
|
||||||
|
tokio::net::{TcpListener, TcpStream, ToSocketAddrs},
|
||||||
|
tokio_util::codec::length_delimited,
|
||||||
|
};
|
||||||
|
|
||||||
|
mod private {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
pub trait Sealed {}
|
||||||
|
|
||||||
|
impl<Item, SinkItem, Codec> Sealed for Transport<TcpStream, Item, SinkItem, Codec> {}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Item, SinkItem, Codec> Transport<TcpStream, Item, SinkItem, Codec> {
|
||||||
|
/// Returns the peer address of the underlying TcpStream.
|
||||||
|
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
|
||||||
|
self.inner.get_ref().get_ref().peer_addr()
|
||||||
|
}
|
||||||
|
/// Returns the local address of the underlying TcpStream.
|
||||||
|
pub fn local_addr(&self) -> io::Result<SocketAddr> {
|
||||||
|
self.inner.get_ref().get_ref().local_addr()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A connection Future that also exposes the length-delimited framing config.
|
||||||
|
#[pin_project]
|
||||||
|
pub struct Connect<T, Item, SinkItem, CodecFn> {
|
||||||
|
#[pin]
|
||||||
|
inner: T,
|
||||||
|
codec_fn: CodecFn,
|
||||||
|
config: length_delimited::Builder,
|
||||||
|
ghost: PhantomData<(fn(SinkItem), fn() -> Item)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, Item, SinkItem, Codec, CodecFn> Future for Connect<T, Item, SinkItem, CodecFn>
|
||||||
|
where
|
||||||
|
T: Future<Output = io::Result<TcpStream>>,
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
SinkItem: Serialize,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
CodecFn: Fn() -> Codec,
|
||||||
|
{
|
||||||
|
type Output = io::Result<Transport<TcpStream, Item, SinkItem, Codec>>;
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
||||||
|
let io = ready!(self.as_mut().project().inner.poll(cx))?;
|
||||||
|
Poll::Ready(Ok(new(self.config.new_framed(io), (self.codec_fn)())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, Item, SinkItem, CodecFn> Connect<T, Item, SinkItem, CodecFn> {
|
||||||
|
/// Returns an immutable reference to the length-delimited codec's config.
|
||||||
|
pub fn config(&self) -> &length_delimited::Builder {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable reference to the length-delimited codec's config.
|
||||||
|
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
|
||||||
|
&mut self.config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Connects to `addr`, wrapping the connection in a TCP transport.
|
||||||
|
pub fn connect<A, Item, SinkItem, Codec, CodecFn>(
|
||||||
|
addr: A,
|
||||||
|
codec_fn: CodecFn,
|
||||||
|
) -> Connect<impl Future<Output = io::Result<TcpStream>>, Item, SinkItem, CodecFn>
|
||||||
|
where
|
||||||
|
A: ToSocketAddrs,
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
SinkItem: Serialize,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
CodecFn: Fn() -> Codec,
|
||||||
|
{
|
||||||
|
Connect {
|
||||||
|
inner: TcpStream::connect(addr),
|
||||||
|
codec_fn,
|
||||||
|
config: LengthDelimitedCodec::builder(),
|
||||||
|
ghost: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Listens on `addr`, wrapping accepted connections in TCP transports.
|
||||||
|
pub async fn listen<A, Item, SinkItem, Codec, CodecFn>(
|
||||||
|
addr: A,
|
||||||
|
codec_fn: CodecFn,
|
||||||
|
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
|
||||||
|
where
|
||||||
|
A: ToSocketAddrs,
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
CodecFn: Fn() -> Codec,
|
||||||
|
{
|
||||||
|
let listener = TcpListener::bind(addr).await?;
|
||||||
|
let local_addr = listener.local_addr()?;
|
||||||
|
Ok(Incoming {
|
||||||
|
listener,
|
||||||
|
codec_fn,
|
||||||
|
local_addr,
|
||||||
|
config: LengthDelimitedCodec::builder(),
|
||||||
|
ghost: PhantomData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A [`TcpListener`] that wraps connections in [transports](Transport).
|
||||||
|
#[pin_project]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Incoming<Item, SinkItem, Codec, CodecFn> {
|
||||||
|
listener: TcpListener,
|
||||||
|
local_addr: SocketAddr,
|
||||||
|
codec_fn: CodecFn,
|
||||||
|
config: length_delimited::Builder,
|
||||||
|
ghost: PhantomData<(fn() -> Item, fn(SinkItem), Codec)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Item, SinkItem, Codec, CodecFn> Incoming<Item, SinkItem, Codec, CodecFn> {
|
||||||
|
/// Returns the address being listened on.
|
||||||
|
pub fn local_addr(&self) -> SocketAddr {
|
||||||
|
self.local_addr
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an immutable reference to the length-delimited codec's config.
|
||||||
|
pub fn config(&self) -> &length_delimited::Builder {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable reference to the length-delimited codec's config.
|
||||||
|
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
|
||||||
|
&mut self.config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Item, SinkItem, Codec, CodecFn> Stream for Incoming<Item, SinkItem, Codec, CodecFn>
|
||||||
|
where
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
SinkItem: Serialize,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
CodecFn: Fn() -> Codec,
|
||||||
|
{
|
||||||
|
type Item = io::Result<Transport<TcpStream, Item, SinkItem, Codec>>;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
let conn: TcpStream =
|
||||||
|
ready!(Pin::new(&mut self.as_mut().project().listener).poll_accept(cx)?).0;
|
||||||
|
Poll::Ready(Some(Ok(new(
|
||||||
|
self.config.new_framed(conn),
|
||||||
|
(self.codec_fn)(),
|
||||||
|
))))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::Transport;
|
||||||
|
use assert_matches::assert_matches;
|
||||||
|
use futures::{task::*, Sink, Stream};
|
||||||
|
use pin_utils::pin_mut;
|
||||||
|
use std::{
|
||||||
|
io::{self, Cursor},
|
||||||
|
pin::Pin,
|
||||||
|
};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
|
use tokio_serde::formats::SymmetricalJson;
|
||||||
|
|
||||||
|
fn ctx() -> Context<'static> {
|
||||||
|
Context::from_waker(&noop_waker_ref())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TestIo(Cursor<Vec<u8>>);
|
||||||
|
|
||||||
|
impl AsyncRead for TestIo {
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for TestIo {
|
||||||
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn close() {
|
||||||
|
let (tx, _rx) = crate::transport::channel::bounded::<(), ()>(0);
|
||||||
|
pin_mut!(tx);
|
||||||
|
assert_matches!(tx.as_mut().poll_close(&mut ctx()), Poll::Ready(Ok(())));
|
||||||
|
assert_matches!(tx.as_mut().start_send(()), Err(_));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream() {
|
||||||
|
let data: &[u8] = b"\x00\x00\x00\x18\"Test one, check check.\"";
|
||||||
|
let transport = Transport::from((
|
||||||
|
TestIo(Cursor::new(Vec::from(data))),
|
||||||
|
SymmetricalJson::<String>::default(),
|
||||||
|
));
|
||||||
|
pin_mut!(transport);
|
||||||
|
|
||||||
|
assert_matches!(
|
||||||
|
transport.as_mut().poll_next(&mut ctx()),
|
||||||
|
Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check.");
|
||||||
|
assert_matches!(transport.as_mut().poll_next(&mut ctx()), Poll::Ready(None));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sink() {
|
||||||
|
let writer = Cursor::new(vec![]);
|
||||||
|
let mut transport = Box::pin(Transport::from((
|
||||||
|
TestIo(writer),
|
||||||
|
SymmetricalJson::<String>::default(),
|
||||||
|
)));
|
||||||
|
|
||||||
|
assert_matches!(
|
||||||
|
transport.as_mut().poll_ready(&mut ctx()),
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
);
|
||||||
|
assert_matches!(
|
||||||
|
transport
|
||||||
|
.as_mut()
|
||||||
|
.start_send("Test one, check check.".into()),
|
||||||
|
Ok(())
|
||||||
|
);
|
||||||
|
assert_matches!(
|
||||||
|
transport.as_mut().poll_flush(&mut ctx()),
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
transport.get_ref().0.get_ref(),
|
||||||
|
b"\x00\x00\x00\x18\"Test one, check check.\""
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(tcp)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn tcp() -> io::Result<()> {
|
||||||
|
use super::tcp;
|
||||||
|
|
||||||
|
let mut listener = tcp::listen("0.0.0.0:0", SymmetricalJson::<String>::default).await?;
|
||||||
|
let addr = listener.local_addr();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut transport = listener.next().await.unwrap().unwrap();
|
||||||
|
let message = transport.next().await.unwrap().unwrap();
|
||||||
|
transport.send(message).await.unwrap();
|
||||||
|
});
|
||||||
|
let mut transport = tcp::connect(addr, SymmetricalJson::<String>::default).await?;
|
||||||
|
transport.send(String::from("test")).await?;
|
||||||
|
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
|
||||||
|
assert_matches!(transport.next().await, None);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,326 +0,0 @@
|
|||||||
// Copyright 2019 Google LLC
|
|
||||||
//
|
|
||||||
// Use of this source code is governed by an MIT-style
|
|
||||||
// license that can be found in the LICENSE file or at
|
|
||||||
// https://opensource.org/licenses/MIT.
|
|
||||||
|
|
||||||
//! A generic Serde-based `Transport` that can serialize anything supported by `tokio-serde` via any medium that implements `AsyncRead` and `AsyncWrite`.
|
|
||||||
|
|
||||||
#![deny(missing_docs)]
|
|
||||||
|
|
||||||
use futures::{prelude::*, task::*};
|
|
||||||
use pin_project::pin_project;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::{error::Error, io, pin::Pin};
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
|
||||||
use tokio_serde::{Framed as SerdeFramed, *};
|
|
||||||
use tokio_util::codec::{length_delimited::LengthDelimitedCodec, Framed};
|
|
||||||
|
|
||||||
/// A transport that serializes to, and deserializes from, a [`TcpStream`].
|
|
||||||
#[pin_project]
|
|
||||||
pub struct Transport<S, Item, SinkItem, Codec> {
|
|
||||||
#[pin]
|
|
||||||
inner: SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, Item, SinkItem, Codec, CodecError> Stream for Transport<S, Item, SinkItem, Codec>
|
|
||||||
where
|
|
||||||
S: AsyncWrite + AsyncRead,
|
|
||||||
Item: for<'a> Deserialize<'a>,
|
|
||||||
Codec: Deserializer<Item>,
|
|
||||||
CodecError: Into<Box<dyn std::error::Error + Send + Sync>>,
|
|
||||||
SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>:
|
|
||||||
Stream<Item = Result<Item, CodecError>>,
|
|
||||||
{
|
|
||||||
type Item = io::Result<Item>;
|
|
||||||
|
|
||||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
|
|
||||||
match self.project().inner.poll_next(cx) {
|
|
||||||
Poll::Pending => Poll::Pending,
|
|
||||||
Poll::Ready(None) => Poll::Ready(None),
|
|
||||||
Poll::Ready(Some(Ok::<_, CodecError>(next))) => Poll::Ready(Some(Ok(next))),
|
|
||||||
Poll::Ready(Some(Err::<_, CodecError>(e))) => {
|
|
||||||
Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e))))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, Item, SinkItem, Codec, CodecError> Sink<SinkItem> for Transport<S, Item, SinkItem, Codec>
|
|
||||||
where
|
|
||||||
S: AsyncWrite,
|
|
||||||
SinkItem: Serialize,
|
|
||||||
Codec: Serializer<SinkItem>,
|
|
||||||
CodecError: Into<Box<dyn Error + Send + Sync>>,
|
|
||||||
SerdeFramed<Framed<S, LengthDelimitedCodec>, Item, SinkItem, Codec>:
|
|
||||||
Sink<SinkItem, Error = CodecError>,
|
|
||||||
{
|
|
||||||
type Error = io::Error;
|
|
||||||
|
|
||||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
convert(self.project().inner.poll_ready(cx))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
|
||||||
self.project()
|
|
||||||
.inner
|
|
||||||
.start_send(item)
|
|
||||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
convert(self.project().inner.poll_flush(cx))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
convert(self.project().inner.poll_close(cx))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
|
|
||||||
poll: Poll<Result<(), E>>,
|
|
||||||
) -> Poll<io::Result<()>> {
|
|
||||||
poll.map(|ready| ready.map_err(|e| io::Error::new(io::ErrorKind::Other, e)))
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, Item, SinkItem, Codec> From<(S, Codec)> for Transport<S, Item, SinkItem, Codec>
|
|
||||||
where
|
|
||||||
S: AsyncWrite + AsyncRead,
|
|
||||||
Item: for<'de> Deserialize<'de>,
|
|
||||||
SinkItem: Serialize,
|
|
||||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
|
||||||
{
|
|
||||||
fn from((inner, codec): (S, Codec)) -> Self {
|
|
||||||
Transport {
|
|
||||||
inner: SerdeFramed::new(Framed::new(inner, LengthDelimitedCodec::new()), codec),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "tcp")]
|
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "tcp")))]
|
|
||||||
/// TCP support for generic transport using Tokio.
|
|
||||||
pub mod tcp {
|
|
||||||
use {
|
|
||||||
super::*,
|
|
||||||
futures::ready,
|
|
||||||
std::{marker::PhantomData, net::SocketAddr},
|
|
||||||
tokio::net::{TcpListener, TcpStream, ToSocketAddrs},
|
|
||||||
};
|
|
||||||
|
|
||||||
mod private {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
pub trait Sealed {}
|
|
||||||
|
|
||||||
impl<Item, SinkItem, Codec> Sealed for Transport<TcpStream, Item, SinkItem, Codec> {}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Item, SinkItem, Codec> Transport<TcpStream, Item, SinkItem, Codec> {
|
|
||||||
/// Returns the peer address of the underlying TcpStream.
|
|
||||||
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
|
|
||||||
self.inner.get_ref().get_ref().peer_addr()
|
|
||||||
}
|
|
||||||
/// Returns the local address of the underlying TcpStream.
|
|
||||||
pub fn local_addr(&self) -> io::Result<SocketAddr> {
|
|
||||||
self.inner.get_ref().get_ref().local_addr()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a new JSON transport that reads from and writes to `io`.
|
|
||||||
pub fn new<Item, SinkItem, Codec>(
|
|
||||||
io: TcpStream,
|
|
||||||
codec: Codec,
|
|
||||||
) -> Transport<TcpStream, Item, SinkItem, Codec>
|
|
||||||
where
|
|
||||||
Item: for<'de> Deserialize<'de>,
|
|
||||||
SinkItem: Serialize,
|
|
||||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
|
||||||
{
|
|
||||||
Transport::from((io, codec))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Connects to `addr`, wrapping the connection in a JSON transport.
|
|
||||||
pub async fn connect<A, Item, SinkItem, Codec>(
|
|
||||||
addr: A,
|
|
||||||
codec: Codec,
|
|
||||||
) -> io::Result<Transport<TcpStream, Item, SinkItem, Codec>>
|
|
||||||
where
|
|
||||||
A: ToSocketAddrs,
|
|
||||||
Item: for<'de> Deserialize<'de>,
|
|
||||||
SinkItem: Serialize,
|
|
||||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
|
||||||
{
|
|
||||||
Ok(new(TcpStream::connect(addr).await?, codec))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Listens on `addr`, wrapping accepted connections in JSON transports.
|
|
||||||
pub async fn listen<A, Item, SinkItem, Codec, CodecFn>(
|
|
||||||
addr: A,
|
|
||||||
codec_fn: CodecFn,
|
|
||||||
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
|
|
||||||
where
|
|
||||||
A: ToSocketAddrs,
|
|
||||||
Item: for<'de> Deserialize<'de>,
|
|
||||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
|
||||||
CodecFn: Fn() -> Codec,
|
|
||||||
{
|
|
||||||
let listener = TcpListener::bind(addr).await?;
|
|
||||||
let local_addr = listener.local_addr()?;
|
|
||||||
Ok(Incoming {
|
|
||||||
listener,
|
|
||||||
codec_fn,
|
|
||||||
local_addr,
|
|
||||||
ghost: PhantomData,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A [`TcpListener`] that wraps connections in JSON transports.
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Incoming<Item, SinkItem, Codec, CodecFn> {
|
|
||||||
listener: TcpListener,
|
|
||||||
local_addr: SocketAddr,
|
|
||||||
codec_fn: CodecFn,
|
|
||||||
ghost: PhantomData<(Item, SinkItem, Codec)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Item, SinkItem, Codec, CodecFn> Incoming<Item, SinkItem, Codec, CodecFn> {
|
|
||||||
/// Returns the address being listened on.
|
|
||||||
pub fn local_addr(&self) -> SocketAddr {
|
|
||||||
self.local_addr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Item, SinkItem, Codec, CodecFn> Stream for Incoming<Item, SinkItem, Codec, CodecFn>
|
|
||||||
where
|
|
||||||
Item: for<'de> Deserialize<'de>,
|
|
||||||
SinkItem: Serialize,
|
|
||||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
|
||||||
CodecFn: Fn() -> Codec,
|
|
||||||
{
|
|
||||||
type Item = io::Result<Transport<TcpStream, Item, SinkItem, Codec>>;
|
|
||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
||||||
let next =
|
|
||||||
ready!(Pin::new(&mut self.as_mut().project().listener.incoming()).poll_next(cx)?);
|
|
||||||
Poll::Ready(next.map(|conn| Ok(new(conn, (self.codec_fn)()))))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::Transport;
|
|
||||||
use assert_matches::assert_matches;
|
|
||||||
use futures::{task::*, Sink, Stream};
|
|
||||||
use pin_utils::pin_mut;
|
|
||||||
use std::{
|
|
||||||
io::{self, Cursor},
|
|
||||||
pin::Pin,
|
|
||||||
};
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
|
||||||
use tokio_serde::formats::SymmetricalJson;
|
|
||||||
|
|
||||||
fn ctx() -> Context<'static> {
|
|
||||||
Context::from_waker(&noop_waker_ref())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_stream() {
|
|
||||||
struct TestIo(Cursor<&'static [u8]>);
|
|
||||||
|
|
||||||
impl AsyncRead for TestIo {
|
|
||||||
fn poll_read(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &mut [u8],
|
|
||||||
) -> Poll<io::Result<usize>> {
|
|
||||||
AsyncRead::poll_read(Pin::new(self.0.get_mut()), cx, buf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AsyncWrite for TestIo {
|
|
||||||
fn poll_write(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
_cx: &mut Context<'_>,
|
|
||||||
_buf: &[u8],
|
|
||||||
) -> Poll<io::Result<usize>> {
|
|
||||||
unreachable!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
unreachable!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
unreachable!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let data = b"\x00\x00\x00\x18\"Test one, check check.\"";
|
|
||||||
let transport = Transport::from((
|
|
||||||
TestIo(Cursor::new(data)),
|
|
||||||
SymmetricalJson::<String>::default(),
|
|
||||||
));
|
|
||||||
pin_mut!(transport);
|
|
||||||
|
|
||||||
assert_matches!(
|
|
||||||
transport.poll_next(&mut ctx()),
|
|
||||||
Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check.");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_sink() {
|
|
||||||
struct TestIo<'a>(&'a mut Vec<u8>);
|
|
||||||
|
|
||||||
impl<'a> AsyncRead for TestIo<'a> {
|
|
||||||
fn poll_read(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
_cx: &mut Context<'_>,
|
|
||||||
_buf: &mut [u8],
|
|
||||||
) -> Poll<io::Result<usize>> {
|
|
||||||
unreachable!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> AsyncWrite for TestIo<'a> {
|
|
||||||
fn poll_write(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &[u8],
|
|
||||||
) -> Poll<io::Result<usize>> {
|
|
||||||
AsyncWrite::poll_write(Pin::new(&mut *self.0), cx, buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
AsyncWrite::poll_flush(Pin::new(&mut *self.0), cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_shutdown(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
) -> Poll<io::Result<()>> {
|
|
||||||
AsyncWrite::poll_shutdown(Pin::new(&mut *self.0), cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut writer = vec![];
|
|
||||||
let transport =
|
|
||||||
Transport::from((TestIo(&mut writer), SymmetricalJson::<String>::default()));
|
|
||||||
pin_mut!(transport);
|
|
||||||
|
|
||||||
assert_matches!(
|
|
||||||
transport.as_mut().poll_ready(&mut ctx()),
|
|
||||||
Poll::Ready(Ok(()))
|
|
||||||
);
|
|
||||||
assert_matches!(
|
|
||||||
transport
|
|
||||||
.as_mut()
|
|
||||||
.start_send("Test one, check check.".into()),
|
|
||||||
Ok(())
|
|
||||||
);
|
|
||||||
assert_matches!(transport.poll_flush(&mut ctx()), Poll::Ready(Ok(())));
|
|
||||||
assert_eq!(writer, b"\x00\x00\x00\x18\"Test one, check check.\"");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
1053
tarpc/src/server.rs
Normal file
1053
tarpc/src/server.rs
Normal file
File diff suppressed because it is too large
Load Diff
223
tarpc/src/server/in_flight_requests.rs
Normal file
223
tarpc/src/server/in_flight_requests.rs
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
use crate::util::{Compact, TimeUntil};
|
||||||
|
use fnv::FnvHashMap;
|
||||||
|
use futures::future::{AbortHandle, AbortRegistration};
|
||||||
|
use std::{
|
||||||
|
collections::hash_map,
|
||||||
|
task::{Context, Poll},
|
||||||
|
time::SystemTime,
|
||||||
|
};
|
||||||
|
use tokio_util::time::delay_queue::{self, DelayQueue};
|
||||||
|
use tracing::Span;
|
||||||
|
|
||||||
|
/// A data structure that tracks in-flight requests. It aborts requests,
|
||||||
|
/// either on demand or when a request deadline expires.
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct InFlightRequests {
|
||||||
|
request_data: FnvHashMap<u64, RequestData>,
|
||||||
|
deadlines: DelayQueue<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Data needed to clean up a single in-flight request.
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct RequestData {
|
||||||
|
/// Aborts the response handler for the associated request.
|
||||||
|
abort_handle: AbortHandle,
|
||||||
|
/// The key to remove the timer for the request's deadline.
|
||||||
|
deadline_key: delay_queue::Key,
|
||||||
|
/// The client span.
|
||||||
|
span: Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An error returned when a request attempted to start with the same ID as a request already
|
||||||
|
/// in flight.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AlreadyExistsError;
|
||||||
|
|
||||||
|
impl InFlightRequests {
|
||||||
|
/// Returns the number of in-flight requests.
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.request_data.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Starts a request, unless a request with the same ID is already in flight.
|
||||||
|
pub fn start_request(
|
||||||
|
&mut self,
|
||||||
|
request_id: u64,
|
||||||
|
deadline: SystemTime,
|
||||||
|
span: Span,
|
||||||
|
) -> Result<AbortRegistration, AlreadyExistsError> {
|
||||||
|
match self.request_data.entry(request_id) {
|
||||||
|
hash_map::Entry::Vacant(vacant) => {
|
||||||
|
let timeout = deadline.time_until();
|
||||||
|
let (abort_handle, abort_registration) = AbortHandle::new_pair();
|
||||||
|
let deadline_key = self.deadlines.insert(request_id, timeout);
|
||||||
|
vacant.insert(RequestData {
|
||||||
|
abort_handle,
|
||||||
|
deadline_key,
|
||||||
|
span,
|
||||||
|
});
|
||||||
|
Ok(abort_registration)
|
||||||
|
}
|
||||||
|
hash_map::Entry::Occupied(_) => Err(AlreadyExistsError),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cancels an in-flight request. Returns true iff the request was found.
|
||||||
|
pub fn cancel_request(&mut self, request_id: u64) -> bool {
|
||||||
|
if let Some(RequestData {
|
||||||
|
span,
|
||||||
|
abort_handle,
|
||||||
|
deadline_key,
|
||||||
|
}) = self.request_data.remove(&request_id)
|
||||||
|
{
|
||||||
|
let _entered = span.enter();
|
||||||
|
self.request_data.compact(0.1);
|
||||||
|
abort_handle.abort();
|
||||||
|
self.deadlines.remove(&deadline_key);
|
||||||
|
tracing::info!("ReceiveCancel");
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Removes a request without aborting. Returns true iff the request was found.
|
||||||
|
/// This method should be used when a response is being sent.
|
||||||
|
pub fn remove_request(&mut self, request_id: u64) -> Option<Span> {
|
||||||
|
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||||
|
self.request_data.compact(0.1);
|
||||||
|
self.deadlines.remove(&request_data.deadline_key);
|
||||||
|
Some(request_data.span)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Yields a request that has expired, aborting any ongoing processing of that request.
|
||||||
|
pub fn poll_expired(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut Context,
|
||||||
|
) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
|
||||||
|
if self.deadlines.is_empty() {
|
||||||
|
// TODO(https://github.com/tokio-rs/tokio/issues/4161)
|
||||||
|
// This is a workaround for DelayQueue not always treating this case correctly.
|
||||||
|
return Poll::Ready(None);
|
||||||
|
}
|
||||||
|
self.deadlines.poll_expired(cx).map_ok(|expired| {
|
||||||
|
if let Some(RequestData {
|
||||||
|
abort_handle, span, ..
|
||||||
|
}) = self.request_data.remove(expired.get_ref())
|
||||||
|
{
|
||||||
|
let _entered = span.enter();
|
||||||
|
self.request_data.compact(0.1);
|
||||||
|
abort_handle.abort();
|
||||||
|
tracing::error!("DeadlineExceeded");
|
||||||
|
}
|
||||||
|
expired.into_inner()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// When InFlightRequests is dropped, any outstanding requests are aborted.
|
||||||
|
impl Drop for InFlightRequests {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.request_data
|
||||||
|
.values()
|
||||||
|
.for_each(|request_data| request_data.abort_handle.abort())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
use assert_matches::assert_matches;
|
||||||
|
use futures::{
|
||||||
|
future::{pending, Abortable},
|
||||||
|
FutureExt,
|
||||||
|
};
|
||||||
|
use futures_test::task::noop_context;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn start_request_increases_len() {
|
||||||
|
let mut in_flight_requests = InFlightRequests::default();
|
||||||
|
assert_eq!(in_flight_requests.len(), 0);
|
||||||
|
in_flight_requests
|
||||||
|
.start_request(0, SystemTime::now(), Span::current())
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(in_flight_requests.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn polling_expired_aborts() {
|
||||||
|
let mut in_flight_requests = InFlightRequests::default();
|
||||||
|
let abort_registration = in_flight_requests
|
||||||
|
.start_request(0, SystemTime::now(), Span::current())
|
||||||
|
.unwrap();
|
||||||
|
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||||
|
|
||||||
|
tokio::time::pause();
|
||||||
|
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
|
||||||
|
|
||||||
|
assert_matches!(
|
||||||
|
in_flight_requests.poll_expired(&mut noop_context()),
|
||||||
|
Poll::Ready(Some(Ok(_)))
|
||||||
|
);
|
||||||
|
assert_matches!(
|
||||||
|
abortable_future.poll_unpin(&mut noop_context()),
|
||||||
|
Poll::Ready(Err(_))
|
||||||
|
);
|
||||||
|
assert_eq!(in_flight_requests.len(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn cancel_request_aborts() {
|
||||||
|
let mut in_flight_requests = InFlightRequests::default();
|
||||||
|
let abort_registration = in_flight_requests
|
||||||
|
.start_request(0, SystemTime::now(), Span::current())
|
||||||
|
.unwrap();
|
||||||
|
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||||
|
|
||||||
|
assert_eq!(in_flight_requests.cancel_request(0), true);
|
||||||
|
assert_matches!(
|
||||||
|
abortable_future.poll_unpin(&mut noop_context()),
|
||||||
|
Poll::Ready(Err(_))
|
||||||
|
);
|
||||||
|
assert_eq!(in_flight_requests.len(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn remove_request_doesnt_abort() {
|
||||||
|
let mut in_flight_requests = InFlightRequests::default();
|
||||||
|
assert!(in_flight_requests.deadlines.is_empty());
|
||||||
|
|
||||||
|
let abort_registration = in_flight_requests
|
||||||
|
.start_request(
|
||||||
|
0,
|
||||||
|
SystemTime::now() + std::time::Duration::from_secs(10),
|
||||||
|
Span::current(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||||
|
|
||||||
|
// Precondition: Pending expiration
|
||||||
|
assert_matches!(
|
||||||
|
in_flight_requests.poll_expired(&mut noop_context()),
|
||||||
|
Poll::Pending
|
||||||
|
);
|
||||||
|
assert!(!in_flight_requests.deadlines.is_empty());
|
||||||
|
|
||||||
|
assert_matches!(in_flight_requests.remove_request(0), Some(_));
|
||||||
|
// Postcondition: No pending expirations
|
||||||
|
assert!(in_flight_requests.deadlines.is_empty());
|
||||||
|
assert_matches!(
|
||||||
|
in_flight_requests.poll_expired(&mut noop_context()),
|
||||||
|
Poll::Ready(None)
|
||||||
|
);
|
||||||
|
assert_matches!(
|
||||||
|
abortable_future.poll_unpin(&mut noop_context()),
|
||||||
|
Poll::Pending
|
||||||
|
);
|
||||||
|
assert_eq!(in_flight_requests.len(), 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
49
tarpc/src/server/incoming.rs
Normal file
49
tarpc/src/server/incoming.rs
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
use super::{
|
||||||
|
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
|
||||||
|
Channel,
|
||||||
|
};
|
||||||
|
use futures::prelude::*;
|
||||||
|
use std::{fmt, hash::Hash};
|
||||||
|
|
||||||
|
#[cfg(feature = "tokio1")]
|
||||||
|
use super::{tokio::TokioServerExecutor, Serve};
|
||||||
|
|
||||||
|
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
|
||||||
|
pub trait Incoming<C>
|
||||||
|
where
|
||||||
|
Self: Sized + Stream<Item = C>,
|
||||||
|
C: Channel,
|
||||||
|
{
|
||||||
|
/// Enforces channel per-key limits.
|
||||||
|
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> MaxChannelsPerKey<Self, K, KF>
|
||||||
|
where
|
||||||
|
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||||
|
KF: Fn(&C) -> K,
|
||||||
|
{
|
||||||
|
MaxChannelsPerKey::new(self, n, keymaker)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Caps the number of concurrent requests per channel.
|
||||||
|
fn max_concurrent_requests_per_channel(self, n: usize) -> MaxRequestsPerChannel<Self> {
|
||||||
|
MaxRequestsPerChannel::new(self, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
|
||||||
|
/// concurrently by spawning on tokio's default executor, and each request will be also
|
||||||
|
/// be spawned on tokio's default executor.
|
||||||
|
#[cfg(feature = "tokio1")]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||||
|
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
|
||||||
|
where
|
||||||
|
S: Serve<C::Req, Resp = C::Resp>,
|
||||||
|
{
|
||||||
|
TokioServerExecutor::new(self, serve)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, C> Incoming<C> for S
|
||||||
|
where
|
||||||
|
S: Sized + Stream<Item = C>,
|
||||||
|
C: Channel,
|
||||||
|
{
|
||||||
|
}
|
||||||
5
tarpc/src/server/limits.rs
Normal file
5
tarpc/src/server/limits.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
/// Provides functionality to limit the number of active channels.
|
||||||
|
pub mod channels_per_key;
|
||||||
|
|
||||||
|
/// Provides a [channel](crate::server::Channel) that limits the number of in-flight requests.
|
||||||
|
pub mod requests_per_channel;
|
||||||
@@ -9,78 +9,65 @@ use crate::{
|
|||||||
util::Compact,
|
util::Compact,
|
||||||
};
|
};
|
||||||
use fnv::FnvHashMap;
|
use fnv::FnvHashMap;
|
||||||
use futures::{channel::mpsc, future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
|
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||||
use log::{debug, info, trace};
|
|
||||||
use pin_project::pin_project;
|
use pin_project::pin_project;
|
||||||
use raii_counter::{Counter, WeakCounter};
|
|
||||||
use std::sync::{Arc, Weak};
|
use std::sync::{Arc, Weak};
|
||||||
use std::{
|
use std::{
|
||||||
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
|
collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin,
|
||||||
};
|
};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tracing::{debug, info, trace};
|
||||||
|
|
||||||
/// A single-threaded filter that drops channels based on per-key limits.
|
/// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on
|
||||||
|
/// per-key limits.
|
||||||
|
///
|
||||||
|
/// The decision to drop a Channel is made once at the time the Channel materializes. Once a
|
||||||
|
/// Channel is yielded, it will not be prematurely dropped.
|
||||||
#[pin_project]
|
#[pin_project]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ChannelFilter<S, K, F>
|
pub struct MaxChannelsPerKey<S, K, F>
|
||||||
where
|
where
|
||||||
K: Eq + Hash,
|
K: Eq + Hash,
|
||||||
{
|
{
|
||||||
#[pin]
|
#[pin]
|
||||||
listener: Fuse<S>,
|
listener: Fuse<S>,
|
||||||
channels_per_key: u32,
|
channels_per_key: u32,
|
||||||
#[pin]
|
|
||||||
dropped_keys: mpsc::UnboundedReceiver<K>,
|
dropped_keys: mpsc::UnboundedReceiver<K>,
|
||||||
#[pin]
|
|
||||||
dropped_keys_tx: mpsc::UnboundedSender<K>,
|
dropped_keys_tx: mpsc::UnboundedSender<K>,
|
||||||
key_counts: FnvHashMap<K, TrackerPrototype<K>>,
|
key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
|
||||||
keymaker: F,
|
keymaker: F,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A channel that is tracked by a ChannelFilter.
|
/// A channel that is tracked by [`MaxChannelsPerKey`].
|
||||||
#[pin_project]
|
#[pin_project]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct TrackedChannel<C, K> {
|
pub struct TrackedChannel<C, K> {
|
||||||
#[pin]
|
#[pin]
|
||||||
inner: C,
|
inner: C,
|
||||||
tracker: Tracker<K>,
|
tracker: Arc<Tracker<K>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Debug)]
|
||||||
struct Tracker<K> {
|
struct Tracker<K> {
|
||||||
key: Option<Arc<K>>,
|
key: Option<K>,
|
||||||
counter: Counter,
|
|
||||||
dropped_keys: mpsc::UnboundedSender<K>,
|
dropped_keys: mpsc::UnboundedSender<K>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<K> Drop for Tracker<K> {
|
impl<K> Drop for Tracker<K> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
if self.counter.count() <= 1 {
|
// Don't care if the listener is dropped.
|
||||||
// Don't care if the listener is dropped.
|
let _ = self.dropped_keys.send(self.key.take().unwrap());
|
||||||
match Arc::try_unwrap(self.key.take().unwrap()) {
|
|
||||||
Ok(key) => {
|
|
||||||
let _ = self.dropped_keys.unbounded_send(key);
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
struct TrackerPrototype<K> {
|
|
||||||
key: Weak<K>,
|
|
||||||
counter: WeakCounter,
|
|
||||||
dropped_keys: mpsc::UnboundedSender<K>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C, K> Stream for TrackedChannel<C, K>
|
impl<C, K> Stream for TrackedChannel<C, K>
|
||||||
where
|
where
|
||||||
C: Stream,
|
C: Stream,
|
||||||
{
|
{
|
||||||
type Item = <C as Stream>::Item;
|
type Item = <C as Stream>::Item;
|
||||||
|
|
||||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||||
self.channel().poll_next(cx)
|
self.inner_pin_mut().poll_next(cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,20 +77,20 @@ where
|
|||||||
{
|
{
|
||||||
type Error = C::Error;
|
type Error = C::Error;
|
||||||
|
|
||||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
self.channel().poll_ready(cx)
|
self.inner_pin_mut().poll_ready(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
|
fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
|
||||||
self.channel().start_send(item)
|
self.inner_pin_mut().start_send(item)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
self.channel().poll_flush(cx)
|
self.inner_pin_mut().poll_flush(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
self.channel().poll_close(cx)
|
self.inner_pin_mut().poll_close(cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,17 +106,18 @@ where
|
|||||||
{
|
{
|
||||||
type Req = C::Req;
|
type Req = C::Req;
|
||||||
type Resp = C::Resp;
|
type Resp = C::Resp;
|
||||||
|
type Transport = C::Transport;
|
||||||
|
|
||||||
fn config(&self) -> &server::Config {
|
fn config(&self) -> &server::Config {
|
||||||
self.inner.config()
|
self.inner.config()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
fn in_flight_requests(&self) -> usize {
|
||||||
self.project().inner.in_flight_requests()
|
self.inner.in_flight_requests()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
fn transport(&self) -> &Self::Transport {
|
||||||
self.project().inner.start_request(request_id)
|
self.inner.transport()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,12 +128,12 @@ impl<C, K> TrackedChannel<C, K> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the pinned inner channel.
|
/// Returns the pinned inner channel.
|
||||||
fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> {
|
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
|
||||||
self.project().inner
|
self.as_mut().project().inner
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, K, F> ChannelFilter<S, K, F>
|
impl<S, K, F> MaxChannelsPerKey<S, K, F>
|
||||||
where
|
where
|
||||||
K: Eq + Hash,
|
K: Eq + Hash,
|
||||||
S: Stream,
|
S: Stream,
|
||||||
@@ -153,8 +141,8 @@ where
|
|||||||
{
|
{
|
||||||
/// Sheds new channels to stay under configured limits.
|
/// Sheds new channels to stay under configured limits.
|
||||||
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
|
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
|
||||||
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded();
|
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel();
|
||||||
ChannelFilter {
|
MaxChannelsPerKey {
|
||||||
listener: listener.fuse(),
|
listener: listener.fuse(),
|
||||||
channels_per_key,
|
channels_per_key,
|
||||||
dropped_keys,
|
dropped_keys,
|
||||||
@@ -165,12 +153,16 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, K, F> ChannelFilter<S, K, F>
|
impl<S, K, F> MaxChannelsPerKey<S, K, F>
|
||||||
where
|
where
|
||||||
S: Stream,
|
S: Stream,
|
||||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||||
F: Fn(&S::Item) -> K,
|
F: Fn(&S::Item) -> K,
|
||||||
{
|
{
|
||||||
|
fn listener_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<S>> {
|
||||||
|
self.as_mut().project().listener
|
||||||
|
}
|
||||||
|
|
||||||
fn handle_new_channel(
|
fn handle_new_channel(
|
||||||
mut self: Pin<&mut Self>,
|
mut self: Pin<&mut Self>,
|
||||||
stream: S::Item,
|
stream: S::Item,
|
||||||
@@ -179,11 +171,10 @@ where
|
|||||||
let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
|
let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
|
||||||
|
|
||||||
trace!(
|
trace!(
|
||||||
"[{}] Opening channel ({}/{}) channels for key.",
|
channel_filter_key = %key,
|
||||||
key,
|
open_channels = Arc::strong_count(&tracker),
|
||||||
tracker.counter.count(),
|
max_open_channels = self.channels_per_key,
|
||||||
self.as_mut().project().channels_per_key
|
"Opening channel");
|
||||||
);
|
|
||||||
|
|
||||||
Ok(TrackedChannel {
|
Ok(TrackedChannel {
|
||||||
tracker,
|
tracker,
|
||||||
@@ -191,45 +182,38 @@ where
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Tracker<K>, K> {
|
fn increment_channels_for_key(self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
|
||||||
let channels_per_key = self.channels_per_key;
|
let self_ = self.project();
|
||||||
let dropped_keys = self.dropped_keys_tx.clone();
|
let dropped_keys = self_.dropped_keys_tx;
|
||||||
let key_counts = &mut self.as_mut().project().key_counts;
|
match self_.key_counts.entry(key.clone()) {
|
||||||
match key_counts.entry(key.clone()) {
|
|
||||||
Entry::Vacant(vacant) => {
|
Entry::Vacant(vacant) => {
|
||||||
let key = Arc::new(key);
|
let tracker = Arc::new(Tracker {
|
||||||
let counter = WeakCounter::new();
|
key: Some(key),
|
||||||
|
|
||||||
vacant.insert(TrackerPrototype {
|
|
||||||
key: Arc::downgrade(&key),
|
|
||||||
counter: counter.clone(),
|
|
||||||
dropped_keys: dropped_keys.clone(),
|
dropped_keys: dropped_keys.clone(),
|
||||||
});
|
});
|
||||||
Ok(Tracker {
|
|
||||||
key: Some(key),
|
vacant.insert(Arc::downgrade(&tracker));
|
||||||
counter: counter.upgrade(),
|
Ok(tracker)
|
||||||
dropped_keys,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
Entry::Occupied(o) => {
|
Entry::Occupied(mut o) => {
|
||||||
let count = o.get().counter.count();
|
let count = o.get().strong_count();
|
||||||
if count >= channels_per_key.try_into().unwrap() {
|
if count >= TryFrom::try_from(*self_.channels_per_key).unwrap() {
|
||||||
info!(
|
info!(
|
||||||
"[{}] Opened max channels from key ({}/{}).",
|
channel_filter_key = %key,
|
||||||
key, count, channels_per_key
|
open_channels = count,
|
||||||
);
|
max_open_channels = *self_.channels_per_key,
|
||||||
|
"At open channel limit");
|
||||||
Err(key)
|
Err(key)
|
||||||
} else {
|
} else {
|
||||||
let TrackerPrototype {
|
Ok(o.get().upgrade().unwrap_or_else(|| {
|
||||||
key,
|
let tracker = Arc::new(Tracker {
|
||||||
counter,
|
key: Some(key),
|
||||||
dropped_keys,
|
dropped_keys: dropped_keys.clone(),
|
||||||
} = o.get().clone();
|
});
|
||||||
Ok(Tracker {
|
|
||||||
counter: counter.upgrade(),
|
*o.get_mut() = Arc::downgrade(&tracker);
|
||||||
key: Some(key.upgrade().unwrap()),
|
tracker
|
||||||
dropped_keys,
|
}))
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -239,18 +223,21 @@ where
|
|||||||
mut self: Pin<&mut Self>,
|
mut self: Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
|
) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
|
||||||
match ready!(self.as_mut().project().listener.poll_next_unpin(cx)) {
|
match ready!(self.listener_pin_mut().poll_next_unpin(cx)) {
|
||||||
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
|
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
|
||||||
None => Poll::Ready(None),
|
None => Poll::Ready(None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
fn poll_closed_channels(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||||
match ready!(self.as_mut().project().dropped_keys.poll_next_unpin(cx)) {
|
let self_ = self.project();
|
||||||
|
match ready!(self_.dropped_keys.poll_recv(cx)) {
|
||||||
Some(key) => {
|
Some(key) => {
|
||||||
debug!("All channels dropped for key [{}]", key);
|
debug!(
|
||||||
self.as_mut().project().key_counts.remove(&key);
|
channel_filter_key = %key,
|
||||||
self.as_mut().project().key_counts.compact(0.1);
|
"All channels dropped");
|
||||||
|
self_.key_counts.remove(&key);
|
||||||
|
self_.key_counts.compact(0.1);
|
||||||
Poll::Ready(())
|
Poll::Ready(())
|
||||||
}
|
}
|
||||||
None => unreachable!("Holding a copy of closed_channels and didn't close it."),
|
None => unreachable!("Holding a copy of closed_channels and didn't close it."),
|
||||||
@@ -258,7 +245,7 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, K, F> Stream for ChannelFilter<S, K, F>
|
impl<S, K, F> Stream for MaxChannelsPerKey<S, K, F>
|
||||||
where
|
where
|
||||||
S: Stream,
|
S: Stream,
|
||||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||||
@@ -291,7 +278,6 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
fn ctx() -> Context<'static> {
|
fn ctx() -> Context<'static> {
|
||||||
use futures::task::*;
|
use futures::task::*;
|
||||||
@@ -302,32 +288,28 @@ fn ctx() -> Context<'static> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn tracker_drop() {
|
fn tracker_drop() {
|
||||||
use assert_matches::assert_matches;
|
use assert_matches::assert_matches;
|
||||||
use raii_counter::Counter;
|
|
||||||
|
|
||||||
let (tx, mut rx) = mpsc::unbounded();
|
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||||
Tracker {
|
Tracker {
|
||||||
key: Some(Arc::new(1)),
|
key: Some(1),
|
||||||
counter: Counter::new(),
|
|
||||||
dropped_keys: tx,
|
dropped_keys: tx,
|
||||||
};
|
};
|
||||||
assert_matches!(rx.try_next(), Ok(Some(1)));
|
assert_matches!(rx.poll_recv(&mut ctx()), Poll::Ready(Some(1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn tracked_channel_stream() {
|
fn tracked_channel_stream() {
|
||||||
use assert_matches::assert_matches;
|
use assert_matches::assert_matches;
|
||||||
use pin_utils::pin_mut;
|
use pin_utils::pin_mut;
|
||||||
use raii_counter::Counter;
|
|
||||||
|
|
||||||
let (chan_tx, chan) = mpsc::unbounded();
|
let (chan_tx, chan) = futures::channel::mpsc::unbounded();
|
||||||
let (dropped_keys, _) = mpsc::unbounded();
|
let (dropped_keys, _) = mpsc::unbounded_channel();
|
||||||
let channel = TrackedChannel {
|
let channel = TrackedChannel {
|
||||||
inner: chan,
|
inner: chan,
|
||||||
tracker: Tracker {
|
tracker: Arc::new(Tracker {
|
||||||
key: Some(Arc::new(1)),
|
key: Some(1),
|
||||||
counter: Counter::new(),
|
|
||||||
dropped_keys,
|
dropped_keys,
|
||||||
},
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
chan_tx.unbounded_send("test").unwrap();
|
chan_tx.unbounded_send("test").unwrap();
|
||||||
@@ -339,17 +321,15 @@ fn tracked_channel_stream() {
|
|||||||
fn tracked_channel_sink() {
|
fn tracked_channel_sink() {
|
||||||
use assert_matches::assert_matches;
|
use assert_matches::assert_matches;
|
||||||
use pin_utils::pin_mut;
|
use pin_utils::pin_mut;
|
||||||
use raii_counter::Counter;
|
|
||||||
|
|
||||||
let (chan, mut chan_rx) = mpsc::unbounded();
|
let (chan, mut chan_rx) = futures::channel::mpsc::unbounded();
|
||||||
let (dropped_keys, _) = mpsc::unbounded();
|
let (dropped_keys, _) = mpsc::unbounded_channel();
|
||||||
let channel = TrackedChannel {
|
let channel = TrackedChannel {
|
||||||
inner: chan,
|
inner: chan,
|
||||||
tracker: Tracker {
|
tracker: Arc::new(Tracker {
|
||||||
key: Some(Arc::new(1)),
|
key: Some(1),
|
||||||
counter: Counter::new(),
|
|
||||||
dropped_keys,
|
dropped_keys,
|
||||||
},
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
pin_mut!(channel);
|
pin_mut!(channel);
|
||||||
@@ -367,16 +347,16 @@ fn channel_filter_increment_channels_for_key() {
|
|||||||
struct TestChannel {
|
struct TestChannel {
|
||||||
key: &'static str,
|
key: &'static str,
|
||||||
}
|
}
|
||||||
let (_, listener) = mpsc::unbounded();
|
let (_, listener) = futures::channel::mpsc::unbounded();
|
||||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||||
pin_mut!(filter);
|
pin_mut!(filter);
|
||||||
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
|
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
|
||||||
assert_eq!(tracker1.counter.count(), 1);
|
assert_eq!(Arc::strong_count(&tracker1), 1);
|
||||||
let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap();
|
let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap();
|
||||||
assert_eq!(tracker1.counter.count(), 2);
|
assert_eq!(Arc::strong_count(&tracker1), 2);
|
||||||
assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
|
assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
|
||||||
drop(tracker2);
|
drop(tracker2);
|
||||||
assert_eq!(tracker1.counter.count(), 1);
|
assert_eq!(Arc::strong_count(&tracker1), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -388,27 +368,27 @@ fn channel_filter_handle_new_channel() {
|
|||||||
struct TestChannel {
|
struct TestChannel {
|
||||||
key: &'static str,
|
key: &'static str,
|
||||||
}
|
}
|
||||||
let (_, listener) = mpsc::unbounded();
|
let (_, listener) = futures::channel::mpsc::unbounded();
|
||||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||||
pin_mut!(filter);
|
pin_mut!(filter);
|
||||||
let channel1 = filter
|
let channel1 = filter
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.handle_new_channel(TestChannel { key: "key" })
|
.handle_new_channel(TestChannel { key: "key" })
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(channel1.tracker.counter.count(), 1);
|
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
|
||||||
|
|
||||||
let channel2 = filter
|
let channel2 = filter
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.handle_new_channel(TestChannel { key: "key" })
|
.handle_new_channel(TestChannel { key: "key" })
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(channel1.tracker.counter.count(), 2);
|
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
|
||||||
|
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
filter.handle_new_channel(TestChannel { key: "key" }),
|
filter.handle_new_channel(TestChannel { key: "key" }),
|
||||||
Err("key")
|
Err("key")
|
||||||
);
|
);
|
||||||
drop(channel2);
|
drop(channel2);
|
||||||
assert_eq!(channel1.tracker.counter.count(), 1);
|
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -420,8 +400,8 @@ fn channel_filter_poll_listener() {
|
|||||||
struct TestChannel {
|
struct TestChannel {
|
||||||
key: &'static str,
|
key: &'static str,
|
||||||
}
|
}
|
||||||
let (new_channels, listener) = mpsc::unbounded();
|
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||||
pin_mut!(filter);
|
pin_mut!(filter);
|
||||||
|
|
||||||
new_channels
|
new_channels
|
||||||
@@ -429,14 +409,14 @@ fn channel_filter_poll_listener() {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
let channel1 =
|
let channel1 =
|
||||||
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
|
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
|
||||||
assert_eq!(channel1.tracker.counter.count(), 1);
|
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
|
||||||
|
|
||||||
new_channels
|
new_channels
|
||||||
.unbounded_send(TestChannel { key: "key" })
|
.unbounded_send(TestChannel { key: "key" })
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let _channel2 =
|
let _channel2 =
|
||||||
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
|
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
|
||||||
assert_eq!(channel1.tracker.counter.count(), 2);
|
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
|
||||||
|
|
||||||
new_channels
|
new_channels
|
||||||
.unbounded_send(TestChannel { key: "key" })
|
.unbounded_send(TestChannel { key: "key" })
|
||||||
@@ -444,7 +424,7 @@ fn channel_filter_poll_listener() {
|
|||||||
let key =
|
let key =
|
||||||
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
|
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
|
||||||
assert_eq!(key, "key");
|
assert_eq!(key, "key");
|
||||||
assert_eq!(channel1.tracker.counter.count(), 2);
|
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -456,8 +436,8 @@ fn channel_filter_poll_closed_channels() {
|
|||||||
struct TestChannel {
|
struct TestChannel {
|
||||||
key: &'static str,
|
key: &'static str,
|
||||||
}
|
}
|
||||||
let (new_channels, listener) = mpsc::unbounded();
|
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||||
pin_mut!(filter);
|
pin_mut!(filter);
|
||||||
|
|
||||||
new_channels
|
new_channels
|
||||||
@@ -484,8 +464,8 @@ fn channel_filter_stream() {
|
|||||||
struct TestChannel {
|
struct TestChannel {
|
||||||
key: &'static str,
|
key: &'static str,
|
||||||
}
|
}
|
||||||
let (new_channels, listener) = mpsc::unbounded();
|
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||||
pin_mut!(filter);
|
pin_mut!(filter);
|
||||||
|
|
||||||
new_channels
|
new_channels
|
||||||
349
tarpc/src/server/limits/requests_per_channel.rs
Normal file
349
tarpc/src/server/limits/requests_per_channel.rs
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
// Copyright 2020 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
server::{Channel, Config},
|
||||||
|
Response, ServerError,
|
||||||
|
};
|
||||||
|
use futures::{prelude::*, ready, task::*};
|
||||||
|
use pin_project::pin_project;
|
||||||
|
use std::{io, pin::Pin};
|
||||||
|
|
||||||
|
/// A [`Channel`] that limits the number of concurrent requests by throttling.
|
||||||
|
///
|
||||||
|
/// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low
|
||||||
|
/// for the resources available to the server. For production use cases, a more advanced throttler
|
||||||
|
/// is likely needed.
|
||||||
|
#[pin_project]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct MaxRequests<C> {
|
||||||
|
max_in_flight_requests: usize,
|
||||||
|
#[pin]
|
||||||
|
inner: C,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C> MaxRequests<C> {
|
||||||
|
/// Returns the inner channel.
|
||||||
|
pub fn get_ref(&self) -> &C {
|
||||||
|
&self.inner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C> MaxRequests<C>
|
||||||
|
where
|
||||||
|
C: Channel,
|
||||||
|
{
|
||||||
|
/// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to
|
||||||
|
/// `max_in_flight_requests`.
|
||||||
|
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
|
||||||
|
MaxRequests {
|
||||||
|
max_in_flight_requests,
|
||||||
|
inner,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C> Stream for MaxRequests<C>
|
||||||
|
where
|
||||||
|
C: Channel,
|
||||||
|
{
|
||||||
|
type Item = <C as Stream>::Item;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||||
|
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
|
||||||
|
{
|
||||||
|
ready!(self.as_mut().project().inner.poll_ready(cx)?);
|
||||||
|
|
||||||
|
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
|
||||||
|
Some(r) => {
|
||||||
|
let _entered = r.span.enter();
|
||||||
|
tracing::info!(
|
||||||
|
in_flight_requests = self.as_mut().in_flight_requests(),
|
||||||
|
"ThrottleRequest",
|
||||||
|
);
|
||||||
|
|
||||||
|
self.as_mut().start_send(Response {
|
||||||
|
request_id: r.request.id,
|
||||||
|
message: Err(ServerError {
|
||||||
|
kind: io::ErrorKind::WouldBlock,
|
||||||
|
detail: "server throttled the request.".into(),
|
||||||
|
}),
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
None => return Poll::Ready(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.project().inner.poll_next(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C> Sink<Response<<C as Channel>::Resp>> for MaxRequests<C>
|
||||||
|
where
|
||||||
|
C: Channel,
|
||||||
|
{
|
||||||
|
type Error = C::Error;
|
||||||
|
|
||||||
|
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.project().inner.poll_ready(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
item: Response<<C as Channel>::Resp>,
|
||||||
|
) -> Result<(), Self::Error> {
|
||||||
|
self.project().inner.start_send(item)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.project().inner.poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.project().inner.poll_close(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C> AsRef<C> for MaxRequests<C> {
|
||||||
|
fn as_ref(&self) -> &C {
|
||||||
|
&self.inner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C> Channel for MaxRequests<C>
|
||||||
|
where
|
||||||
|
C: Channel,
|
||||||
|
{
|
||||||
|
type Req = <C as Channel>::Req;
|
||||||
|
type Resp = <C as Channel>::Resp;
|
||||||
|
type Transport = <C as Channel>::Transport;
|
||||||
|
|
||||||
|
fn in_flight_requests(&self) -> usize {
|
||||||
|
self.inner.in_flight_requests()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config(&self) -> &Config {
|
||||||
|
self.inner.config()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn transport(&self) -> &Self::Transport {
|
||||||
|
self.inner.transport()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on
|
||||||
|
/// the number of in-flight requests.
|
||||||
|
#[pin_project]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct MaxRequestsPerChannel<S> {
|
||||||
|
#[pin]
|
||||||
|
inner: S,
|
||||||
|
max_in_flight_requests: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> MaxRequestsPerChannel<S>
|
||||||
|
where
|
||||||
|
S: Stream,
|
||||||
|
<S as Stream>::Item: Channel,
|
||||||
|
{
|
||||||
|
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
max_in_flight_requests,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> Stream for MaxRequestsPerChannel<S>
|
||||||
|
where
|
||||||
|
S: Stream,
|
||||||
|
<S as Stream>::Item: Channel,
|
||||||
|
{
|
||||||
|
type Item = MaxRequests<<S as Stream>::Item>;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||||
|
match ready!(self.as_mut().project().inner.poll_next(cx)) {
|
||||||
|
Some(channel) => Poll::Ready(Some(MaxRequests::new(
|
||||||
|
channel,
|
||||||
|
*self.project().max_in_flight_requests,
|
||||||
|
))),
|
||||||
|
None => Poll::Ready(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
use crate::server::{
|
||||||
|
testing::{self, FakeChannel, PollExt},
|
||||||
|
TrackedRequest,
|
||||||
|
};
|
||||||
|
use pin_utils::pin_mut;
|
||||||
|
use std::{
|
||||||
|
marker::PhantomData,
|
||||||
|
time::{Duration, SystemTime},
|
||||||
|
};
|
||||||
|
use tracing::Span;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn throttler_in_flight_requests() {
|
||||||
|
let throttler = MaxRequests {
|
||||||
|
max_in_flight_requests: 0,
|
||||||
|
inner: FakeChannel::default::<isize, isize>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
pin_mut!(throttler);
|
||||||
|
for i in 0..5 {
|
||||||
|
throttler
|
||||||
|
.inner
|
||||||
|
.in_flight_requests
|
||||||
|
.start_request(
|
||||||
|
i,
|
||||||
|
SystemTime::now() + Duration::from_secs(1),
|
||||||
|
Span::current(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn throttler_poll_next_done() {
|
||||||
|
let throttler = MaxRequests {
|
||||||
|
max_in_flight_requests: 0,
|
||||||
|
inner: FakeChannel::default::<isize, isize>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
pin_mut!(throttler);
|
||||||
|
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn throttler_poll_next_some() -> io::Result<()> {
|
||||||
|
let throttler = MaxRequests {
|
||||||
|
max_in_flight_requests: 1,
|
||||||
|
inner: FakeChannel::default::<isize, isize>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
pin_mut!(throttler);
|
||||||
|
throttler.inner.push_req(0, 1);
|
||||||
|
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
|
||||||
|
assert_eq!(
|
||||||
|
throttler
|
||||||
|
.as_mut()
|
||||||
|
.poll_next(&mut testing::cx())?
|
||||||
|
.map(|r| r.map(|r| (r.request.id, r.request.message))),
|
||||||
|
Poll::Ready(Some((0, 1)))
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn throttler_poll_next_throttled() {
|
||||||
|
let throttler = MaxRequests {
|
||||||
|
max_in_flight_requests: 0,
|
||||||
|
inner: FakeChannel::default::<isize, isize>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
pin_mut!(throttler);
|
||||||
|
throttler.inner.push_req(1, 1);
|
||||||
|
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
||||||
|
assert_eq!(throttler.inner.sink.len(), 1);
|
||||||
|
let resp = throttler.inner.sink.get(0).unwrap();
|
||||||
|
assert_eq!(resp.request_id, 1);
|
||||||
|
assert!(resp.message.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn throttler_poll_next_throttled_sink_not_ready() {
|
||||||
|
let throttler = MaxRequests {
|
||||||
|
max_in_flight_requests: 0,
|
||||||
|
inner: PendingSink::default::<isize, isize>(),
|
||||||
|
};
|
||||||
|
pin_mut!(throttler);
|
||||||
|
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
|
||||||
|
|
||||||
|
struct PendingSink<In, Out> {
|
||||||
|
ghost: PhantomData<fn(Out) -> In>,
|
||||||
|
}
|
||||||
|
impl PendingSink<(), ()> {
|
||||||
|
pub fn default<Req, Resp>(
|
||||||
|
) -> PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||||
|
PendingSink { ghost: PhantomData }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<In, Out> Stream for PendingSink<In, Out> {
|
||||||
|
type Item = In;
|
||||||
|
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
|
||||||
|
type Error = io::Error;
|
||||||
|
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
|
||||||
|
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
||||||
|
}
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<Req, Resp> Channel for PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||||
|
type Req = Req;
|
||||||
|
type Resp = Resp;
|
||||||
|
type Transport = ();
|
||||||
|
fn config(&self) -> &Config {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
fn in_flight_requests(&self) -> usize {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
fn transport(&self) -> &() {
|
||||||
|
&()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn throttler_start_send() {
|
||||||
|
let throttler = MaxRequests {
|
||||||
|
max_in_flight_requests: 0,
|
||||||
|
inner: FakeChannel::default::<isize, isize>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
pin_mut!(throttler);
|
||||||
|
throttler
|
||||||
|
.inner
|
||||||
|
.in_flight_requests
|
||||||
|
.start_request(
|
||||||
|
0,
|
||||||
|
SystemTime::now() + Duration::from_secs(1),
|
||||||
|
Span::current(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
throttler
|
||||||
|
.as_mut()
|
||||||
|
.start_send(Response {
|
||||||
|
request_id: 0,
|
||||||
|
message: Ok(1),
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(throttler.inner.in_flight_requests.len(), 0);
|
||||||
|
assert_eq!(
|
||||||
|
throttler.inner.sink.get(0),
|
||||||
|
Some(&Response {
|
||||||
|
request_id: 0,
|
||||||
|
message: Ok(1),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,16 +1,18 @@
|
|||||||
use crate::server::{Channel, Config};
|
// Copyright 2020 Google LLC
|
||||||
use crate::{context, Request, Response};
|
//
|
||||||
use fnv::FnvHashSet;
|
// Use of this source code is governed by an MIT-style
|
||||||
use futures::{
|
// license that can be found in the LICENSE file or at
|
||||||
future::{AbortHandle, AbortRegistration},
|
// https://opensource.org/licenses/MIT.
|
||||||
task::*,
|
|
||||||
Sink, Stream,
|
use crate::{
|
||||||
|
context,
|
||||||
|
server::{Channel, Config, TrackedRequest},
|
||||||
|
Request, Response,
|
||||||
};
|
};
|
||||||
|
use futures::{task::*, Sink, Stream};
|
||||||
use pin_project::pin_project;
|
use pin_project::pin_project;
|
||||||
use std::collections::VecDeque;
|
use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
|
||||||
use std::io;
|
use tracing::Span;
|
||||||
use std::pin::Pin;
|
|
||||||
use std::time::SystemTime;
|
|
||||||
|
|
||||||
#[pin_project]
|
#[pin_project]
|
||||||
pub(crate) struct FakeChannel<In, Out> {
|
pub(crate) struct FakeChannel<In, Out> {
|
||||||
@@ -19,7 +21,7 @@ pub(crate) struct FakeChannel<In, Out> {
|
|||||||
#[pin]
|
#[pin]
|
||||||
pub sink: VecDeque<Out>,
|
pub sink: VecDeque<Out>,
|
||||||
pub config: Config,
|
pub config: Config,
|
||||||
pub in_flight_requests: FnvHashSet<u64>,
|
pub in_flight_requests: super::in_flight_requests::InFlightRequests,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<In, Out> Stream for FakeChannel<In, Out>
|
impl<In, Out> Stream for FakeChannel<In, Out>
|
||||||
@@ -44,7 +46,7 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
|
|||||||
self.as_mut()
|
self.as_mut()
|
||||||
.project()
|
.project()
|
||||||
.in_flight_requests
|
.in_flight_requests
|
||||||
.remove(&response.request_id);
|
.remove_request(response.request_id);
|
||||||
self.project()
|
self.project()
|
||||||
.sink
|
.sink
|
||||||
.start_send(response)
|
.start_send(response)
|
||||||
@@ -60,44 +62,47 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Req, Resp> Channel for FakeChannel<io::Result<Request<Req>>, Response<Resp>>
|
impl<Req, Resp> Channel for FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>>
|
||||||
where
|
where
|
||||||
Req: Unpin,
|
Req: Unpin,
|
||||||
{
|
{
|
||||||
type Req = Req;
|
type Req = Req;
|
||||||
type Resp = Resp;
|
type Resp = Resp;
|
||||||
|
type Transport = ();
|
||||||
|
|
||||||
fn config(&self) -> &Config {
|
fn config(&self) -> &Config {
|
||||||
&self.config
|
&self.config
|
||||||
}
|
}
|
||||||
|
|
||||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
fn in_flight_requests(&self) -> usize {
|
||||||
self.in_flight_requests.len()
|
self.in_flight_requests.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration {
|
fn transport(&self) -> &() {
|
||||||
self.project().in_flight_requests.insert(id);
|
&()
|
||||||
AbortHandle::new_pair().1
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
|
impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||||
pub fn push_req(&mut self, id: u64, message: Req) {
|
pub fn push_req(&mut self, id: u64, message: Req) {
|
||||||
self.stream.push_back(Ok(Request {
|
let (_, abort_registration) = futures::future::AbortHandle::new_pair();
|
||||||
context: context::Context {
|
self.stream.push_back(Ok(TrackedRequest {
|
||||||
deadline: SystemTime::UNIX_EPOCH,
|
request: Request {
|
||||||
trace_context: Default::default(),
|
context: context::Context {
|
||||||
_non_exhaustive: (),
|
deadline: SystemTime::UNIX_EPOCH,
|
||||||
|
trace_context: Default::default(),
|
||||||
|
},
|
||||||
|
id,
|
||||||
|
message,
|
||||||
},
|
},
|
||||||
id,
|
abort_registration,
|
||||||
message,
|
span: Span::none(),
|
||||||
_non_exhaustive: (),
|
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FakeChannel<(), ()> {
|
impl FakeChannel<(), ()> {
|
||||||
pub fn default<Req, Resp>() -> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
|
pub fn default<Req, Resp>() -> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||||
FakeChannel {
|
FakeChannel {
|
||||||
stream: Default::default(),
|
stream: Default::default(),
|
||||||
sink: Default::default(),
|
sink: Default::default(),
|
||||||
@@ -113,10 +118,7 @@ pub trait PollExt {
|
|||||||
|
|
||||||
impl<T> PollExt for Poll<Option<T>> {
|
impl<T> PollExt for Poll<Option<T>> {
|
||||||
fn is_done(&self) -> bool {
|
fn is_done(&self) -> bool {
|
||||||
match self {
|
matches!(self, Poll::Ready(None))
|
||||||
Poll::Ready(None) => true,
|
|
||||||
_ => false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
111
tarpc/src/server/tokio.rs
Normal file
111
tarpc/src/server/tokio.rs
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
use super::{Channel, Requests, Serve};
|
||||||
|
use futures::{prelude::*, ready, task::*};
|
||||||
|
use pin_project::pin_project;
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
|
||||||
|
/// for each new channel. Returned by
|
||||||
|
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
|
||||||
|
#[pin_project]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TokioServerExecutor<T, S> {
|
||||||
|
#[pin]
|
||||||
|
inner: T,
|
||||||
|
serve: S,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, S> TokioServerExecutor<T, S> {
|
||||||
|
pub(crate) fn new(inner: T, serve: S) -> Self {
|
||||||
|
Self { inner, serve }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A future that drives the server by [spawning](tokio::spawn) each [response
|
||||||
|
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
|
||||||
|
/// [`Channel::execute`](crate::server::Channel::execute).
|
||||||
|
#[pin_project]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TokioChannelExecutor<T, S> {
|
||||||
|
#[pin]
|
||||||
|
inner: T,
|
||||||
|
serve: S,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, S> TokioServerExecutor<T, S> {
|
||||||
|
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||||
|
self.as_mut().project().inner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, S> TokioChannelExecutor<T, S> {
|
||||||
|
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||||
|
self.as_mut().project().inner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send + 'static execution helper methods.
|
||||||
|
|
||||||
|
impl<C> Requests<C>
|
||||||
|
where
|
||||||
|
C: Channel,
|
||||||
|
C::Req: Send + 'static,
|
||||||
|
C::Resp: Send + 'static,
|
||||||
|
{
|
||||||
|
/// Executes all requests using the given service function. Requests are handled concurrently
|
||||||
|
/// by [spawning](::tokio::spawn) each handler on tokio's default executor.
|
||||||
|
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
|
||||||
|
where
|
||||||
|
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
||||||
|
{
|
||||||
|
TokioChannelExecutor { inner: self, serve }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
|
||||||
|
where
|
||||||
|
St: Sized + Stream<Item = C>,
|
||||||
|
C: Channel + Send + 'static,
|
||||||
|
C::Req: Send + 'static,
|
||||||
|
C::Resp: Send + 'static,
|
||||||
|
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||||
|
Se::Fut: Send,
|
||||||
|
{
|
||||||
|
type Output = ();
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||||
|
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||||
|
tokio::spawn(channel.execute(self.serve.clone()));
|
||||||
|
}
|
||||||
|
tracing::info!("Server shutting down.");
|
||||||
|
Poll::Ready(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
|
||||||
|
where
|
||||||
|
C: Channel + 'static,
|
||||||
|
C::Req: Send + 'static,
|
||||||
|
C::Resp: Send + 'static,
|
||||||
|
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||||
|
S::Fut: Send,
|
||||||
|
{
|
||||||
|
type Output = ();
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||||
|
match response_handler {
|
||||||
|
Ok(resp) => {
|
||||||
|
let server = self.serve.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
resp.execute(server).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("Requests stream errored out: {}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Poll::Ready(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -16,18 +16,21 @@
|
|||||||
//! This crate's design is based on [opencensus
|
//! This crate's design is based on [opencensus
|
||||||
//! tracing](https://opencensus.io/core-concepts/tracing/).
|
//! tracing](https://opencensus.io/core-concepts/tracing/).
|
||||||
|
|
||||||
|
use opentelemetry::trace::TraceContextExt;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use std::{
|
use std::{
|
||||||
|
convert::TryFrom,
|
||||||
fmt::{self, Formatter},
|
fmt::{self, Formatter},
|
||||||
mem,
|
num::{NonZeroU128, NonZeroU64},
|
||||||
};
|
};
|
||||||
|
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||||
|
|
||||||
/// A context for tracing the execution of processes, distributed or otherwise.
|
/// A context for tracing the execution of processes, distributed or otherwise.
|
||||||
///
|
///
|
||||||
/// Consists of a span identifying an event, an optional parent span identifying a causal event
|
/// Consists of a span identifying an event, an optional parent span identifying a causal event
|
||||||
/// that triggered the current span, and a trace with which all related spans are associated.
|
/// that triggered the current span, and a trace with which all related spans are associated.
|
||||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
|
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
pub struct Context {
|
pub struct Context {
|
||||||
/// An identifier of the trace associated with the current context. A trace ID is typically
|
/// An identifier of the trace associated with the current context. A trace ID is typically
|
||||||
/// created at a root span and passed along through all causal events.
|
/// created at a root span and passed along through all causal events.
|
||||||
@@ -36,33 +39,50 @@ pub struct Context {
|
|||||||
/// before making an RPC, and the span ID is sent to the server. The server is free to create
|
/// before making an RPC, and the span ID is sent to the server. The server is free to create
|
||||||
/// its own spans, for which it sets the client's span as the parent span.
|
/// its own spans, for which it sets the client's span as the parent span.
|
||||||
pub span_id: SpanId,
|
pub span_id: SpanId,
|
||||||
/// An identifier of the span that originated the current span. For example, if a server sends
|
/// Indicates whether a sampler has already decided whether or not to sample the trace
|
||||||
/// an RPC in response to a client request that included a span, the server would create a span
|
/// associated with the Context. If `sampling_decision` is None, then a decision has not yet
|
||||||
/// for the RPC and set its parent to the span_id in the incoming request's context.
|
/// been made. Downstream samplers do not need to abide by "no sample" decisions--for example,
|
||||||
///
|
/// an upstream client may choose to never sample, which may not make sense for the client's
|
||||||
/// If `parent_id` is `None`, then this is a root context.
|
/// dependencies. On the other hand, if an upstream process has chosen to sample this trace,
|
||||||
pub parent_id: Option<SpanId>,
|
/// then the downstream samplers are expected to respect that decision and also sample the
|
||||||
|
/// trace. Otherwise, the full trace would not be able to be reconstructed.
|
||||||
|
pub sampling_decision: SamplingDecision,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the
|
/// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the
|
||||||
/// same trace ID.
|
/// same trace ID.
|
||||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
|
#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
pub struct TraceId(u128);
|
pub struct TraceId(#[cfg_attr(feature = "serde1", serde(with = "u128_serde"))] u128);
|
||||||
|
|
||||||
/// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace.
|
/// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace.
|
||||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
|
#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
pub struct SpanId(u64);
|
pub struct SpanId(u64);
|
||||||
|
|
||||||
|
/// Indicates whether a sampler has decided whether or not to sample the trace associated with the
|
||||||
|
/// Context. Downstream samplers do not need to abide by "no sample" decisions--for example, an
|
||||||
|
/// upstream client may choose to never sample, which may not make sense for the client's
|
||||||
|
/// dependencies. On the other hand, if an upstream process has chosen to sample this trace, then
|
||||||
|
/// the downstream samplers are expected to respect that decision and also sample the trace.
|
||||||
|
/// Otherwise, the full trace would not be able to be reconstructed reliably.
|
||||||
|
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
|
||||||
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
|
#[repr(u8)]
|
||||||
|
pub enum SamplingDecision {
|
||||||
|
/// The associated span was sampled by its creating process. Child spans must also be sampled.
|
||||||
|
Sampled,
|
||||||
|
/// The associated span was not sampled by its creating process.
|
||||||
|
Unsampled,
|
||||||
|
}
|
||||||
|
|
||||||
impl Context {
|
impl Context {
|
||||||
/// Constructs a new root context. A root context is one with no parent span.
|
/// Constructs a new context with the trace ID and sampling decision inherited from the parent.
|
||||||
pub fn new_root() -> Self {
|
pub(crate) fn new_child(&self) -> Self {
|
||||||
let rng = &mut rand::thread_rng();
|
Self {
|
||||||
Context {
|
trace_id: self.trace_id,
|
||||||
trace_id: TraceId::random(rng),
|
span_id: SpanId::random(&mut rand::thread_rng()),
|
||||||
span_id: SpanId::random(rng),
|
sampling_decision: self.sampling_decision,
|
||||||
parent_id: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -71,17 +91,128 @@ impl TraceId {
|
|||||||
/// Returns a random trace ID that can be assumed to be globally unique if `rng` generates
|
/// Returns a random trace ID that can be assumed to be globally unique if `rng` generates
|
||||||
/// actually-random numbers.
|
/// actually-random numbers.
|
||||||
pub fn random<R: Rng>(rng: &mut R) -> Self {
|
pub fn random<R: Rng>(rng: &mut R) -> Self {
|
||||||
TraceId(u128::from(rng.next_u64()) << mem::size_of::<u64>() | u128::from(rng.next_u64()))
|
TraceId(rng.gen::<NonZeroU128>().get())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true iff the trace ID is 0.
|
||||||
|
pub fn is_none(&self) -> bool {
|
||||||
|
self.0 == 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SpanId {
|
impl SpanId {
|
||||||
/// Returns a random span ID that can be assumed to be unique within a single trace.
|
/// Returns a random span ID that can be assumed to be unique within a single trace.
|
||||||
pub fn random<R: Rng>(rng: &mut R) -> Self {
|
pub fn random<R: Rng>(rng: &mut R) -> Self {
|
||||||
SpanId(rng.next_u64())
|
SpanId(rng.gen::<NonZeroU64>().get())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true iff the span ID is 0.
|
||||||
|
pub fn is_none(&self) -> bool {
|
||||||
|
self.0 == 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<TraceId> for u128 {
|
||||||
|
fn from(trace_id: TraceId) -> Self {
|
||||||
|
trace_id.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<u128> for TraceId {
|
||||||
|
fn from(trace_id: u128) -> Self {
|
||||||
|
Self(trace_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SpanId> for u64 {
|
||||||
|
fn from(span_id: SpanId) -> Self {
|
||||||
|
span_id.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<u64> for SpanId {
|
||||||
|
fn from(span_id: u64) -> Self {
|
||||||
|
Self(span_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<opentelemetry::trace::TraceId> for TraceId {
|
||||||
|
fn from(trace_id: opentelemetry::trace::TraceId) -> Self {
|
||||||
|
Self::from(trace_id.to_u128())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<TraceId> for opentelemetry::trace::TraceId {
|
||||||
|
fn from(trace_id: TraceId) -> Self {
|
||||||
|
Self::from_u128(trace_id.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<opentelemetry::trace::SpanId> for SpanId {
|
||||||
|
fn from(span_id: opentelemetry::trace::SpanId) -> Self {
|
||||||
|
Self::from(span_id.to_u64())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SpanId> for opentelemetry::trace::SpanId {
|
||||||
|
fn from(span_id: SpanId) -> Self {
|
||||||
|
Self::from_u64(span_id.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&tracing::Span> for Context {
|
||||||
|
type Error = NoActiveSpan;
|
||||||
|
|
||||||
|
fn try_from(span: &tracing::Span) -> Result<Self, NoActiveSpan> {
|
||||||
|
let context = span.context();
|
||||||
|
if context.has_active_span() {
|
||||||
|
Ok(Self::from(context.span()))
|
||||||
|
} else {
|
||||||
|
Err(NoActiveSpan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<opentelemetry::trace::SpanRef<'_>> for Context {
|
||||||
|
fn from(span: opentelemetry::trace::SpanRef<'_>) -> Self {
|
||||||
|
let otel_ctx = span.span_context();
|
||||||
|
Self {
|
||||||
|
trace_id: TraceId::from(otel_ctx.trace_id()),
|
||||||
|
span_id: SpanId::from(otel_ctx.span_id()),
|
||||||
|
sampling_decision: SamplingDecision::from(otel_ctx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SamplingDecision> for opentelemetry::trace::TraceFlags {
|
||||||
|
fn from(decision: SamplingDecision) -> Self {
|
||||||
|
match decision {
|
||||||
|
SamplingDecision::Sampled => opentelemetry::trace::TraceFlags::SAMPLED,
|
||||||
|
SamplingDecision::Unsampled => opentelemetry::trace::TraceFlags::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&opentelemetry::trace::SpanContext> for SamplingDecision {
|
||||||
|
fn from(context: &opentelemetry::trace::SpanContext) -> Self {
|
||||||
|
if context.is_sampled() {
|
||||||
|
SamplingDecision::Sampled
|
||||||
|
} else {
|
||||||
|
SamplingDecision::Unsampled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for SamplingDecision {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::Unsampled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returned when a [`Context`] cannot be constructed from a [`Span`](tracing::Span).
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct NoActiveSpan;
|
||||||
|
|
||||||
impl fmt::Display for TraceId {
|
impl fmt::Display for TraceId {
|
||||||
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||||
write!(f, "{:02x}", self.0)?;
|
write!(f, "{:02x}", self.0)?;
|
||||||
@@ -89,9 +220,42 @@ impl fmt::Display for TraceId {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for TraceId {
|
||||||
|
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||||
|
write!(f, "{:02x}", self.0)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl fmt::Display for SpanId {
|
impl fmt::Display for SpanId {
|
||||||
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||||
write!(f, "{:02x}", self.0)?;
|
write!(f, "{:02x}", self.0)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for SpanId {
|
||||||
|
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||||
|
write!(f, "{:02x}", self.0)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "serde1")]
|
||||||
|
mod u128_serde {
|
||||||
|
pub fn serialize<S>(u: &u128, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
serde::Serialize::serialize(&u.to_le_bytes(), serializer)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<u128, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
Ok(u128::from_le_bytes(serde::Deserialize::deserialize(
|
||||||
|
deserializer,
|
||||||
|
)?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
40
tarpc/src/transport.rs
Normal file
40
tarpc/src/transport.rs
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
// Copyright 2018 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
//! Provides a [`Transport`](sealed::Transport) trait as well as implementations.
|
||||||
|
//!
|
||||||
|
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`](sealed::Transport)
|
||||||
|
//! can be plugged in, using whatever protocol it wants.
|
||||||
|
|
||||||
|
pub mod channel;
|
||||||
|
|
||||||
|
pub(crate) mod sealed {
|
||||||
|
use futures::prelude::*;
|
||||||
|
use std::error::Error;
|
||||||
|
|
||||||
|
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
|
||||||
|
pub trait Transport<SinkItem, Item>
|
||||||
|
where
|
||||||
|
Self: Stream<Item = Result<Item, <Self as Sink<SinkItem>>::Error>>,
|
||||||
|
Self: Sink<SinkItem, Error = <Self as Transport<SinkItem, Item>>::TransportError>,
|
||||||
|
<Self as Sink<SinkItem>>::Error: Error,
|
||||||
|
{
|
||||||
|
/// Associated type where clauses are not elaborated; this associated type allows users
|
||||||
|
/// bounding types by Transport to avoid having to explicitly add `T::Error: Error` to their
|
||||||
|
/// bounds.
|
||||||
|
type TransportError: Error + Send + Sync + 'static;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, SinkItem, Item, E> Transport<SinkItem, Item> for T
|
||||||
|
where
|
||||||
|
T: ?Sized,
|
||||||
|
T: Stream<Item = Result<Item, E>>,
|
||||||
|
T: Sink<SinkItem, Error = E>,
|
||||||
|
T::Error: Error + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
type TransportError = E;
|
||||||
|
}
|
||||||
|
}
|
||||||
202
tarpc/src/transport/channel.rs
Normal file
202
tarpc/src/transport/channel.rs
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
// Copyright 2018 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
//! Transports backed by in-memory channels.
|
||||||
|
|
||||||
|
use futures::{task::*, Sink, Stream};
|
||||||
|
use pin_project::pin_project;
|
||||||
|
use std::{error::Error, pin::Pin};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
/// Errors that occur in the sending or receiving of messages over a channel.
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum ChannelError {
|
||||||
|
/// An error occurred sending over the channel.
|
||||||
|
#[error("an error occurred sending over the channel")]
|
||||||
|
Send(#[source] Box<dyn Error + Send + Sync + 'static>),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
|
||||||
|
/// [`Sink`].
|
||||||
|
pub fn unbounded<SinkItem, Item>() -> (
|
||||||
|
UnboundedChannel<SinkItem, Item>,
|
||||||
|
UnboundedChannel<Item, SinkItem>,
|
||||||
|
) {
|
||||||
|
let (tx1, rx2) = mpsc::unbounded_channel();
|
||||||
|
let (tx2, rx1) = mpsc::unbounded_channel();
|
||||||
|
(
|
||||||
|
UnboundedChannel { tx: tx1, rx: rx1 },
|
||||||
|
UnboundedChannel { tx: tx2, rx: rx2 },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender)
|
||||||
|
/// and [`UnboundedReceiver`](mpsc::UnboundedReceiver).
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct UnboundedChannel<Item, SinkItem> {
|
||||||
|
rx: mpsc::UnboundedReceiver<Item>,
|
||||||
|
tx: mpsc::UnboundedSender<SinkItem>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
|
||||||
|
type Item = Result<Item, ChannelError>;
|
||||||
|
|
||||||
|
fn poll_next(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Option<Result<Item, ChannelError>>> {
|
||||||
|
self.rx.poll_recv(cx).map(|option| option.map(Ok))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const CLOSED_MESSAGE: &str = "the channel is closed and cannot accept new items for sending";
|
||||||
|
|
||||||
|
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
|
||||||
|
type Error = ChannelError;
|
||||||
|
|
||||||
|
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(if self.tx.is_closed() {
|
||||||
|
Err(ChannelError::Send(CLOSED_MESSAGE.into()))
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||||
|
self.tx
|
||||||
|
.send(item)
|
||||||
|
.map_err(|_| ChannelError::Send(CLOSED_MESSAGE.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
// UnboundedSender requires no flushing.
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
// UnboundedSender can't initiate closure.
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns two channel peers with buffer equal to `capacity`. Each [`Stream`] yields items sent
|
||||||
|
/// through the other's [`Sink`].
|
||||||
|
pub fn bounded<SinkItem, Item>(
|
||||||
|
capacity: usize,
|
||||||
|
) -> (Channel<SinkItem, Item>, Channel<Item, SinkItem>) {
|
||||||
|
let (tx1, rx2) = futures::channel::mpsc::channel(capacity);
|
||||||
|
let (tx2, rx1) = futures::channel::mpsc::channel(capacity);
|
||||||
|
(Channel { tx: tx1, rx: rx1 }, Channel { tx: tx2, rx: rx2 })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A bi-directional channel backed by a [`Sender`](futures::channel::mpsc::Sender)
|
||||||
|
/// and [`Receiver`](futures::channel::mpsc::Receiver).
|
||||||
|
#[pin_project]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Channel<Item, SinkItem> {
|
||||||
|
#[pin]
|
||||||
|
rx: futures::channel::mpsc::Receiver<Item>,
|
||||||
|
#[pin]
|
||||||
|
tx: futures::channel::mpsc::Sender<SinkItem>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Item, SinkItem> Stream for Channel<Item, SinkItem> {
|
||||||
|
type Item = Result<Item, ChannelError>;
|
||||||
|
|
||||||
|
fn poll_next(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Option<Result<Item, ChannelError>>> {
|
||||||
|
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
|
||||||
|
type Error = ChannelError;
|
||||||
|
|
||||||
|
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.project()
|
||||||
|
.tx
|
||||||
|
.poll_ready(cx)
|
||||||
|
.map_err(|e| ChannelError::Send(Box::new(e)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||||
|
self.project()
|
||||||
|
.tx
|
||||||
|
.start_send(item)
|
||||||
|
.map_err(|e| ChannelError::Send(Box::new(e)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.project()
|
||||||
|
.tx
|
||||||
|
.poll_flush(cx)
|
||||||
|
.map_err(|e| ChannelError::Send(Box::new(e)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.project()
|
||||||
|
.tx
|
||||||
|
.poll_close(cx)
|
||||||
|
.map_err(|e| ChannelError::Send(Box::new(e)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[cfg(feature = "tokio1")]
|
||||||
|
mod tests {
|
||||||
|
use crate::{
|
||||||
|
client, context,
|
||||||
|
server::{incoming::Incoming, BaseChannel},
|
||||||
|
transport::{
|
||||||
|
self,
|
||||||
|
channel::{Channel, UnboundedChannel},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use assert_matches::assert_matches;
|
||||||
|
use futures::{prelude::*, stream};
|
||||||
|
use std::io;
|
||||||
|
use tracing::trace;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ensure_is_transport() {
|
||||||
|
fn is_transport<SinkItem, Item, T: crate::Transport<SinkItem, Item>>() {}
|
||||||
|
is_transport::<(), (), UnboundedChannel<(), ()>>();
|
||||||
|
is_transport::<(), (), Channel<(), ()>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn integration() -> anyhow::Result<()> {
|
||||||
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
|
||||||
|
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||||
|
tokio::spawn(
|
||||||
|
stream::once(future::ready(server_channel))
|
||||||
|
.map(BaseChannel::with_defaults)
|
||||||
|
.execute(|_ctx, request: String| {
|
||||||
|
future::ready(request.parse::<u64>().map_err(|_| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::InvalidInput,
|
||||||
|
format!("{:?} is not an int", request),
|
||||||
|
)
|
||||||
|
}))
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
let client = client::new(client::Config::default(), client_channel).spawn();
|
||||||
|
|
||||||
|
let response1 = client.call(context::current(), "", "123".into()).await?;
|
||||||
|
let response2 = client.call(context::current(), "", "abc".into()).await?;
|
||||||
|
|
||||||
|
trace!("response1: {:?}, response2: {:?}", response1, response2);
|
||||||
|
|
||||||
|
assert_matches!(response1, Ok(123));
|
||||||
|
assert_matches!(response2, Err(ref e) if e.kind() == io::ErrorKind::InvalidInput);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,7 +10,8 @@ use std::{
|
|||||||
time::{Duration, SystemTime},
|
time::{Duration, SystemTime},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde1")]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(feature = "serde1")))]
|
||||||
pub mod serde;
|
pub mod serde;
|
||||||
|
|
||||||
/// Extension trait for [SystemTimes](SystemTime) in the future, i.e. deadlines.
|
/// Extension trait for [SystemTimes](SystemTime) in the future, i.e. deadlines.
|
||||||
@@ -5,30 +5,7 @@
|
|||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
use std::{
|
use std::io;
|
||||||
io,
|
|
||||||
time::{Duration, SystemTime},
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Serializes `system_time` as a `u64` equal to the number of seconds since the epoch.
|
|
||||||
pub fn serialize_epoch_secs<S>(system_time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
|
|
||||||
where
|
|
||||||
S: Serializer,
|
|
||||||
{
|
|
||||||
system_time
|
|
||||||
.duration_since(SystemTime::UNIX_EPOCH)
|
|
||||||
.unwrap_or(Duration::from_secs(0))
|
|
||||||
.as_secs() // Only care about second precision
|
|
||||||
.serialize(serializer)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Deserializes [`SystemTime`] from a `u64` equal to the number of seconds since the epoch.
|
|
||||||
pub fn deserialize_epoch_secs<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
|
|
||||||
where
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
{
|
|
||||||
Ok(SystemTime::UNIX_EPOCH + Duration::from_secs(u64::deserialize(deserializer)?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Serializes [`io::ErrorKind`] as a `u32`.
|
/// Serializes [`io::ErrorKind`] as a `u32`.
|
||||||
#[allow(clippy::trivially_copy_pass_by_ref)] // Exact fn signature required by serde derive
|
#[allow(clippy::trivially_copy_pass_by_ref)] // Exact fn signature required by serde derive
|
||||||
5
tarpc/tests/compile_fail.rs
Normal file
5
tarpc/tests/compile_fail.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
#[test]
|
||||||
|
fn ui() {
|
||||||
|
let t = trybuild::TestCases::new();
|
||||||
|
t.compile_fail("tests/compile_fail/*.rs");
|
||||||
|
}
|
||||||
15
tarpc/tests/compile_fail/tarpc_server_missing_async.rs
Normal file
15
tarpc/tests/compile_fail/tarpc_server_missing_async.rs
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
#[tarpc::service(derive_serde = false)]
|
||||||
|
trait World {
|
||||||
|
async fn hello(name: String) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct HelloServer;
|
||||||
|
|
||||||
|
#[tarpc::server]
|
||||||
|
impl World for HelloServer {
|
||||||
|
fn hello(name: String) -> String {
|
||||||
|
format!("Hello, {}!", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
||||||
11
tarpc/tests/compile_fail/tarpc_server_missing_async.stderr
Normal file
11
tarpc/tests/compile_fail/tarpc_server_missing_async.stderr
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
error: not all trait items implemented, missing: `HelloFut`
|
||||||
|
--> $DIR/tarpc_server_missing_async.rs:9:1
|
||||||
|
|
|
||||||
|
9 | impl World for HelloServer {
|
||||||
|
| ^^^^
|
||||||
|
|
||||||
|
error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async
|
||||||
|
--> $DIR/tarpc_server_missing_async.rs:10:5
|
||||||
|
|
|
||||||
|
10 | fn hello(name: String) -> String {
|
||||||
|
| ^^
|
||||||
6
tarpc/tests/compile_fail/tarpc_service_arg_pat.rs
Normal file
6
tarpc/tests/compile_fail/tarpc_service_arg_pat.rs
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
#[tarpc::service]
|
||||||
|
trait World {
|
||||||
|
async fn pat((a, b): (u8, u32));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
||||||
5
tarpc/tests/compile_fail/tarpc_service_arg_pat.stderr
Normal file
5
tarpc/tests/compile_fail/tarpc_service_arg_pat.stderr
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
error: patterns aren't allowed in RPC args
|
||||||
|
--> $DIR/tarpc_service_arg_pat.rs:3:18
|
||||||
|
|
|
||||||
|
3 | async fn pat((a, b): (u8, u32));
|
||||||
|
| ^^^^^^
|
||||||
6
tarpc/tests/compile_fail/tarpc_service_fn_new.rs
Normal file
6
tarpc/tests/compile_fail/tarpc_service_fn_new.rs
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
#[tarpc::service]
|
||||||
|
trait World {
|
||||||
|
async fn new();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
||||||
5
tarpc/tests/compile_fail/tarpc_service_fn_new.stderr
Normal file
5
tarpc/tests/compile_fail/tarpc_service_fn_new.stderr
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
error: method name conflicts with generated fn `WorldClient::new`
|
||||||
|
--> $DIR/tarpc_service_fn_new.rs:3:14
|
||||||
|
|
|
||||||
|
3 | async fn new();
|
||||||
|
| ^^^
|
||||||
6
tarpc/tests/compile_fail/tarpc_service_fn_serve.rs
Normal file
6
tarpc/tests/compile_fail/tarpc_service_fn_serve.rs
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
#[tarpc::service]
|
||||||
|
trait World {
|
||||||
|
async fn serve();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
||||||
5
tarpc/tests/compile_fail/tarpc_service_fn_serve.stderr
Normal file
5
tarpc/tests/compile_fail/tarpc_service_fn_serve.stderr
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
error: method name conflicts with generated fn `World::serve`
|
||||||
|
--> $DIR/tarpc_service_fn_serve.rs:3:14
|
||||||
|
|
|
||||||
|
3 | async fn serve();
|
||||||
|
| ^^^^^
|
||||||
55
tarpc/tests/dataservice.rs
Normal file
55
tarpc/tests/dataservice.rs
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
use futures::prelude::*;
|
||||||
|
use tarpc::serde_transport;
|
||||||
|
use tarpc::{
|
||||||
|
client, context,
|
||||||
|
server::{incoming::Incoming, BaseChannel},
|
||||||
|
};
|
||||||
|
use tokio_serde::formats::Json;
|
||||||
|
|
||||||
|
#[tarpc::derive_serde]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
|
pub enum TestData {
|
||||||
|
Black,
|
||||||
|
White,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tarpc::service]
|
||||||
|
pub trait ColorProtocol {
|
||||||
|
async fn get_opposite_color(color: TestData) -> TestData;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct ColorServer;
|
||||||
|
|
||||||
|
#[tarpc::server]
|
||||||
|
impl ColorProtocol for ColorServer {
|
||||||
|
async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData {
|
||||||
|
match color {
|
||||||
|
TestData::White => TestData::Black,
|
||||||
|
TestData::Black => TestData::White,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_call() -> anyhow::Result<()> {
|
||||||
|
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
|
||||||
|
let addr = transport.local_addr();
|
||||||
|
tokio::spawn(
|
||||||
|
transport
|
||||||
|
.take(1)
|
||||||
|
.filter_map(|r| async { r.ok() })
|
||||||
|
.map(BaseChannel::with_defaults)
|
||||||
|
.execute(ColorServer.serve()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||||
|
let client = ColorProtocolClient::new(client::Config::default(), transport).spawn();
|
||||||
|
|
||||||
|
let color = client
|
||||||
|
.get_opposite_color(context::current(), TestData::White)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(color, TestData::Black);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -1,16 +1,16 @@
|
|||||||
use assert_matches::assert_matches;
|
use assert_matches::assert_matches;
|
||||||
use futures::{
|
use futures::{
|
||||||
future::{ready, Ready},
|
future::{join_all, ready, Ready},
|
||||||
prelude::*,
|
prelude::*,
|
||||||
};
|
};
|
||||||
use std::io;
|
use std::time::{Duration, SystemTime};
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
client::{self},
|
client::{self},
|
||||||
context, serde_transport,
|
context,
|
||||||
server::{self, BaseChannel, Channel, Handler},
|
server::{self, incoming::Incoming, BaseChannel, Channel},
|
||||||
transport::channel,
|
transport::channel,
|
||||||
};
|
};
|
||||||
use tokio_serde::formats::Json;
|
use tokio::join;
|
||||||
|
|
||||||
#[tarpc_plugins::service]
|
#[tarpc_plugins::service]
|
||||||
trait Service {
|
trait Service {
|
||||||
@@ -35,19 +35,19 @@ impl Service for Server {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test(threaded_scheduler)]
|
#[tokio::test]
|
||||||
async fn sequential() -> io::Result<()> {
|
async fn sequential() -> anyhow::Result<()> {
|
||||||
let _ = env_logger::try_init();
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
|
||||||
let (tx, rx) = channel::unbounded();
|
let (tx, rx) = channel::unbounded();
|
||||||
|
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
BaseChannel::new(server::Config::default(), rx)
|
BaseChannel::new(server::Config::default(), rx)
|
||||||
.respond_with(Server.serve())
|
.requests()
|
||||||
.execute(),
|
.execute(Server.serve()),
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||||
|
|
||||||
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
|
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
@@ -57,21 +57,75 @@ async fn sequential() -> io::Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "serde1")]
|
#[tokio::test]
|
||||||
#[tokio::test(threaded_scheduler)]
|
async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
|
||||||
async fn serde() -> io::Result<()> {
|
#[tarpc_plugins::service]
|
||||||
let _ = env_logger::try_init();
|
trait Loop {
|
||||||
|
async fn r#loop();
|
||||||
|
}
|
||||||
|
|
||||||
let transport = serde_transport::tcp::listen("localhost:56789", Json::default).await?;
|
#[derive(Clone)]
|
||||||
|
struct LoopServer;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct AllHandlersComplete;
|
||||||
|
|
||||||
|
#[tarpc::server]
|
||||||
|
impl Loop for LoopServer {
|
||||||
|
async fn r#loop(self, _: context::Context) {
|
||||||
|
loop {
|
||||||
|
futures::pending!();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
|
||||||
|
let (tx, rx) = channel::unbounded();
|
||||||
|
|
||||||
|
// Set up a client that initiates a long-lived request.
|
||||||
|
// The request will complete in error when the server drops the connection.
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let client = LoopClient::new(client::Config::default(), tx).spawn();
|
||||||
|
|
||||||
|
let mut ctx = context::current();
|
||||||
|
ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60);
|
||||||
|
let _ = client.r#loop(ctx).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut requests = BaseChannel::with_defaults(rx).requests();
|
||||||
|
// Reading a request should trigger the request being registered with BaseChannel.
|
||||||
|
let first_request = requests.next().await.unwrap()?;
|
||||||
|
// Dropping the channel should trigger cleanup of outstanding requests.
|
||||||
|
drop(requests);
|
||||||
|
// In-flight requests should be aborted by channel cleanup.
|
||||||
|
// The first and only request sent by the client is `loop`, which is an infinite loop
|
||||||
|
// on the server side, so if cleanup was not triggered, this line should hang indefinitely.
|
||||||
|
first_request.execute(LoopServer.serve()).await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn serde() -> anyhow::Result<()> {
|
||||||
|
use tarpc::serde_transport;
|
||||||
|
use tokio_serde::formats::Json;
|
||||||
|
|
||||||
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
|
||||||
|
let transport = tarpc::serde_transport::tcp::listen("localhost:56789", Json::default).await?;
|
||||||
let addr = transport.local_addr();
|
let addr = transport.local_addr();
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
tarpc::Server::default()
|
transport
|
||||||
.incoming(transport.take(1).filter_map(|r| async { r.ok() }))
|
.take(1)
|
||||||
.respond_with(Server.serve()),
|
.filter_map(|r| async { r.ok() })
|
||||||
|
.map(BaseChannel::with_defaults)
|
||||||
|
.execute(Server.serve()),
|
||||||
);
|
);
|
||||||
|
|
||||||
let transport = serde_transport::tcp::connect(addr, Json::default()).await?;
|
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||||
let mut client = ServiceClient::new(client::Config::default(), transport).spawn()?;
|
let client = ServiceClient::new(client::Config::default(), transport).spawn();
|
||||||
|
|
||||||
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
|
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
@@ -82,27 +136,22 @@ async fn serde() -> io::Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test(threaded_scheduler)]
|
#[tokio::test]
|
||||||
async fn concurrent() -> io::Result<()> {
|
async fn concurrent() -> anyhow::Result<()> {
|
||||||
let _ = env_logger::try_init();
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
|
||||||
let (tx, rx) = channel::unbounded();
|
let (tx, rx) = channel::unbounded();
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
tarpc::Server::default()
|
stream::once(ready(rx))
|
||||||
.incoming(stream::once(ready(rx)))
|
.map(BaseChannel::with_defaults)
|
||||||
.respond_with(Server.serve()),
|
.execute(Server.serve()),
|
||||||
);
|
);
|
||||||
|
|
||||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||||
|
|
||||||
let mut c = client.clone();
|
let req1 = client.add(context::current(), 1, 2);
|
||||||
let req1 = c.add(context::current(), 1, 2);
|
let req2 = client.add(context::current(), 3, 4);
|
||||||
|
let req3 = client.hey(context::current(), "Tim".to_string());
|
||||||
let mut c = client.clone();
|
|
||||||
let req2 = c.add(context::current(), 3, 4);
|
|
||||||
|
|
||||||
let mut c = client.clone();
|
|
||||||
let req3 = c.hey(context::current(), "Tim".to_string());
|
|
||||||
|
|
||||||
assert_matches!(req1.await, Ok(3));
|
assert_matches!(req1.await, Ok(3));
|
||||||
assert_matches!(req2.await, Ok(7));
|
assert_matches!(req2.await, Ok(7));
|
||||||
@@ -110,3 +159,86 @@ async fn concurrent() -> io::Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn concurrent_join() -> anyhow::Result<()> {
|
||||||
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
|
||||||
|
let (tx, rx) = channel::unbounded();
|
||||||
|
tokio::spawn(
|
||||||
|
stream::once(ready(rx))
|
||||||
|
.map(BaseChannel::with_defaults)
|
||||||
|
.execute(Server.serve()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||||
|
|
||||||
|
let req1 = client.add(context::current(), 1, 2);
|
||||||
|
let req2 = client.add(context::current(), 3, 4);
|
||||||
|
let req3 = client.hey(context::current(), "Tim".to_string());
|
||||||
|
|
||||||
|
let (resp1, resp2, resp3) = join!(req1, req2, req3);
|
||||||
|
assert_matches!(resp1, Ok(3));
|
||||||
|
assert_matches!(resp2, Ok(7));
|
||||||
|
assert_matches!(resp3, Ok(ref s) if s == "Hey, Tim.");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn concurrent_join_all() -> anyhow::Result<()> {
|
||||||
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
|
||||||
|
let (tx, rx) = channel::unbounded();
|
||||||
|
tokio::spawn(
|
||||||
|
stream::once(ready(rx))
|
||||||
|
.map(BaseChannel::with_defaults)
|
||||||
|
.execute(Server.serve()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||||
|
|
||||||
|
let req1 = client.add(context::current(), 1, 2);
|
||||||
|
let req2 = client.add(context::current(), 3, 4);
|
||||||
|
|
||||||
|
let responses = join_all(vec![req1, req2]).await;
|
||||||
|
assert_matches!(responses[0], Ok(3));
|
||||||
|
assert_matches!(responses[1], Ok(7));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn counter() -> anyhow::Result<()> {
|
||||||
|
#[tarpc::service]
|
||||||
|
trait Counter {
|
||||||
|
async fn count() -> u32;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CountService(u32);
|
||||||
|
|
||||||
|
impl Counter for &mut CountService {
|
||||||
|
type CountFut = futures::future::Ready<u32>;
|
||||||
|
|
||||||
|
fn count(self, _: context::Context) -> Self::CountFut {
|
||||||
|
self.0 += 1;
|
||||||
|
futures::future::ready(self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (tx, rx) = channel::unbounded();
|
||||||
|
tokio::spawn(async {
|
||||||
|
let mut requests = BaseChannel::with_defaults(rx).requests();
|
||||||
|
let mut counter = CountService(0);
|
||||||
|
|
||||||
|
while let Some(Ok(request)) = requests.next().await {
|
||||||
|
request.execute(counter.serve()).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let client = CounterClient::new(client::Config::default(), tx).spawn();
|
||||||
|
assert_matches!(client.count(context::current()).await, Ok(1));
|
||||||
|
assert_matches!(client.count(context::current()).await, Ok(2));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user