mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
Compare commits
64 Commits
client-clo
...
request-ho
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a6758fd1f9 | ||
|
|
2c241cc809 | ||
|
|
263ef8a897 | ||
|
|
d50290a21c | ||
|
|
26988cb833 | ||
|
|
6cf18a1caf | ||
|
|
84932df9b4 | ||
|
|
8dc3711a80 | ||
|
|
7c5afa97bb | ||
|
|
324df5cd15 | ||
|
|
3264979993 | ||
|
|
dd63fb59bf | ||
|
|
f4db8cc5b4 | ||
|
|
e9ba350496 | ||
|
|
e6d779e70b | ||
|
|
ce5f8cfb0c | ||
|
|
4b69dc8db5 | ||
|
|
866db2a2cd | ||
|
|
bed85e2827 | ||
|
|
93f3880025 | ||
|
|
878f594d5b | ||
|
|
aa9bbad109 | ||
|
|
7e872ce925 | ||
|
|
62541b709d | ||
|
|
8c43f94fb6 | ||
|
|
7fa4e5064d | ||
|
|
94db7610bb | ||
|
|
0c08d5e8ca | ||
|
|
75b15fe2aa | ||
|
|
863a08d87e | ||
|
|
49ba8f8b1b | ||
|
|
d832209da3 | ||
|
|
584426d414 | ||
|
|
50eb80c883 | ||
|
|
1f0c80d8c9 | ||
|
|
99bf3e62a3 | ||
|
|
68863e3db0 | ||
|
|
453ba1c074 | ||
|
|
e3eac1b4f5 | ||
|
|
0e102288a5 | ||
|
|
4c8ba41b2f | ||
|
|
946c627579 | ||
|
|
104dd71bba | ||
|
|
012c481861 | ||
|
|
dc12bd09aa | ||
|
|
2594ea8ce9 | ||
|
|
839b87e394 | ||
|
|
57d0638a99 | ||
|
|
a3a6404a30 | ||
|
|
b36eac80b1 | ||
|
|
d7070e4bc3 | ||
|
|
b5d1828308 | ||
|
|
92cfe63c4f | ||
|
|
839a2f067c | ||
|
|
b5d593488c | ||
|
|
eea38b8bf4 | ||
|
|
70493c15f4 | ||
|
|
f7c5d6a7c3 | ||
|
|
98c5d2a18b | ||
|
|
46b534f7c6 | ||
|
|
42b4fc52b1 | ||
|
|
350dbcdad0 | ||
|
|
b1b4461d89 | ||
|
|
f694b7573a |
91
.github/workflows/main.yml
vendored
91
.github/workflows/main.yml
vendored
@@ -14,99 +14,54 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Cancel previous
|
- name: Cancel previous
|
||||||
uses: styfle/cancel-workflow-action@0.7.0
|
uses: styfle/cancel-workflow-action@0.10.0
|
||||||
with:
|
with:
|
||||||
access_token: ${{ github.token }}
|
access_token: ${{ github.token }}
|
||||||
- uses: actions/checkout@v1
|
- uses: actions/checkout@v3
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: dtolnay/rust-toolchain@stable
|
||||||
with:
|
- run: cargo check --all-features
|
||||||
profile: minimal
|
|
||||||
toolchain: stable
|
|
||||||
target: mipsel-unknown-linux-gnu
|
|
||||||
override: true
|
|
||||||
- uses: actions-rs/cargo@v1
|
|
||||||
with:
|
|
||||||
command: check
|
|
||||||
args: --all-features
|
|
||||||
- uses: actions-rs/cargo@v1
|
|
||||||
with:
|
|
||||||
command: check
|
|
||||||
args: --all-features --target mipsel-unknown-linux-gnu
|
|
||||||
|
|
||||||
test:
|
test:
|
||||||
name: Test Suite
|
name: Test Suite
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Cancel previous
|
- name: Cancel previous
|
||||||
uses: styfle/cancel-workflow-action@0.7.0
|
uses: styfle/cancel-workflow-action@0.10.0
|
||||||
with:
|
with:
|
||||||
access_token: ${{ github.token }}
|
access_token: ${{ github.token }}
|
||||||
- uses: actions/checkout@v1
|
- uses: actions/checkout@v3
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: dtolnay/rust-toolchain@stable
|
||||||
with:
|
- run: cargo test
|
||||||
profile: minimal
|
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde1
|
||||||
toolchain: stable
|
- run: cargo test --manifest-path tarpc/Cargo.toml --features tokio1
|
||||||
override: true
|
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde-transport
|
||||||
- uses: actions-rs/cargo@v1
|
- run: cargo test --manifest-path tarpc/Cargo.toml --features tcp
|
||||||
with:
|
- run: cargo test --all-features
|
||||||
command: test
|
|
||||||
- uses: actions-rs/cargo@v1
|
|
||||||
with:
|
|
||||||
command: test
|
|
||||||
args: --manifest-path tarpc/Cargo.toml --features serde1
|
|
||||||
- uses: actions-rs/cargo@v1
|
|
||||||
with:
|
|
||||||
command: test
|
|
||||||
args: --manifest-path tarpc/Cargo.toml --features tokio1
|
|
||||||
- uses: actions-rs/cargo@v1
|
|
||||||
with:
|
|
||||||
command: test
|
|
||||||
args: --manifest-path tarpc/Cargo.toml --features serde-transport
|
|
||||||
- uses: actions-rs/cargo@v1
|
|
||||||
with:
|
|
||||||
command: test
|
|
||||||
args: --manifest-path tarpc/Cargo.toml --features tcp
|
|
||||||
- uses: actions-rs/cargo@v1
|
|
||||||
with:
|
|
||||||
command: test
|
|
||||||
args: --all-features
|
|
||||||
|
|
||||||
fmt:
|
fmt:
|
||||||
name: Rustfmt
|
name: Rustfmt
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Cancel previous
|
- name: Cancel previous
|
||||||
uses: styfle/cancel-workflow-action@0.7.0
|
uses: styfle/cancel-workflow-action@0.10.0
|
||||||
with:
|
with:
|
||||||
access_token: ${{ github.token }}
|
access_token: ${{ github.token }}
|
||||||
- uses: actions/checkout@v1
|
- uses: actions/checkout@v3
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: dtolnay/rust-toolchain@stable
|
||||||
with:
|
with:
|
||||||
profile: minimal
|
components: rustfmt
|
||||||
toolchain: stable
|
- run: cargo fmt --all -- --check
|
||||||
override: true
|
|
||||||
- run: rustup component add rustfmt
|
|
||||||
- uses: actions-rs/cargo@v1
|
|
||||||
with:
|
|
||||||
command: fmt
|
|
||||||
args: --all -- --check
|
|
||||||
|
|
||||||
clippy:
|
clippy:
|
||||||
name: Clippy
|
name: Clippy
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Cancel previous
|
- name: Cancel previous
|
||||||
uses: styfle/cancel-workflow-action@0.7.0
|
uses: styfle/cancel-workflow-action@0.10.0
|
||||||
with:
|
with:
|
||||||
access_token: ${{ github.token }}
|
access_token: ${{ github.token }}
|
||||||
- uses: actions/checkout@v1
|
- uses: actions/checkout@v3
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: dtolnay/rust-toolchain@stable
|
||||||
with:
|
with:
|
||||||
profile: minimal
|
components: clippy
|
||||||
toolchain: stable
|
- run: cargo clippy --all-features -- -D warnings
|
||||||
override: true
|
|
||||||
- run: rustup component add clippy
|
|
||||||
- uses: actions-rs/cargo@v1
|
|
||||||
with:
|
|
||||||
command: clippy
|
|
||||||
args: --all-features -- -D warnings
|
|
||||||
|
|||||||
31
README.md
31
README.md
@@ -40,7 +40,7 @@ rather than in a separate language such as .proto. This means there's no separat
|
|||||||
process, and no context switching between different languages.
|
process, and no context switching between different languages.
|
||||||
|
|
||||||
Some other features of tarpc:
|
Some other features of tarpc:
|
||||||
- Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
|
- Pluggable transport: any type implementing `Stream<Item = Request> + Sink<Response>` can be
|
||||||
used as a transport to connect the client and server.
|
used as a transport to connect the client and server.
|
||||||
- `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
|
- `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
|
||||||
- Cascading cancellation: dropping a request will send a cancellation message to the server.
|
- Cascading cancellation: dropping a request will send a cancellation message to the server.
|
||||||
@@ -55,7 +55,7 @@ Some other features of tarpc:
|
|||||||
[tracing](https://github.com/tokio-rs/tracing) primitives extended with
|
[tracing](https://github.com/tokio-rs/tracing) primitives extended with
|
||||||
[OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
|
[OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
|
||||||
[Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
|
[Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
|
||||||
each RPC can be traced through the client, server, amd other dependencies downstream of the
|
each RPC can be traced through the client, server, and other dependencies downstream of the
|
||||||
server. Even for applications not connected to a distributed tracing collector, the
|
server. Even for applications not connected to a distributed tracing collector, the
|
||||||
instrumentation can also be ingested by regular loggers like
|
instrumentation can also be ingested by regular loggers like
|
||||||
[env_logger](https://github.com/env-logger-rs/env_logger/).
|
[env_logger](https://github.com/env-logger-rs/env_logger/).
|
||||||
@@ -67,7 +67,7 @@ Some other features of tarpc:
|
|||||||
Add to your `Cargo.toml` dependencies:
|
Add to your `Cargo.toml` dependencies:
|
||||||
|
|
||||||
```toml
|
```toml
|
||||||
tarpc = "0.27"
|
tarpc = "0.34"
|
||||||
```
|
```
|
||||||
|
|
||||||
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||||
@@ -82,8 +82,8 @@ your `Cargo.toml`:
|
|||||||
```toml
|
```toml
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
tarpc = { version = "0.27", features = ["tokio1"] }
|
tarpc = { version = "0.31", features = ["tokio1"] }
|
||||||
tokio = { version = "1.0", features = ["macros"] }
|
tokio = { version = "1.0", features = ["rt-multi-thread", "macros"] }
|
||||||
```
|
```
|
||||||
|
|
||||||
In the following example, we use an in-process channel for communication between
|
In the following example, we use an in-process channel for communication between
|
||||||
@@ -93,14 +93,10 @@ For a more real-world example, see [example-service](example-service).
|
|||||||
First, let's set up the dependencies and service definition.
|
First, let's set up the dependencies and service definition.
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
|
use futures::future::{self, Ready};
|
||||||
use futures::{
|
|
||||||
future::{self, Ready},
|
|
||||||
prelude::*,
|
|
||||||
};
|
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
client, context,
|
client, context,
|
||||||
server::{self, incoming::Incoming},
|
server::{self, Channel},
|
||||||
};
|
};
|
||||||
|
|
||||||
// This is the service definition. It looks a lot like a trait definition.
|
// This is the service definition. It looks a lot like a trait definition.
|
||||||
@@ -122,13 +118,8 @@ implement it for our Server struct.
|
|||||||
struct HelloServer;
|
struct HelloServer;
|
||||||
|
|
||||||
impl World for HelloServer {
|
impl World for HelloServer {
|
||||||
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
|
async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
// an associated type representing the future output by the fn.
|
format!("Hello, {name}!")
|
||||||
|
|
||||||
type HelloFut = Ready<String>;
|
|
||||||
|
|
||||||
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
|
||||||
future::ready(format!("Hello, {}!", name))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@@ -148,14 +139,14 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||||
// that takes a config and any Transport as input.
|
// that takes a config and any Transport as input.
|
||||||
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
let client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
||||||
|
|
||||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||||
// specifies a deadline and trace information which can be helpful in debugging requests.
|
// specifies a deadline and trace information which can be helpful in debugging requests.
|
||||||
let hello = client.hello(context::current(), "Stim".to_string()).await?;
|
let hello = client.hello(context::current(), "Stim".to_string()).await?;
|
||||||
|
|
||||||
println!("{}", hello);
|
println!("{hello}");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
94
RELEASES.md
94
RELEASES.md
@@ -1,8 +1,100 @@
|
|||||||
|
## 0.34.0 (2023-12-29)
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
- `#[tarpc::server]` is no more! Service traits now use async fns.
|
||||||
|
- `Channel::execute` no longer spawns request handlers. Async-fn-in-traits makes it impossible to
|
||||||
|
add a Send bound to the future returned by `Serve::serve`. Instead, `Channel::execute` returns a
|
||||||
|
stream of futures, where each future is a request handler. To achieve the former behavior:
|
||||||
|
```rust
|
||||||
|
channel.execute(server.serve())
|
||||||
|
.for_each(|rpc| { tokio::spawn(rpc); })
|
||||||
|
```
|
||||||
|
|
||||||
|
### New Features
|
||||||
|
|
||||||
|
- Request hooks are added to the serve trait, so that it's easy to hook in cross-cutting
|
||||||
|
functionality like throttling, authorization, etc.
|
||||||
|
- The Client trait is back! This makes it possible to hook in generic client functionality like load
|
||||||
|
balancing, retries, etc.
|
||||||
|
|
||||||
|
## 0.33.0 (2023-04-01)
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
Opentelemetry dependency version increased to 0.18.
|
||||||
|
|
||||||
|
## 0.32.0 (2023-03-24)
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
- As part of a fix to return more channel errors in RPC results, a few error types have changed:
|
||||||
|
|
||||||
|
0. `client::RpcError::Disconnected` was split into the following errors:
|
||||||
|
- Shutdown: the client was shutdown, either intentionally or due to an error. If due to an
|
||||||
|
error, pending RPCs should see the more specific errors below.
|
||||||
|
- Send: an RPC message failed to send over the transport. Only the RPC that failed to be sent
|
||||||
|
will see this error.
|
||||||
|
- Receive: a fatal error occurred while receiving from the transport. All in-flight RPCs will
|
||||||
|
receive this error.
|
||||||
|
0. `client::ChannelError` and `server::ChannelError` are unified in `tarpc::ChannelError`.
|
||||||
|
Previously, server transport errors would not indicate during which activity the transport
|
||||||
|
error occurred. Now, just like the client already was, it will be specific: reading, readying,
|
||||||
|
sending, flushing, or closing.
|
||||||
|
|
||||||
|
## 0.31.0 (2022-11-03)
|
||||||
|
|
||||||
|
### New Features
|
||||||
|
|
||||||
|
This release adds Unix Domain Sockets to the `serde_transport` module.
|
||||||
|
To use it, enable the "unix" feature. See the docs for more information.
|
||||||
|
|
||||||
|
## 0.30.0 (2022-08-12)
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
- Some types that impl Future are now annotated with `#[must_use]`. Code that previously created
|
||||||
|
these types but did not use them will now receive a warning. Code that disallows warnings will
|
||||||
|
receive a compilation error.
|
||||||
|
|
||||||
|
### Fixes
|
||||||
|
|
||||||
|
- Servers will more reliably clean up request state for requests with long deadlines when response
|
||||||
|
processing is aborted without sending a response.
|
||||||
|
|
||||||
|
### Other Changes
|
||||||
|
|
||||||
|
- `TrackedRequest` now contains a response guard that can be used to ensure state cleanup for
|
||||||
|
aborted requests. (This was already handled automatically by `InFlightRequests`).
|
||||||
|
- When the feature serde-transport is enabled, the crate tokio_serde is now re-exported.
|
||||||
|
|
||||||
|
## 0.29.0 (2022-05-26)
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
`Context.deadline` is now serialized as a Duration. This prevents clock skew from affecting deadline
|
||||||
|
behavior. For more details see https://github.com/google/tarpc/pull/367 and its [related
|
||||||
|
issue](https://github.com/google/tarpc/issues/366).
|
||||||
|
|
||||||
|
## 0.28.0 (2022-04-06)
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
- The minimum supported Rust version has increased to 1.58.0.
|
||||||
|
- The version of opentelemetry depended on by tarpc has increased to 0.17.0.
|
||||||
|
|
||||||
|
## 0.27.2 (2021-10-08)
|
||||||
|
|
||||||
|
### Fixes
|
||||||
|
|
||||||
|
Clients will now close their transport before dropping it. An attempt at a clean shutdown can help
|
||||||
|
the server drop its connections more quickly.
|
||||||
|
|
||||||
## 0.27.1 (2021-09-22)
|
## 0.27.1 (2021-09-22)
|
||||||
|
|
||||||
### Breaking Changes
|
### Breaking Changes
|
||||||
|
|
||||||
### RPC error type is changing
|
#### RPC error type is changing
|
||||||
|
|
||||||
RPC return types are changing from `Result<Response, io::Error>` to `Result<Response,
|
RPC return types are changing from `Result<Response, io::Error>` to `Result<Response,
|
||||||
tarpc::client::RpcError>`.
|
tarpc::client::RpcError>`.
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "tarpc-example-service"
|
name = "tarpc-example-service"
|
||||||
version = "0.10.0"
|
version = "0.15.0"
|
||||||
|
rust-version = "1.56"
|
||||||
authors = ["Tim Kuehn <tikue@google.com>"]
|
authors = ["Tim Kuehn <tikue@google.com>"]
|
||||||
edition = "2018"
|
edition = "2021"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
documentation = "https://docs.rs/tarpc-example-service"
|
documentation = "https://docs.rs/tarpc-example-service"
|
||||||
homepage = "https://github.com/google/tarpc"
|
homepage = "https://github.com/google/tarpc"
|
||||||
@@ -14,17 +15,18 @@ description = "An example server built on tarpc."
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
clap = "3.0.0-beta.2"
|
clap = { version = "3.0.0-rc.9", features = ["derive"] }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
opentelemetry = { version = "0.16", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.21.0" }
|
||||||
opentelemetry-jaeger = { version = "0.15", features = ["rt-tokio"] }
|
opentelemetry-jaeger = { version = "0.20.0", features = ["rt-tokio"] }
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
tarpc = { version = "0.27", path = "../tarpc", features = ["full"] }
|
tarpc = { version = "0.34", path = "../tarpc", features = ["full"] }
|
||||||
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
|
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
|
||||||
tracing = { version = "0.1" }
|
tracing = { version = "0.1" }
|
||||||
tracing-opentelemetry = "0.15"
|
tracing-opentelemetry = "0.22.0"
|
||||||
tracing-subscriber = "0.2"
|
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||||
|
opentelemetry_sdk = "0.21.1"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
name = "service"
|
name = "service"
|
||||||
|
|||||||
15
example-service/README.md
Normal file
15
example-service/README.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Example
|
||||||
|
|
||||||
|
Example service to demonstrate how to set up `tarpc` with [Jaeger](https://www.jaegertracing.io). To see traces Jaeger, run the following with `RUST_LOG=trace`.
|
||||||
|
|
||||||
|
## Server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --bin server -- --port 50051
|
||||||
|
```
|
||||||
|
|
||||||
|
## Client
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --bin client -- --server-addr "[::1]:50051" --name "Bob"
|
||||||
|
```
|
||||||
@@ -4,14 +4,14 @@
|
|||||||
// license that can be found in the LICENSE file or at
|
// license that can be found in the LICENSE file or at
|
||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use clap::Clap;
|
use clap::Parser;
|
||||||
use service::{init_tracing, WorldClient};
|
use service::{init_tracing, WorldClient};
|
||||||
use std::{net::SocketAddr, time::Duration};
|
use std::{net::SocketAddr, time::Duration};
|
||||||
use tarpc::{client, context, tokio_serde::formats::Json};
|
use tarpc::{client, context, tokio_serde::formats::Json};
|
||||||
use tokio::time::sleep;
|
use tokio::time::sleep;
|
||||||
use tracing::Instrument;
|
use tracing::Instrument;
|
||||||
|
|
||||||
#[derive(Clap)]
|
#[derive(Parser)]
|
||||||
struct Flags {
|
struct Flags {
|
||||||
/// Sets the server address to connect to.
|
/// Sets the server address to connect to.
|
||||||
#[clap(long)]
|
#[clap(long)]
|
||||||
@@ -26,7 +26,8 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let flags = Flags::parse();
|
let flags = Flags::parse();
|
||||||
init_tracing("Tarpc Example Client")?;
|
init_tracing("Tarpc Example Client")?;
|
||||||
|
|
||||||
let transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
|
let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
|
||||||
|
transport.config_mut().max_frame_length(usize::MAX);
|
||||||
|
|
||||||
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
|
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
|
||||||
// config and any Transport as input.
|
// config and any Transport as input.
|
||||||
@@ -42,7 +43,10 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
.instrument(tracing::info_span!("Two Hellos"))
|
.instrument(tracing::info_span!("Two Hellos"))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
tracing::info!("{:?}", hello);
|
match hello {
|
||||||
|
Ok(hello) => tracing::info!("{hello:?}"),
|
||||||
|
Err(e) => tracing::warn!("{:?}", anyhow::Error::from(e)),
|
||||||
|
}
|
||||||
|
|
||||||
// Let the background span processor finish.
|
// Let the background span processor finish.
|
||||||
sleep(Duration::from_micros(1)).await;
|
sleep(Duration::from_micros(1)).await;
|
||||||
|
|||||||
@@ -19,10 +19,10 @@ pub trait World {
|
|||||||
pub fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
pub fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||||
|
|
||||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
let tracer = opentelemetry_jaeger::new_agent_pipeline()
|
||||||
.with_service_name(service_name)
|
.with_service_name(service_name)
|
||||||
.with_max_packet_size(2usize.pow(13))
|
.with_max_packet_size(2usize.pow(13))
|
||||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
.install_batch(opentelemetry_sdk::runtime::Tokio)?;
|
||||||
|
|
||||||
tracing_subscriber::registry()
|
tracing_subscriber::registry()
|
||||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
// license that can be found in the LICENSE file or at
|
// license that can be found in the LICENSE file or at
|
||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use clap::Clap;
|
use clap::Parser;
|
||||||
use futures::{future, prelude::*};
|
use futures::{future, prelude::*};
|
||||||
use rand::{
|
use rand::{
|
||||||
distributions::{Distribution, Uniform},
|
distributions::{Distribution, Uniform},
|
||||||
@@ -22,7 +22,7 @@ use tarpc::{
|
|||||||
};
|
};
|
||||||
use tokio::time;
|
use tokio::time;
|
||||||
|
|
||||||
#[derive(Clap)]
|
#[derive(Parser)]
|
||||||
struct Flags {
|
struct Flags {
|
||||||
/// Sets the port number to listen on.
|
/// Sets the port number to listen on.
|
||||||
#[clap(long)]
|
#[clap(long)]
|
||||||
@@ -34,16 +34,19 @@ struct Flags {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct HelloServer(SocketAddr);
|
struct HelloServer(SocketAddr);
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl World for HelloServer {
|
impl World for HelloServer {
|
||||||
async fn hello(self, _: context::Context, name: String) -> String {
|
async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
let sleep_time =
|
let sleep_time =
|
||||||
Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng()));
|
Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng()));
|
||||||
time::sleep(sleep_time).await;
|
time::sleep(sleep_time).await;
|
||||||
format!("Hello, {}! You are connected from {}", name, self.0)
|
format!("Hello, {name}! You are connected from {}", self.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let flags = Flags::parse();
|
let flags = Flags::parse();
|
||||||
@@ -54,6 +57,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
// JSON transport is provided by the json_transport tarpc module. It makes it easy
|
// JSON transport is provided by the json_transport tarpc module. It makes it easy
|
||||||
// to start up a serde-powered json serialization strategy over TCP.
|
// to start up a serde-powered json serialization strategy over TCP.
|
||||||
let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?;
|
let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?;
|
||||||
|
tracing::info!("Listening on port {}", listener.local_addr().port());
|
||||||
listener.config_mut().max_frame_length(usize::MAX);
|
listener.config_mut().max_frame_length(usize::MAX);
|
||||||
listener
|
listener
|
||||||
// Ignore accept errors.
|
// Ignore accept errors.
|
||||||
@@ -65,7 +69,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
// the generated World trait.
|
// the generated World trait.
|
||||||
.map(|channel| {
|
.map(|channel| {
|
||||||
let server = HelloServer(channel.transport().peer_addr().unwrap());
|
let server = HelloServer(channel.transport().peer_addr().unwrap());
|
||||||
channel.execute(server.serve())
|
channel.execute(server.serve()).for_each(spawn)
|
||||||
})
|
})
|
||||||
// Max 10 channels.
|
// Max 10 channels.
|
||||||
.buffer_unordered(10)
|
.buffer_unordered(10)
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "tarpc-plugins"
|
name = "tarpc-plugins"
|
||||||
version = "0.12.0"
|
version = "0.13.0"
|
||||||
|
rust-version = "1.56"
|
||||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||||
edition = "2018"
|
edition = "2021"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
documentation = "https://docs.rs/tarpc-plugins"
|
documentation = "https://docs.rs/tarpc-plugins"
|
||||||
homepage = "https://github.com/google/tarpc"
|
homepage = "https://github.com/google/tarpc"
|
||||||
|
|||||||
9
plugins/LICENSE
Normal file
9
plugins/LICENSE
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright 2016 Google Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
@@ -12,18 +12,18 @@ extern crate quote;
|
|||||||
extern crate syn;
|
extern crate syn;
|
||||||
|
|
||||||
use proc_macro::TokenStream;
|
use proc_macro::TokenStream;
|
||||||
use proc_macro2::{Span, TokenStream as TokenStream2};
|
use proc_macro2::TokenStream as TokenStream2;
|
||||||
use quote::{format_ident, quote, ToTokens};
|
use quote::{format_ident, quote, ToTokens};
|
||||||
use syn::{
|
use syn::{
|
||||||
braced,
|
braced,
|
||||||
ext::IdentExt,
|
ext::IdentExt,
|
||||||
parenthesized,
|
parenthesized,
|
||||||
parse::{Parse, ParseStream},
|
parse::{Parse, ParseStream},
|
||||||
parse_macro_input, parse_quote, parse_str,
|
parse_macro_input, parse_quote,
|
||||||
spanned::Spanned,
|
spanned::Spanned,
|
||||||
token::Comma,
|
token::Comma,
|
||||||
Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool,
|
Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type,
|
||||||
MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility,
|
Visibility,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Accumulates multiple errors into a result.
|
/// Accumulates multiple errors into a result.
|
||||||
@@ -83,7 +83,7 @@ impl Parse for Service {
|
|||||||
ident_errors,
|
ident_errors,
|
||||||
syn::Error::new(
|
syn::Error::new(
|
||||||
rpc.ident.span(),
|
rpc.ident.span(),
|
||||||
format!("method name conflicts with generated fn `{}::serve`", ident)
|
format!("method name conflicts with generated fn `{ident}::serve`")
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -257,7 +257,6 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
|||||||
.map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
|
.map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
|
||||||
.collect();
|
.collect();
|
||||||
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
|
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
|
||||||
let response_fut_name = &format!("{}ResponseFut", ident.unraw());
|
|
||||||
let derive_serialize = if derive_serde.0 {
|
let derive_serialize = if derive_serde.0 {
|
||||||
Some(
|
Some(
|
||||||
quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
||||||
@@ -270,14 +269,13 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
|||||||
let methods = rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>();
|
let methods = rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>();
|
||||||
let request_names = methods
|
let request_names = methods
|
||||||
.iter()
|
.iter()
|
||||||
.map(|m| format!("{}.{}", ident, m))
|
.map(|m| format!("{ident}.{m}"))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
ServiceGenerator {
|
ServiceGenerator {
|
||||||
response_fut_name,
|
|
||||||
service_ident: ident,
|
service_ident: ident,
|
||||||
|
client_stub_ident: &format_ident!("{}Stub", ident),
|
||||||
server_ident: &format_ident!("Serve{}", ident),
|
server_ident: &format_ident!("Serve{}", ident),
|
||||||
response_fut_ident: &Ident::new(response_fut_name, ident.span()),
|
|
||||||
client_ident: &format_ident!("{}Client", ident),
|
client_ident: &format_ident!("{}Client", ident),
|
||||||
request_ident: &format_ident!("{}Request", ident),
|
request_ident: &format_ident!("{}Request", ident),
|
||||||
response_ident: &format_ident!("{}Response", ident),
|
response_ident: &format_ident!("{}Response", ident),
|
||||||
@@ -285,7 +283,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
|||||||
args,
|
args,
|
||||||
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
|
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
|
||||||
method_idents: &methods,
|
method_idents: &methods,
|
||||||
request_names: &*request_names,
|
request_names: &request_names,
|
||||||
attrs,
|
attrs,
|
||||||
rpcs,
|
rpcs,
|
||||||
return_types: &rpcs
|
return_types: &rpcs
|
||||||
@@ -304,137 +302,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
|||||||
.zip(camel_case_fn_names.iter())
|
.zip(camel_case_fn_names.iter())
|
||||||
.map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
|
.map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
future_types: &camel_case_fn_names
|
|
||||||
.iter()
|
|
||||||
.map(|name| parse_str(&format!("{}Fut", name)).unwrap())
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
derive_serialize: derive_serialize.as_ref(),
|
derive_serialize: derive_serialize.as_ref(),
|
||||||
}
|
}
|
||||||
.into_token_stream()
|
.into_token_stream()
|
||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// generate an identifier consisting of the method name to CamelCase with
|
|
||||||
/// Fut appended to it.
|
|
||||||
fn associated_type_for_rpc(method: &ImplItemMethod) -> String {
|
|
||||||
snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut"
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Transforms an async function into a sync one, returning a type declaration
|
|
||||||
/// for the return type (a future).
|
|
||||||
fn transform_method(method: &mut ImplItemMethod) -> ImplItemType {
|
|
||||||
method.sig.asyncness = None;
|
|
||||||
|
|
||||||
// get either the return type or ().
|
|
||||||
let ret = match &method.sig.output {
|
|
||||||
ReturnType::Default => quote!(()),
|
|
||||||
ReturnType::Type(_, ret) => quote!(#ret),
|
|
||||||
};
|
|
||||||
|
|
||||||
let fut_name = associated_type_for_rpc(method);
|
|
||||||
let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span());
|
|
||||||
|
|
||||||
// generate the updated return signature.
|
|
||||||
method.sig.output = parse_quote! {
|
|
||||||
-> ::core::pin::Pin<Box<
|
|
||||||
dyn ::core::future::Future<Output = #ret> + ::core::marker::Send
|
|
||||||
>>
|
|
||||||
};
|
|
||||||
|
|
||||||
// transform the body of the method into Box::pin(async move { body }).
|
|
||||||
let block = method.block.clone();
|
|
||||||
method.block = parse_quote! [{
|
|
||||||
Box::pin(async move
|
|
||||||
#block
|
|
||||||
)
|
|
||||||
}];
|
|
||||||
|
|
||||||
// generate and return type declaration for return type.
|
|
||||||
let t: ImplItemType = parse_quote! {
|
|
||||||
type #fut_name_ident = ::core::pin::Pin<Box<dyn ::core::future::Future<Output = #ret> + ::core::marker::Send>>;
|
|
||||||
};
|
|
||||||
|
|
||||||
t
|
|
||||||
}
|
|
||||||
|
|
||||||
#[proc_macro_attribute]
|
|
||||||
pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream {
|
|
||||||
let mut item = syn::parse_macro_input!(input as ItemImpl);
|
|
||||||
let span = item.span();
|
|
||||||
|
|
||||||
// the generated type declarations
|
|
||||||
let mut types: Vec<ImplItemType> = Vec::new();
|
|
||||||
let mut expected_non_async_types: Vec<(&ImplItemMethod, String)> = Vec::new();
|
|
||||||
let mut found_non_async_types: Vec<&ImplItemType> = Vec::new();
|
|
||||||
|
|
||||||
for inner in &mut item.items {
|
|
||||||
match inner {
|
|
||||||
ImplItem::Method(method) => {
|
|
||||||
if method.sig.asyncness.is_some() {
|
|
||||||
// if this function is declared async, transform it into a regular function
|
|
||||||
let typedecl = transform_method(method);
|
|
||||||
types.push(typedecl);
|
|
||||||
} else {
|
|
||||||
// If it's not async, keep track of all required associated types for better
|
|
||||||
// error reporting.
|
|
||||||
expected_non_async_types.push((method, associated_type_for_rpc(method)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ImplItem::Type(typedecl) => found_non_async_types.push(typedecl),
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Err(e) =
|
|
||||||
verify_types_were_provided(span, &expected_non_async_types, &found_non_async_types)
|
|
||||||
{
|
|
||||||
return TokenStream::from(e.to_compile_error());
|
|
||||||
}
|
|
||||||
|
|
||||||
// add the type declarations into the impl block
|
|
||||||
for t in types.into_iter() {
|
|
||||||
item.items.push(syn::ImplItem::Type(t));
|
|
||||||
}
|
|
||||||
|
|
||||||
TokenStream::from(quote!(#item))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn verify_types_were_provided(
|
|
||||||
span: Span,
|
|
||||||
expected: &[(&ImplItemMethod, String)],
|
|
||||||
provided: &[&ImplItemType],
|
|
||||||
) -> syn::Result<()> {
|
|
||||||
let mut result = Ok(());
|
|
||||||
for (method, expected) in expected {
|
|
||||||
if !provided.iter().any(|typedecl| typedecl.ident == expected) {
|
|
||||||
let mut e = syn::Error::new(
|
|
||||||
span,
|
|
||||||
format!("not all trait items implemented, missing: `{}`", expected),
|
|
||||||
);
|
|
||||||
let fn_span = method.sig.fn_token.span();
|
|
||||||
e.extend(syn::Error::new(
|
|
||||||
fn_span.join(method.sig.ident.span()).unwrap_or(fn_span),
|
|
||||||
format!(
|
|
||||||
"hint: `#[tarpc::server]` only rewrites async fns, and `fn {}` is not async",
|
|
||||||
method.sig.ident
|
|
||||||
),
|
|
||||||
));
|
|
||||||
match result {
|
|
||||||
Ok(_) => result = Err(e),
|
|
||||||
Err(ref mut error) => error.extend(Some(e)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
// Things needed to generate the service items: trait, serve impl, request/response enums, and
|
// Things needed to generate the service items: trait, serve impl, request/response enums, and
|
||||||
// the client stub.
|
// the client stub.
|
||||||
struct ServiceGenerator<'a> {
|
struct ServiceGenerator<'a> {
|
||||||
service_ident: &'a Ident,
|
service_ident: &'a Ident,
|
||||||
|
client_stub_ident: &'a Ident,
|
||||||
server_ident: &'a Ident,
|
server_ident: &'a Ident,
|
||||||
response_fut_ident: &'a Ident,
|
|
||||||
response_fut_name: &'a str,
|
|
||||||
client_ident: &'a Ident,
|
client_ident: &'a Ident,
|
||||||
request_ident: &'a Ident,
|
request_ident: &'a Ident,
|
||||||
response_ident: &'a Ident,
|
response_ident: &'a Ident,
|
||||||
@@ -442,7 +321,6 @@ struct ServiceGenerator<'a> {
|
|||||||
attrs: &'a [Attribute],
|
attrs: &'a [Attribute],
|
||||||
rpcs: &'a [RpcMethod],
|
rpcs: &'a [RpcMethod],
|
||||||
camel_case_idents: &'a [Ident],
|
camel_case_idents: &'a [Ident],
|
||||||
future_types: &'a [Type],
|
|
||||||
method_idents: &'a [&'a Ident],
|
method_idents: &'a [&'a Ident],
|
||||||
request_names: &'a [String],
|
request_names: &'a [String],
|
||||||
method_attrs: &'a [&'a [Attribute]],
|
method_attrs: &'a [&'a [Attribute]],
|
||||||
@@ -458,42 +336,37 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
attrs,
|
attrs,
|
||||||
rpcs,
|
rpcs,
|
||||||
vis,
|
vis,
|
||||||
future_types,
|
|
||||||
return_types,
|
return_types,
|
||||||
service_ident,
|
service_ident,
|
||||||
|
client_stub_ident,
|
||||||
|
request_ident,
|
||||||
|
response_ident,
|
||||||
server_ident,
|
server_ident,
|
||||||
..
|
..
|
||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
let types_and_fns = rpcs
|
let rpc_fns = rpcs
|
||||||
.iter()
|
.iter()
|
||||||
.zip(future_types.iter())
|
|
||||||
.zip(return_types.iter())
|
.zip(return_types.iter())
|
||||||
.map(
|
.map(
|
||||||
|(
|
|(
|
||||||
(
|
RpcMethod {
|
||||||
RpcMethod {
|
attrs, ident, args, ..
|
||||||
attrs, ident, args, ..
|
},
|
||||||
},
|
|
||||||
future_type,
|
|
||||||
),
|
|
||||||
output,
|
output,
|
||||||
)| {
|
)| {
|
||||||
let ty_doc = format!("The response future returned by [`{}::{}`].", service_ident, ident);
|
|
||||||
quote! {
|
quote! {
|
||||||
#[doc = #ty_doc]
|
|
||||||
type #future_type: std::future::Future<Output = #output>;
|
|
||||||
|
|
||||||
#( #attrs )*
|
#( #attrs )*
|
||||||
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type;
|
async fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> #output;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let stub_doc = format!("The stub trait for service [`{service_ident}`].");
|
||||||
quote! {
|
quote! {
|
||||||
#( #attrs )*
|
#( #attrs )*
|
||||||
#vis trait #service_ident: Sized {
|
#vis trait #service_ident: Sized {
|
||||||
#( #types_and_fns )*
|
#( #rpc_fns )*
|
||||||
|
|
||||||
/// Returns a serving function to use with
|
/// Returns a serving function to use with
|
||||||
/// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute).
|
/// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute).
|
||||||
@@ -501,6 +374,15 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
#server_ident { service: self }
|
#server_ident { service: self }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[doc = #stub_doc]
|
||||||
|
#vis trait #client_stub_ident: tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> #client_stub_ident for S
|
||||||
|
where S: tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
|
||||||
|
{
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -524,7 +406,6 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
server_ident,
|
server_ident,
|
||||||
service_ident,
|
service_ident,
|
||||||
response_ident,
|
response_ident,
|
||||||
response_fut_ident,
|
|
||||||
camel_case_idents,
|
camel_case_idents,
|
||||||
arg_pats,
|
arg_pats,
|
||||||
method_idents,
|
method_idents,
|
||||||
@@ -533,11 +414,11 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
quote! {
|
quote! {
|
||||||
impl<S> tarpc::server::Serve<#request_ident> for #server_ident<S>
|
impl<S> tarpc::server::Serve for #server_ident<S>
|
||||||
where S: #service_ident
|
where S: #service_ident
|
||||||
{
|
{
|
||||||
|
type Req = #request_ident;
|
||||||
type Resp = #response_ident;
|
type Resp = #response_ident;
|
||||||
type Fut = #response_fut_ident<S>;
|
|
||||||
|
|
||||||
fn method(&self, req: &#request_ident) -> Option<&'static str> {
|
fn method(&self, req: &#request_ident) -> Option<&'static str> {
|
||||||
Some(match req {
|
Some(match req {
|
||||||
@@ -549,15 +430,16 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
|
async fn serve(self, ctx: tarpc::context::Context, req: #request_ident)
|
||||||
|
-> Result<#response_ident, tarpc::ServerError> {
|
||||||
match req {
|
match req {
|
||||||
#(
|
#(
|
||||||
#request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
|
#request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
|
||||||
#response_fut_ident::#camel_case_idents(
|
Ok(#response_ident::#camel_case_idents(
|
||||||
#service_ident::#method_idents(
|
#service_ident::#method_idents(
|
||||||
self.service, ctx, #( #arg_pats ),*
|
self.service, ctx, #( #arg_pats ),*
|
||||||
)
|
).await
|
||||||
)
|
))
|
||||||
}
|
}
|
||||||
)*
|
)*
|
||||||
}
|
}
|
||||||
@@ -608,73 +490,6 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn enum_response_future(&self) -> TokenStream2 {
|
|
||||||
let &Self {
|
|
||||||
vis,
|
|
||||||
service_ident,
|
|
||||||
response_fut_ident,
|
|
||||||
camel_case_idents,
|
|
||||||
future_types,
|
|
||||||
..
|
|
||||||
} = self;
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
/// A future resolving to a server response.
|
|
||||||
#[allow(missing_docs)]
|
|
||||||
#vis enum #response_fut_ident<S: #service_ident> {
|
|
||||||
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn impl_debug_for_response_future(&self) -> TokenStream2 {
|
|
||||||
let &Self {
|
|
||||||
service_ident,
|
|
||||||
response_fut_ident,
|
|
||||||
response_fut_name,
|
|
||||||
..
|
|
||||||
} = self;
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
|
|
||||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
||||||
fmt.debug_struct(#response_fut_name).finish()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn impl_future_for_response_future(&self) -> TokenStream2 {
|
|
||||||
let &Self {
|
|
||||||
service_ident,
|
|
||||||
response_fut_ident,
|
|
||||||
response_ident,
|
|
||||||
camel_case_idents,
|
|
||||||
..
|
|
||||||
} = self;
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
impl<S: #service_ident> std::future::Future for #response_fut_ident<S> {
|
|
||||||
type Output = #response_ident;
|
|
||||||
|
|
||||||
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
|
|
||||||
-> std::task::Poll<#response_ident>
|
|
||||||
{
|
|
||||||
unsafe {
|
|
||||||
match std::pin::Pin::get_unchecked_mut(self) {
|
|
||||||
#(
|
|
||||||
#response_fut_ident::#camel_case_idents(resp) =>
|
|
||||||
std::pin::Pin::new_unchecked(resp)
|
|
||||||
.poll(cx)
|
|
||||||
.map(#response_ident::#camel_case_idents),
|
|
||||||
)*
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn struct_client(&self) -> TokenStream2 {
|
fn struct_client(&self) -> TokenStream2 {
|
||||||
let &Self {
|
let &Self {
|
||||||
vis,
|
vis,
|
||||||
@@ -689,7 +504,9 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
/// The client stub that makes RPC calls to the server. All request methods return
|
/// The client stub that makes RPC calls to the server. All request methods return
|
||||||
/// [Futures](std::future::Future).
|
/// [Futures](std::future::Future).
|
||||||
#vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>);
|
#vis struct #client_ident<
|
||||||
|
Stub = tarpc::client::Channel<#request_ident, #response_ident>
|
||||||
|
>(Stub);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -719,6 +536,17 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
dispatch: new_client.dispatch,
|
dispatch: new_client.dispatch,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Stub> From<Stub> for #client_ident<Stub>
|
||||||
|
where Stub: tarpc::client::stub::Stub<
|
||||||
|
Req = #request_ident,
|
||||||
|
Resp = #response_ident>
|
||||||
|
{
|
||||||
|
/// Returns a new client stub that sends requests over the given transport.
|
||||||
|
fn from(stub: Stub) -> Self {
|
||||||
|
#client_ident(stub)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -741,7 +569,11 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
quote! {
|
quote! {
|
||||||
impl #client_ident {
|
impl<Stub> #client_ident<Stub>
|
||||||
|
where Stub: tarpc::client::stub::Stub<
|
||||||
|
Req = #request_ident,
|
||||||
|
Resp = #response_ident>
|
||||||
|
{
|
||||||
#(
|
#(
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
#( #method_attrs )*
|
#( #method_attrs )*
|
||||||
@@ -770,9 +602,6 @@ impl<'a> ToTokens for ServiceGenerator<'a> {
|
|||||||
self.impl_serve_for_server(),
|
self.impl_serve_for_server(),
|
||||||
self.enum_request(),
|
self.enum_request(),
|
||||||
self.enum_response(),
|
self.enum_response(),
|
||||||
self.enum_response_future(),
|
|
||||||
self.impl_debug_for_response_future(),
|
|
||||||
self.impl_future_for_response_future(),
|
|
||||||
self.struct_client(),
|
self.struct_client(),
|
||||||
self.impl_client_new(),
|
self.impl_client_new(),
|
||||||
self.impl_client_rpc_methods(),
|
self.impl_client_rpc_methods(),
|
||||||
|
|||||||
@@ -1,8 +1,3 @@
|
|||||||
use assert_type_eq::assert_type_eq;
|
|
||||||
use futures::Future;
|
|
||||||
use std::pin::Pin;
|
|
||||||
use tarpc::context;
|
|
||||||
|
|
||||||
// these need to be out here rather than inside the function so that the
|
// these need to be out here rather than inside the function so that the
|
||||||
// assert_type_eq macro can pick them up.
|
// assert_type_eq macro can pick them up.
|
||||||
#[tarpc::service]
|
#[tarpc::service]
|
||||||
@@ -12,42 +7,6 @@ trait Foo {
|
|||||||
async fn baz();
|
async fn baz();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn type_generation_works() {
|
|
||||||
#[tarpc::server]
|
|
||||||
impl Foo for () {
|
|
||||||
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
|
|
||||||
(s, i)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn bar(self, _: context::Context, s: String) -> String {
|
|
||||||
s
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn baz(self, _: context::Context) {}
|
|
||||||
}
|
|
||||||
|
|
||||||
// the assert_type_eq macro can only be used once per block.
|
|
||||||
{
|
|
||||||
assert_type_eq!(
|
|
||||||
<() as Foo>::TwoPartFut,
|
|
||||||
Pin<Box<dyn Future<Output = (String, i32)> + Send>>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
assert_type_eq!(
|
|
||||||
<() as Foo>::BarFut,
|
|
||||||
Pin<Box<dyn Future<Output = String> + Send>>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
assert_type_eq!(
|
|
||||||
<() as Foo>::BazFut,
|
|
||||||
Pin<Box<dyn Future<Output = ()> + Send>>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(non_camel_case_types)]
|
#[allow(non_camel_case_types)]
|
||||||
#[test]
|
#[test]
|
||||||
fn raw_idents_work() {
|
fn raw_idents_work() {
|
||||||
@@ -59,24 +18,6 @@ fn raw_idents_work() {
|
|||||||
async fn r#fn(r#impl: r#yield) -> r#yield;
|
async fn r#fn(r#impl: r#yield) -> r#yield;
|
||||||
async fn r#async();
|
async fn r#async();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl r#trait for () {
|
|
||||||
async fn r#await(
|
|
||||||
self,
|
|
||||||
_: context::Context,
|
|
||||||
r#struct: r#yield,
|
|
||||||
r#enum: i32,
|
|
||||||
) -> (r#yield, i32) {
|
|
||||||
(r#struct, r#enum)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
|
|
||||||
r#impl
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn r#async(self, _: context::Context) {}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -100,45 +41,4 @@ fn syntax() {
|
|||||||
#[doc = "attr"]
|
#[doc = "attr"]
|
||||||
async fn one_arg_implicit_return_error(one: String);
|
async fn one_arg_implicit_return_error(one: String);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl Syntax for () {
|
|
||||||
#[deny(warnings)]
|
|
||||||
#[allow(non_snake_case)]
|
|
||||||
async fn TestCamelCaseDoesntConflict(self, _: context::Context) {}
|
|
||||||
|
|
||||||
async fn hello(self, _: context::Context) -> String {
|
|
||||||
String::new()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn attr(self, _: context::Context, _s: String) -> String {
|
|
||||||
String::new()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn no_args_no_return(self, _: context::Context) {}
|
|
||||||
|
|
||||||
async fn no_args(self, _: context::Context) -> () {}
|
|
||||||
|
|
||||||
async fn one_arg(self, _: context::Context, _one: String) -> i32 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {}
|
|
||||||
|
|
||||||
async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String {
|
|
||||||
String::new()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn no_args_ret_error(self, _: context::Context) -> i32 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String {
|
|
||||||
String::new()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn no_arg_implicit_return_error(self, _: context::Context) {}
|
|
||||||
|
|
||||||
async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ use tarpc::context;
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn att_service_trait() {
|
fn att_service_trait() {
|
||||||
use futures::future::{ready, Ready};
|
|
||||||
|
|
||||||
#[tarpc::service]
|
#[tarpc::service]
|
||||||
trait Foo {
|
trait Foo {
|
||||||
async fn two_part(s: String, i: i32) -> (String, i32);
|
async fn two_part(s: String, i: i32) -> (String, i32);
|
||||||
@@ -12,19 +10,16 @@ fn att_service_trait() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Foo for () {
|
impl Foo for () {
|
||||||
type TwoPartFut = Ready<(String, i32)>;
|
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
|
||||||
fn two_part(self, _: context::Context, s: String, i: i32) -> Self::TwoPartFut {
|
(s, i)
|
||||||
ready((s, i))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type BarFut = Ready<String>;
|
async fn bar(self, _: context::Context, s: String) -> String {
|
||||||
fn bar(self, _: context::Context, s: String) -> Self::BarFut {
|
s
|
||||||
ready(s)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type BazFut = Ready<()>;
|
async fn baz(self, _: context::Context) {
|
||||||
fn baz(self, _: context::Context) -> Self::BazFut {
|
()
|
||||||
ready(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -32,8 +27,6 @@ fn att_service_trait() {
|
|||||||
#[allow(non_camel_case_types)]
|
#[allow(non_camel_case_types)]
|
||||||
#[test]
|
#[test]
|
||||||
fn raw_idents() {
|
fn raw_idents() {
|
||||||
use futures::future::{ready, Ready};
|
|
||||||
|
|
||||||
type r#yield = String;
|
type r#yield = String;
|
||||||
|
|
||||||
#[tarpc::service]
|
#[tarpc::service]
|
||||||
@@ -44,19 +37,21 @@ fn raw_idents() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl r#trait for () {
|
impl r#trait for () {
|
||||||
type AwaitFut = Ready<(r#yield, i32)>;
|
async fn r#await(
|
||||||
fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut {
|
self,
|
||||||
ready((r#struct, r#enum))
|
_: context::Context,
|
||||||
|
r#struct: r#yield,
|
||||||
|
r#enum: i32,
|
||||||
|
) -> (r#yield, i32) {
|
||||||
|
(r#struct, r#enum)
|
||||||
}
|
}
|
||||||
|
|
||||||
type FnFut = Ready<r#yield>;
|
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
|
||||||
fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut {
|
r#impl
|
||||||
ready(r#impl)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AsyncFut = Ready<()>;
|
async fn r#async(self, _: context::Context) {
|
||||||
fn r#async(self, _: context::Context) -> Self::AsyncFut {
|
()
|
||||||
ready(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,28 +1,41 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "tarpc"
|
name = "tarpc"
|
||||||
version = "0.27.1"
|
version = "0.34.0"
|
||||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
rust-version = "1.58.0"
|
||||||
edition = "2018"
|
authors = [
|
||||||
|
"Adam Wright <adam.austin.wright@gmail.com>",
|
||||||
|
"Tim Kuehn <timothy.j.kuehn@gmail.com>",
|
||||||
|
]
|
||||||
|
edition = "2021"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
documentation = "https://docs.rs/tarpc"
|
documentation = "https://docs.rs/tarpc"
|
||||||
homepage = "https://github.com/google/tarpc"
|
homepage = "https://github.com/google/tarpc"
|
||||||
repository = "https://github.com/google/tarpc"
|
repository = "https://github.com/google/tarpc"
|
||||||
keywords = ["rpc", "network", "server", "api", "microservices"]
|
keywords = ["rpc", "network", "server", "api", "microservices"]
|
||||||
categories = ["asynchronous", "network-programming"]
|
categories = ["asynchronous", "network-programming"]
|
||||||
readme = "../README.md"
|
readme = "README.md"
|
||||||
description = "An RPC framework for Rust with a focus on ease of use."
|
description = "An RPC framework for Rust with a focus on ease of use."
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
|
||||||
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
|
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive", "serde/rc"]
|
||||||
tokio1 = ["tokio/rt-multi-thread"]
|
tokio1 = ["tokio/rt"]
|
||||||
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
||||||
serde-transport-json = ["tokio-serde/json"]
|
serde-transport-json = ["tokio-serde/json"]
|
||||||
serde-transport-bincode = ["tokio-serde/bincode"]
|
serde-transport-bincode = ["tokio-serde/bincode"]
|
||||||
tcp = ["tokio/net"]
|
tcp = ["tokio/net"]
|
||||||
|
unix = ["tokio/net"]
|
||||||
|
|
||||||
full = ["serde1", "tokio1", "serde-transport", "serde-transport-json", "serde-transport-bincode", "tcp"]
|
full = [
|
||||||
|
"serde1",
|
||||||
|
"tokio1",
|
||||||
|
"serde-transport",
|
||||||
|
"serde-transport-json",
|
||||||
|
"serde-transport-bincode",
|
||||||
|
"tcp",
|
||||||
|
"unix",
|
||||||
|
]
|
||||||
|
|
||||||
[badges]
|
[badges]
|
||||||
travis-ci = { repository = "google/tarpc" }
|
travis-ci = { repository = "google/tarpc" }
|
||||||
@@ -36,14 +49,17 @@ pin-project = "1.0"
|
|||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
serde = { optional = true, version = "1.0", features = ["derive"] }
|
serde = { optional = true, version = "1.0", features = ["derive"] }
|
||||||
static_assertions = "1.1.0"
|
static_assertions = "1.1.0"
|
||||||
tarpc-plugins = { path = "../plugins", version = "0.12" }
|
tarpc-plugins = { path = "../plugins", version = "0.13" }
|
||||||
thiserror = "1.0"
|
thiserror = "1.0"
|
||||||
tokio = { version = "1", features = ["time"] }
|
tokio = { version = "1", features = ["time"] }
|
||||||
tokio-util = { version = "0.6.3", features = ["time"] }
|
tokio-util = { version = "0.7.3", features = ["time"] }
|
||||||
tokio-serde = { optional = true, version = "0.8" }
|
tokio-serde = { optional = true, version = "0.8" }
|
||||||
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
|
tracing = { version = "0.1", default-features = false, features = [
|
||||||
tracing-opentelemetry = { version = "0.15", default-features = false }
|
"attributes",
|
||||||
opentelemetry = { version = "0.16", default-features = false }
|
"log",
|
||||||
|
] }
|
||||||
|
tracing-opentelemetry = { version = "0.18.0", default-features = false }
|
||||||
|
opentelemetry = { version = "0.18.0", default-features = false }
|
||||||
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
@@ -52,14 +68,19 @@ bincode = "1.3"
|
|||||||
bytes = { version = "1", features = ["serde"] }
|
bytes = { version = "1", features = ["serde"] }
|
||||||
flate2 = "1.0"
|
flate2 = "1.0"
|
||||||
futures-test = "0.3"
|
futures-test = "0.3"
|
||||||
opentelemetry = { version = "0.16", default-features = false, features = ["rt-tokio"] }
|
opentelemetry = { version = "0.18.0", default-features = false, features = [
|
||||||
opentelemetry-jaeger = { version = "0.15", features = ["rt-tokio"] }
|
"rt-tokio",
|
||||||
|
] }
|
||||||
|
opentelemetry-jaeger = { version = "0.17.0", features = ["rt-tokio"] }
|
||||||
pin-utils = "0.1.0-alpha"
|
pin-utils = "0.1.0-alpha"
|
||||||
serde_bytes = "0.11"
|
serde_bytes = "0.11"
|
||||||
tracing-subscriber = "0.2"
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
tokio = { version = "1", features = ["full", "test-util"] }
|
tokio = { version = "1", features = ["full", "test-util", "tracing"] }
|
||||||
|
console-subscriber = "0.1"
|
||||||
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
|
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
|
||||||
trybuild = "1.0"
|
trybuild = "1.0"
|
||||||
|
tokio-rustls = "0.23"
|
||||||
|
rustls-pemfile = "1.0"
|
||||||
|
|
||||||
[package.metadata.docs.rs]
|
[package.metadata.docs.rs]
|
||||||
all-features = true
|
all-features = true
|
||||||
@@ -85,6 +106,10 @@ required-features = ["full"]
|
|||||||
name = "custom_transport"
|
name = "custom_transport"
|
||||||
required-features = ["serde1", "tokio1", "serde-transport"]
|
required-features = ["serde1", "tokio1", "serde-transport"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "tls_over_tcp"
|
||||||
|
required-features = ["full"]
|
||||||
|
|
||||||
[[test]]
|
[[test]]
|
||||||
name = "service_functional"
|
name = "service_functional"
|
||||||
required-features = ["serde-transport"]
|
required-features = ["serde-transport"]
|
||||||
|
|||||||
9
tarpc/LICENSE
Normal file
9
tarpc/LICENSE
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright 2016 Google Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
11
tarpc/examples/certs/eddsa/client.cert
Normal file
11
tarpc/examples/certs/eddsa/client.cert
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIBlDCCAUagAwIBAgICAxUwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
|
||||||
|
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
|
||||||
|
NjEwMTEwNFowGjEYMBYGA1UEAwwPcG9ueXRvd24gY2xpZW50MCowBQYDK2VwAyEA
|
||||||
|
NTKuLume19IhJfEFd/5OZUuYDKZH6xvy4AGver17OoejgZswgZgwDAYDVR0TAQH/
|
||||||
|
BAIwADALBgNVHQ8EBAMCBsAwFgYDVR0lAQH/BAwwCgYIKwYBBQUHAwIwHQYDVR0O
|
||||||
|
BBYEFDjdrlMu4tyw5MHtbg7WnzSGRBpFMEQGA1UdIwQ9MDuAFHIl7fHKWP6/l8FE
|
||||||
|
fI2YEIM3oHxKoSCkHjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQYIBezAF
|
||||||
|
BgMrZXADQQCaahfj/QLxoCOpvl6y0ZQ9CpojPqBnxV3460j5nUOp040Va2MpF137
|
||||||
|
izCBY7LwgUE/YG6E+kH30G4jMEnqVEYK
|
||||||
|
-----END CERTIFICATE-----
|
||||||
19
tarpc/examples/certs/eddsa/client.chain
Normal file
19
tarpc/examples/certs/eddsa/client.chain
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
|
||||||
|
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
|
||||||
|
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
|
||||||
|
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
|
||||||
|
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
|
||||||
|
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
|
||||||
|
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
|
||||||
|
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
|
||||||
|
-----END CERTIFICATE-----
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
|
||||||
|
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
|
||||||
|
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
|
||||||
|
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
|
||||||
|
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
|
||||||
|
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
|
||||||
|
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
|
||||||
|
-----END CERTIFICATE-----
|
||||||
3
tarpc/examples/certs/eddsa/client.key
Normal file
3
tarpc/examples/certs/eddsa/client.key
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
-----BEGIN PRIVATE KEY-----
|
||||||
|
MC4CAQAwBQYDK2VwBCIEIIJX9ThTHpVS1SNZb6HP4myg4fRInIVGunTRdgnc+weH
|
||||||
|
-----END PRIVATE KEY-----
|
||||||
12
tarpc/examples/certs/eddsa/end.cert
Normal file
12
tarpc/examples/certs/eddsa/end.cert
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIBuDCCAWqgAwIBAgICAcgwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
|
||||||
|
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
|
||||||
|
NjEwMTEwNFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wKjAFBgMrZXADIQDc
|
||||||
|
RLl3/N2tPoWnzBV3noVn/oheEl8IUtiY11Vg/QXTUKOBwDCBvTAMBgNVHRMBAf8E
|
||||||
|
AjAAMAsGA1UdDwQEAwIGwDAdBgNVHQ4EFgQUk7U2mnxedNWBAH84BsNy5si3ZQow
|
||||||
|
RAYDVR0jBD0wO4AUciXt8cpY/r+XwUR8jZgQgzegfEqhIKQeMBwxGjAYBgNVBAMM
|
||||||
|
EXBvbnl0b3duIEVkRFNBIENBggF7MDsGA1UdEQQ0MDKCDnRlc3RzZXJ2ZXIuY29t
|
||||||
|
ghVzZWNvbmQudGVzdHNlcnZlci5jb22CCWxvY2FsaG9zdDAFBgMrZXADQQCFWIcF
|
||||||
|
9FiztCuUNzgXDNu5kshuflt0RjkjWpGlWzQjGoYM2IvYhNVPeqnCiY92gqwDSBtq
|
||||||
|
amD2TBup4eNUCsQB
|
||||||
|
-----END CERTIFICATE-----
|
||||||
19
tarpc/examples/certs/eddsa/end.chain
Normal file
19
tarpc/examples/certs/eddsa/end.chain
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
|
||||||
|
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
|
||||||
|
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
|
||||||
|
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
|
||||||
|
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
|
||||||
|
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
|
||||||
|
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
|
||||||
|
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
|
||||||
|
-----END CERTIFICATE-----
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
|
||||||
|
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
|
||||||
|
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
|
||||||
|
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
|
||||||
|
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
|
||||||
|
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
|
||||||
|
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
|
||||||
|
-----END CERTIFICATE-----
|
||||||
3
tarpc/examples/certs/eddsa/end.key
Normal file
3
tarpc/examples/certs/eddsa/end.key
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
-----BEGIN PRIVATE KEY-----
|
||||||
|
MC4CAQAwBQYDK2VwBCIEIMU6xGVe8JTpZ3bN/wajHfw6pEHt0Rd7wPBxds9eEFy2
|
||||||
|
-----END PRIVATE KEY-----
|
||||||
@@ -1,5 +1,11 @@
|
|||||||
|
// Copyright 2022 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
|
use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
|
||||||
use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt};
|
use futures::{prelude::*, Sink, SinkExt, Stream, StreamExt, TryStreamExt};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_bytes::ByteBuf;
|
use serde_bytes::ByteBuf;
|
||||||
use std::{io, io::Read, io::Write};
|
use std::{io, io::Read, io::Write};
|
||||||
@@ -54,7 +60,7 @@ where
|
|||||||
if algorithm != CompressionAlgorithm::Deflate {
|
if algorithm != CompressionAlgorithm::Deflate {
|
||||||
return Err(io::Error::new(
|
return Err(io::Error::new(
|
||||||
io::ErrorKind::InvalidData,
|
io::ErrorKind::InvalidData,
|
||||||
format!("Compression algorithm {:?} not supported", algorithm),
|
format!("Compression algorithm {algorithm:?} not supported"),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
let mut deflater = DeflateDecoder::new(payload.as_slice());
|
let mut deflater = DeflateDecoder::new(payload.as_slice());
|
||||||
@@ -99,13 +105,16 @@ pub trait World {
|
|||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct HelloServer;
|
struct HelloServer;
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl World for HelloServer {
|
impl World for HelloServer {
|
||||||
async fn hello(self, _: context::Context, name: String) -> String {
|
async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
format!("Hey, {}!", name)
|
format!("Hey, {name}!")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let mut incoming = tcp::listen("localhost:0", Bincode::default).await?;
|
let mut incoming = tcp::listen("localhost:0", Bincode::default).await?;
|
||||||
@@ -114,6 +123,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let transport = incoming.next().await.unwrap().unwrap();
|
let transport = incoming.next().await.unwrap().unwrap();
|
||||||
BaseChannel::with_defaults(add_compression(transport))
|
BaseChannel::with_defaults(add_compression(transport))
|
||||||
.execute(HelloServer.serve())
|
.execute(HelloServer.serve())
|
||||||
|
.for_each(spawn)
|
||||||
.await;
|
.await;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,16 @@
|
|||||||
|
// Copyright 2022 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
use futures::prelude::*;
|
||||||
|
use tarpc::context::Context;
|
||||||
use tarpc::serde_transport as transport;
|
use tarpc::serde_transport as transport;
|
||||||
use tarpc::server::{BaseChannel, Channel};
|
use tarpc::server::{BaseChannel, Channel};
|
||||||
use tarpc::{context::Context, tokio_serde::formats::Bincode};
|
use tarpc::tokio_serde::formats::Bincode;
|
||||||
|
use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec;
|
||||||
use tokio::net::{UnixListener, UnixStream};
|
use tokio::net::{UnixListener, UnixStream};
|
||||||
use tokio_util::codec::length_delimited::LengthDelimitedCodec;
|
|
||||||
|
|
||||||
#[tarpc::service]
|
#[tarpc::service]
|
||||||
pub trait PingService {
|
pub trait PingService {
|
||||||
@@ -12,7 +20,6 @@ pub trait PingService {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct Service;
|
struct Service;
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl PingService for Service {
|
impl PingService for Service {
|
||||||
async fn ping(self, _: Context) {}
|
async fn ping(self, _: Context) {}
|
||||||
}
|
}
|
||||||
@@ -25,13 +32,18 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let listener = UnixListener::bind(bind_addr).unwrap();
|
let listener = UnixListener::bind(bind_addr).unwrap();
|
||||||
let codec_builder = LengthDelimitedCodec::builder();
|
let codec_builder = LengthDelimitedCodec::builder();
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
let (conn, _addr) = listener.accept().await.unwrap();
|
let (conn, _addr) = listener.accept().await.unwrap();
|
||||||
let framed = codec_builder.new_framed(conn);
|
let framed = codec_builder.new_framed(conn);
|
||||||
let transport = transport::new(framed, Bincode::default());
|
let transport = transport::new(framed, Bincode::default());
|
||||||
|
|
||||||
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
|
let fut = BaseChannel::with_defaults(transport)
|
||||||
|
.execute(Service.serve())
|
||||||
|
.for_each(spawn);
|
||||||
tokio::spawn(fut);
|
tokio::spawn(fut);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -52,9 +52,9 @@ use tarpc::{
|
|||||||
client, context,
|
client, context,
|
||||||
serde_transport::tcp,
|
serde_transport::tcp,
|
||||||
server::{self, Channel},
|
server::{self, Channel},
|
||||||
|
tokio_serde::formats::Json,
|
||||||
};
|
};
|
||||||
use tokio::net::ToSocketAddrs;
|
use tokio::net::ToSocketAddrs;
|
||||||
use tokio_serde::formats::Json;
|
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
@@ -79,7 +79,6 @@ struct Subscriber {
|
|||||||
topics: Vec<String>,
|
topics: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl subscriber::Subscriber for Subscriber {
|
impl subscriber::Subscriber for Subscriber {
|
||||||
async fn topics(self, _: context::Context) -> Vec<String> {
|
async fn topics(self, _: context::Context) -> Vec<String> {
|
||||||
self.topics.clone()
|
self.topics.clone()
|
||||||
@@ -117,7 +116,8 @@ impl Subscriber {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve()));
|
let (handler, abort_handle) =
|
||||||
|
future::abortable(handler.execute(subscriber.serve()).for_each(spawn));
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
match handler.await {
|
match handler.await {
|
||||||
Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."),
|
Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."),
|
||||||
@@ -129,7 +129,6 @@ impl Subscriber {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Subscription {
|
struct Subscription {
|
||||||
subscriber: subscriber::SubscriberClient,
|
|
||||||
topics: Vec<String>,
|
topics: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,6 +143,10 @@ struct PublisherAddrs {
|
|||||||
subscriptions: SocketAddr,
|
subscriptions: SocketAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
impl Publisher {
|
impl Publisher {
|
||||||
async fn start(self) -> io::Result<PublisherAddrs> {
|
async fn start(self) -> io::Result<PublisherAddrs> {
|
||||||
let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?;
|
let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?;
|
||||||
@@ -163,6 +166,7 @@ impl Publisher {
|
|||||||
|
|
||||||
server::BaseChannel::with_defaults(publisher)
|
server::BaseChannel::with_defaults(publisher)
|
||||||
.execute(self.serve())
|
.execute(self.serve())
|
||||||
|
.for_each(spawn)
|
||||||
.await
|
.await
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -210,7 +214,6 @@ impl Publisher {
|
|||||||
self.clients.lock().unwrap().insert(
|
self.clients.lock().unwrap().insert(
|
||||||
subscriber_addr,
|
subscriber_addr,
|
||||||
Subscription {
|
Subscription {
|
||||||
subscriber: subscriber.clone(),
|
|
||||||
topics: topics.clone(),
|
topics: topics.clone(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
@@ -259,7 +262,6 @@ impl Publisher {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl publisher::Publisher for Publisher {
|
impl publisher::Publisher for Publisher {
|
||||||
async fn publish(self, _: context::Context, topic: String, message: String) {
|
async fn publish(self, _: context::Context, topic: String, message: String) {
|
||||||
info!("received message to publish.");
|
info!("received message to publish.");
|
||||||
@@ -284,13 +286,13 @@ impl publisher::Publisher for Publisher {
|
|||||||
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
|
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
|
||||||
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
let tracer = opentelemetry_jaeger::new_agent_pipeline()
|
||||||
.with_service_name(service_name)
|
.with_service_name(service_name)
|
||||||
.with_max_packet_size(2usize.pow(13))
|
.with_max_packet_size(2usize.pow(13))
|
||||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||||
|
|
||||||
tracing_subscriber::registry()
|
tracing_subscriber::registry()
|
||||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
.with(tracing_subscriber::filter::EnvFilter::from_default_env())
|
||||||
.with(tracing_subscriber::fmt::layer())
|
.with(tracing_subscriber::fmt::layer())
|
||||||
.with(tracing_opentelemetry::layer().with_tracer(tracer))
|
.with(tracing_opentelemetry::layer().with_tracer(tracer))
|
||||||
.try_init()?;
|
.try_init()?;
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
// license that can be found in the LICENSE file or at
|
// license that can be found in the LICENSE file or at
|
||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use futures::future::{self, Ready};
|
use futures::prelude::*;
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
client, context,
|
client, context,
|
||||||
server::{self, Channel},
|
server::{self, Channel},
|
||||||
@@ -23,22 +23,21 @@ pub trait World {
|
|||||||
struct HelloServer;
|
struct HelloServer;
|
||||||
|
|
||||||
impl World for HelloServer {
|
impl World for HelloServer {
|
||||||
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
|
async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
// an associated type representing the future output by the fn.
|
format!("Hello, {name}!")
|
||||||
|
|
||||||
type HelloFut = Ready<String>;
|
|
||||||
|
|
||||||
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
|
||||||
future::ready(format!("Hello, {}!", name))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||||
|
|
||||||
let server = server::BaseChannel::with_defaults(server_transport);
|
let server = server::BaseChannel::with_defaults(server_transport);
|
||||||
tokio::spawn(server.execute(HelloServer.serve()));
|
tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn));
|
||||||
|
|
||||||
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||||
// that takes a config and any Transport as input.
|
// that takes a config and any Transport as input.
|
||||||
@@ -49,7 +48,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
// specifies a deadline and trace information which can be helpful in debugging requests.
|
// specifies a deadline and trace information which can be helpful in debugging requests.
|
||||||
let hello = client.hello(context::current(), "Stim".to_string()).await?;
|
let hello = client.hello(context::current(), "Stim".to_string()).await?;
|
||||||
|
|
||||||
println!("{}", hello);
|
println!("{hello}");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
150
tarpc/examples/tls_over_tcp.rs
Normal file
150
tarpc/examples/tls_over_tcp.rs
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
// Copyright 2023 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
use futures::prelude::*;
|
||||||
|
use rustls_pemfile::certs;
|
||||||
|
use std::io::{BufReader, Cursor};
|
||||||
|
use std::net::{IpAddr, Ipv4Addr};
|
||||||
|
use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio_rustls::rustls::{self, RootCertStore};
|
||||||
|
use tokio_rustls::{TlsAcceptor, TlsConnector};
|
||||||
|
|
||||||
|
use tarpc::context::Context;
|
||||||
|
use tarpc::serde_transport as transport;
|
||||||
|
use tarpc::server::{BaseChannel, Channel};
|
||||||
|
use tarpc::tokio_serde::formats::Bincode;
|
||||||
|
use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec;
|
||||||
|
|
||||||
|
#[tarpc::service]
|
||||||
|
pub trait PingService {
|
||||||
|
async fn ping() -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct Service;
|
||||||
|
|
||||||
|
impl PingService for Service {
|
||||||
|
async fn ping(self, _: Context) -> String {
|
||||||
|
"🔒".to_owned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca
|
||||||
|
// used on client-side for server tls
|
||||||
|
const END_CHAIN: &str = include_str!("certs/eddsa/end.chain");
|
||||||
|
// used on client-side for client-auth
|
||||||
|
const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key");
|
||||||
|
const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert");
|
||||||
|
|
||||||
|
// used on server-side for server tls
|
||||||
|
const END_CERT: &str = include_str!("certs/eddsa/end.cert");
|
||||||
|
const END_PRIVATEKEY: &str = include_str!("certs/eddsa/end.key");
|
||||||
|
// used on server-side for client-auth
|
||||||
|
const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain");
|
||||||
|
|
||||||
|
pub fn load_certs(data: &str) -> Vec<rustls::Certificate> {
|
||||||
|
certs(&mut BufReader::new(Cursor::new(data)))
|
||||||
|
.unwrap()
|
||||||
|
.into_iter()
|
||||||
|
.map(rustls::Certificate)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_private_key(key: &str) -> rustls::PrivateKey {
|
||||||
|
let mut reader = BufReader::new(Cursor::new(key));
|
||||||
|
loop {
|
||||||
|
match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
|
||||||
|
Some(rustls_pemfile::Item::RSAKey(key)) => return rustls::PrivateKey(key),
|
||||||
|
Some(rustls_pemfile::Item::PKCS8Key(key)) => return rustls::PrivateKey(key),
|
||||||
|
Some(rustls_pemfile::Item::ECKey(key)) => return rustls::PrivateKey(key),
|
||||||
|
None => break,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
panic!("no keys found in {:?} (encrypted keys not supported)", key);
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
// -------------------- start here to setup tls tcp tokio stream --------------------------
|
||||||
|
// ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs
|
||||||
|
// ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs
|
||||||
|
let cert = load_certs(END_CERT);
|
||||||
|
let key = load_private_key(END_PRIVATEKEY);
|
||||||
|
let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000);
|
||||||
|
|
||||||
|
// ------------- server side client_auth cert loading start
|
||||||
|
let mut client_auth_roots = RootCertStore::empty();
|
||||||
|
for root in load_certs(CLIENT_CHAIN_CLIENT_AUTH) {
|
||||||
|
client_auth_roots.add(&root).unwrap();
|
||||||
|
}
|
||||||
|
let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots);
|
||||||
|
// ------------- server side client_auth cert loading end
|
||||||
|
|
||||||
|
let config = rustls::ServerConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_client_cert_verifier(client_auth) // use .with_no_client_auth() instead if you don't want client-auth
|
||||||
|
.with_single_cert(cert, key)
|
||||||
|
.unwrap();
|
||||||
|
let acceptor = TlsAcceptor::from(Arc::new(config));
|
||||||
|
let listener = TcpListener::bind(&server_addr).await.unwrap();
|
||||||
|
let codec_builder = LengthDelimitedCodec::builder();
|
||||||
|
|
||||||
|
// ref ./custom_transport.rs server side
|
||||||
|
tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
let (stream, _peer_addr) = listener.accept().await.unwrap();
|
||||||
|
let tls_stream = acceptor.accept(stream).await.unwrap();
|
||||||
|
let framed = codec_builder.new_framed(tls_stream);
|
||||||
|
|
||||||
|
let transport = transport::new(framed, Bincode::default());
|
||||||
|
|
||||||
|
let fut = BaseChannel::with_defaults(transport)
|
||||||
|
.execute(Service.serve())
|
||||||
|
.for_each(spawn);
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// ---------------------- client connection ---------------------
|
||||||
|
// tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
|
||||||
|
let mut root_store = rustls::RootCertStore::empty();
|
||||||
|
for root in load_certs(END_CHAIN) {
|
||||||
|
root_store.add(&root).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH);
|
||||||
|
let client_auth_certs = load_certs(CLIENT_CERT_CLIENT_AUTH);
|
||||||
|
|
||||||
|
let config = rustls::ClientConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_root_certificates(root_store)
|
||||||
|
.with_single_cert(client_auth_certs, client_auth_private_key)?; // use .with_no_client_auth() instead if you don't want client-auth
|
||||||
|
|
||||||
|
let domain = rustls::ServerName::try_from("localhost")?;
|
||||||
|
let connector = TlsConnector::from(Arc::new(config));
|
||||||
|
|
||||||
|
let stream = TcpStream::connect(server_addr).await?;
|
||||||
|
let stream = connector.connect(domain, stream).await?;
|
||||||
|
|
||||||
|
let transport = transport::new(codec_builder.new_framed(stream), Bincode::default());
|
||||||
|
let answer = PingServiceClient::new(Default::default(), transport)
|
||||||
|
.spawn()
|
||||||
|
.ping(tarpc::context::current())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
println!("ping answer: {answer}");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -4,14 +4,34 @@
|
|||||||
// license that can be found in the LICENSE file or at
|
// license that can be found in the LICENSE file or at
|
||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use crate::{add::Add as AddService, double::Double as DoubleService};
|
use crate::{
|
||||||
use futures::{future, prelude::*};
|
add::{Add as AddService, AddStub},
|
||||||
use std::env;
|
double::Double as DoubleService,
|
||||||
use tarpc::{
|
|
||||||
client, context,
|
|
||||||
server::{incoming::Incoming, BaseChannel},
|
|
||||||
};
|
};
|
||||||
use tokio_serde::formats::Json;
|
use futures::{future, prelude::*};
|
||||||
|
use std::{
|
||||||
|
io,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use tarpc::{
|
||||||
|
client::{
|
||||||
|
self,
|
||||||
|
stub::{load_balance, retry},
|
||||||
|
RpcError,
|
||||||
|
},
|
||||||
|
context, serde_transport,
|
||||||
|
server::{
|
||||||
|
incoming::{spawn_incoming, Incoming},
|
||||||
|
request_hook::{self, BeforeRequestList},
|
||||||
|
BaseChannel,
|
||||||
|
},
|
||||||
|
tokio_serde::formats::Json,
|
||||||
|
ClientMessage, Response, ServerError, Transport,
|
||||||
|
};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
pub mod add {
|
pub mod add {
|
||||||
@@ -33,7 +53,6 @@ pub mod double {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct AddServer;
|
struct AddServer;
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl AddService for AddServer {
|
impl AddService for AddServer {
|
||||||
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
|
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
|
||||||
x + y
|
x + y
|
||||||
@@ -41,12 +60,14 @@ impl AddService for AddServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct DoubleServer {
|
struct DoubleServer<Stub> {
|
||||||
add_client: add::AddClient,
|
add_client: add::AddClient<Stub>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tarpc::server]
|
impl<Stub> DoubleService for DoubleServer<Stub>
|
||||||
impl DoubleService for DoubleServer {
|
where
|
||||||
|
Stub: AddStub + Clone + Send + Sync + 'static,
|
||||||
|
{
|
||||||
async fn double(self, _: context::Context, x: i32) -> Result<i32, String> {
|
async fn double(self, _: context::Context, x: i32) -> Result<i32, String> {
|
||||||
self.add_client
|
self.add_client
|
||||||
.add(context::current(), x, x)
|
.add(context::current(), x, x)
|
||||||
@@ -56,9 +77,9 @@ impl DoubleService for DoubleServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
let tracer = opentelemetry_jaeger::new_agent_pipeline()
|
||||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
|
||||||
.with_service_name(service_name)
|
.with_service_name(service_name)
|
||||||
|
.with_auto_split_batch(true)
|
||||||
.with_max_packet_size(2usize.pow(13))
|
.with_max_packet_size(2usize.pow(13))
|
||||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||||
|
|
||||||
@@ -71,32 +92,88 @@ fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn listen_on_random_port<Item, SinkItem>() -> anyhow::Result<(
|
||||||
|
impl Stream<Item = serde_transport::Transport<TcpStream, Item, SinkItem, Json<Item, SinkItem>>>,
|
||||||
|
std::net::SocketAddr,
|
||||||
|
)>
|
||||||
|
where
|
||||||
|
Item: for<'de> serde::Deserialize<'de>,
|
||||||
|
SinkItem: serde::Serialize,
|
||||||
|
{
|
||||||
|
let listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||||
|
.await?
|
||||||
|
.filter_map(|r| future::ready(r.ok()))
|
||||||
|
.take(1);
|
||||||
|
let addr = listener.get_ref().get_ref().local_addr();
|
||||||
|
Ok((listener, addr))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_stub<Req, Resp, const N: usize>(
|
||||||
|
backends: [impl Transport<ClientMessage<Arc<Req>>, Response<Resp>> + Send + Sync + 'static; N],
|
||||||
|
) -> retry::Retry<
|
||||||
|
impl Fn(&Result<Resp, RpcError>, u32) -> bool + Clone,
|
||||||
|
load_balance::RoundRobin<client::Channel<Arc<Req>, Resp>>,
|
||||||
|
>
|
||||||
|
where
|
||||||
|
Req: Send + Sync + 'static,
|
||||||
|
Resp: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let stub = load_balance::RoundRobin::new(
|
||||||
|
backends
|
||||||
|
.into_iter()
|
||||||
|
.map(|transport| tarpc::client::new(client::Config::default(), transport).spawn())
|
||||||
|
.collect(),
|
||||||
|
);
|
||||||
|
let stub = retry::Retry::new(stub, |resp, attempts| {
|
||||||
|
if let Err(e) = resp {
|
||||||
|
tracing::warn!("Got an error: {e:?}");
|
||||||
|
attempts < 3
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
});
|
||||||
|
stub
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
init_tracing("tarpc_tracing_example")?;
|
init_tracing("tarpc_tracing_example")?;
|
||||||
|
|
||||||
let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
let (add_listener1, addr1) = listen_on_random_port().await?;
|
||||||
.await?
|
let (add_listener2, addr2) = listen_on_random_port().await?;
|
||||||
.filter_map(|r| future::ready(r.ok()));
|
let something_bad_happened = Arc::new(AtomicBool::new(false));
|
||||||
let addr = add_listener.get_ref().local_addr();
|
let server = request_hook::before()
|
||||||
let add_server = add_listener
|
.then_fn(move |_: &mut _, _: &_| {
|
||||||
.map(BaseChannel::with_defaults)
|
let something_bad_happened = something_bad_happened.clone();
|
||||||
.take(1)
|
async move {
|
||||||
.execute(AddServer.serve());
|
if something_bad_happened.fetch_xor(true, Ordering::Relaxed) {
|
||||||
tokio::spawn(add_server);
|
Err(ServerError::new(
|
||||||
|
io::ErrorKind::NotFound,
|
||||||
|
"Gamma Ray!".into(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.serving(AddServer.serve());
|
||||||
|
let add_server = add_listener1
|
||||||
|
.chain(add_listener2)
|
||||||
|
.map(BaseChannel::with_defaults);
|
||||||
|
tokio::spawn(spawn_incoming(add_server.execute(server)));
|
||||||
|
|
||||||
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
let add_client = add::AddClient::from(make_stub([
|
||||||
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn();
|
tarpc::serde_transport::tcp::connect(addr1, Json::default).await?,
|
||||||
|
tarpc::serde_transport::tcp::connect(addr2, Json::default).await?,
|
||||||
|
]));
|
||||||
|
|
||||||
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||||
.await?
|
.await?
|
||||||
.filter_map(|r| future::ready(r.ok()));
|
.filter_map(|r| future::ready(r.ok()));
|
||||||
let addr = double_listener.get_ref().local_addr();
|
let addr = double_listener.get_ref().local_addr();
|
||||||
let double_server = double_listener
|
let double_server = double_listener.map(BaseChannel::with_defaults).take(1);
|
||||||
.map(BaseChannel::with_defaults)
|
let server = DoubleServer { add_client }.serve();
|
||||||
.take(1)
|
tokio::spawn(spawn_incoming(double_server.execute(server)));
|
||||||
.execute(DoubleServer { add_client }.serve());
|
|
||||||
tokio::spawn(double_server);
|
|
||||||
|
|
||||||
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||||
let double_client =
|
let double_client =
|
||||||
|
|||||||
49
tarpc/src/cancellations.rs
Normal file
49
tarpc/src/cancellations.rs
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
use futures::{prelude::*, task::*};
|
||||||
|
use std::pin::Pin;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
/// Sends request cancellation signals.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RequestCancellation(mpsc::UnboundedSender<u64>);
|
||||||
|
|
||||||
|
/// A stream of IDs of requests that have been canceled.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
|
||||||
|
|
||||||
|
/// Returns a channel to send request cancellation messages.
|
||||||
|
pub fn cancellations() -> (RequestCancellation, CanceledRequests) {
|
||||||
|
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
|
||||||
|
// bounded by the number of in-flight requests.
|
||||||
|
let (tx, rx) = mpsc::unbounded_channel();
|
||||||
|
(RequestCancellation(tx), CanceledRequests(rx))
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RequestCancellation {
|
||||||
|
/// Cancels the request with ID `request_id`.
|
||||||
|
///
|
||||||
|
/// No validation is done of `request_id`. There is no way to know if the request id provided
|
||||||
|
/// corresponds to a request actually tracked by the backing channel. `RequestCancellation` is
|
||||||
|
/// a one-way communication channel.
|
||||||
|
///
|
||||||
|
/// Once request data is cleaned up, a response will never be received by the client. This is
|
||||||
|
/// useful primarily when request processing ends prematurely for requests with long deadlines
|
||||||
|
/// which would otherwise continue to be tracked by the backing channel—a kind of leak.
|
||||||
|
pub fn cancel(&self, request_id: u64) {
|
||||||
|
let _ = self.0.send(request_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CanceledRequests {
|
||||||
|
/// Polls for a cancelled request.
|
||||||
|
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
||||||
|
self.0.poll_recv(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Stream for CanceledRequests {
|
||||||
|
type Item = u64;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
||||||
|
self.poll_recv(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,15 +7,18 @@
|
|||||||
//! Provides a client that connects to a server and sends multiplexed requests.
|
//! Provides a client that connects to a server and sends multiplexed requests.
|
||||||
|
|
||||||
mod in_flight_requests;
|
mod in_flight_requests;
|
||||||
|
pub mod stub;
|
||||||
|
|
||||||
use crate::{context, trace, ClientMessage, Request, Response, ServerError, Transport};
|
use crate::{
|
||||||
|
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||||
|
context, trace, ChannelError, ClientMessage, Request, Response, ServerError, Transport,
|
||||||
|
};
|
||||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||||
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
|
use in_flight_requests::InFlightRequests;
|
||||||
use pin_project::pin_project;
|
use pin_project::pin_project;
|
||||||
use std::{
|
use std::{
|
||||||
convert::TryFrom,
|
convert::TryFrom,
|
||||||
error::Error,
|
fmt,
|
||||||
fmt, mem,
|
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
sync::{
|
sync::{
|
||||||
atomic::{AtomicUsize, Ordering},
|
atomic::{AtomicUsize, Ordering},
|
||||||
@@ -81,14 +84,10 @@ impl<C, D> fmt::Debug for NewClient<C, D> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
const _CHECK_USIZE: () = assert!(
|
||||||
#[allow(clippy::no_effect)]
|
std::mem::size_of::<usize>() <= std::mem::size_of::<u64>(),
|
||||||
const CHECK_USIZE: () = {
|
"usize is too big to fit in u64"
|
||||||
if std::mem::size_of::<usize>() > std::mem::size_of::<u64>() {
|
);
|
||||||
// TODO: replace this with panic!() as soon as RFC 2345 gets stabilized
|
|
||||||
["usize is too big to fit in u64"][42];
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Handles communication from the client to request dispatch.
|
/// Handles communication from the client to request dispatch.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -118,18 +117,19 @@ impl<Req, Resp> Channel<Req, Resp> {
|
|||||||
skip(self, ctx, request_name, request),
|
skip(self, ctx, request_name, request),
|
||||||
fields(
|
fields(
|
||||||
rpc.trace_id = tracing::field::Empty,
|
rpc.trace_id = tracing::field::Empty,
|
||||||
|
rpc.deadline = %humantime::format_rfc3339(ctx.deadline),
|
||||||
otel.kind = "client",
|
otel.kind = "client",
|
||||||
otel.name = request_name)
|
otel.name = request_name)
|
||||||
)]
|
)]
|
||||||
pub async fn call(
|
pub async fn call(
|
||||||
&self,
|
&self,
|
||||||
mut ctx: context::Context,
|
mut ctx: context::Context,
|
||||||
request_name: &str,
|
request_name: &'static str,
|
||||||
request: Req,
|
request: Req,
|
||||||
) -> Result<Resp, RpcError> {
|
) -> Result<Resp, RpcError> {
|
||||||
let span = Span::current();
|
let span = Span::current();
|
||||||
ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
|
ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
|
||||||
tracing::warn!(
|
tracing::trace!(
|
||||||
"OpenTelemetry subscriber not installed; making unsampled child context."
|
"OpenTelemetry subscriber not installed; making unsampled child context."
|
||||||
);
|
);
|
||||||
ctx.trace_context.new_child()
|
ctx.trace_context.new_child()
|
||||||
@@ -147,6 +147,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
|||||||
response: &mut response,
|
response: &mut response,
|
||||||
request_id,
|
request_id,
|
||||||
cancellation: &self.cancellation,
|
cancellation: &self.cancellation,
|
||||||
|
cancel: true,
|
||||||
};
|
};
|
||||||
self.to_dispatch
|
self.to_dispatch
|
||||||
.send(DispatchRequest {
|
.send(DispatchRequest {
|
||||||
@@ -157,7 +158,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
|||||||
response_completion,
|
response_completion,
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
|
.map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?;
|
||||||
response_guard.response().await
|
response_guard.response().await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -165,9 +166,10 @@ impl<Req, Resp> Channel<Req, Resp> {
|
|||||||
/// A server response that is completed by request dispatch when the corresponding response
|
/// A server response that is completed by request dispatch when the corresponding response
|
||||||
/// arrives off the wire.
|
/// arrives off the wire.
|
||||||
struct ResponseGuard<'a, Resp> {
|
struct ResponseGuard<'a, Resp> {
|
||||||
response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
|
response: &'a mut oneshot::Receiver<Result<Resp, RpcError>>,
|
||||||
cancellation: &'a RequestCancellation,
|
cancellation: &'a RequestCancellation,
|
||||||
request_id: u64,
|
request_id: u64,
|
||||||
|
cancel: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
|
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
|
||||||
@@ -175,8 +177,14 @@ struct ResponseGuard<'a, Resp> {
|
|||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum RpcError {
|
pub enum RpcError {
|
||||||
/// The client disconnected from the server.
|
/// The client disconnected from the server.
|
||||||
#[error("the client disconnected from the server")]
|
#[error("the connection to the server was already shutdown")]
|
||||||
Disconnected,
|
Shutdown,
|
||||||
|
/// The client failed to send the request.
|
||||||
|
#[error("the client failed to send the request")]
|
||||||
|
Send(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||||
|
/// An error occurred while waiting for the server response.
|
||||||
|
#[error("an error occurred while waiting for the server response")]
|
||||||
|
Receive(#[source] Arc<dyn std::error::Error + Send + Sync + 'static>),
|
||||||
/// The request exceeded its deadline.
|
/// The request exceeded its deadline.
|
||||||
#[error("the request exceeded its deadline")]
|
#[error("the request exceeded its deadline")]
|
||||||
DeadlineExceeded,
|
DeadlineExceeded,
|
||||||
@@ -185,24 +193,18 @@ pub enum RpcError {
|
|||||||
Server(#[from] ServerError),
|
Server(#[from] ServerError),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<DeadlineExceededError> for RpcError {
|
|
||||||
fn from(_: DeadlineExceededError) -> Self {
|
|
||||||
RpcError::DeadlineExceeded
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Resp> ResponseGuard<'_, Resp> {
|
impl<Resp> ResponseGuard<'_, Resp> {
|
||||||
async fn response(mut self) -> Result<Resp, RpcError> {
|
async fn response(mut self) -> Result<Resp, RpcError> {
|
||||||
let response = (&mut self.response).await;
|
let response = (&mut self.response).await;
|
||||||
// Cancel drop logic once a response has been received.
|
// Cancel drop logic once a response has been received.
|
||||||
mem::forget(self);
|
self.cancel = false;
|
||||||
match response {
|
match response {
|
||||||
Ok(resp) => Ok(resp?.message?),
|
Ok(response) => response,
|
||||||
Err(oneshot::error::RecvError { .. }) => {
|
Err(oneshot::error::RecvError { .. }) => {
|
||||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||||
// there's nothing listening on the other side, so there's no point in
|
// there's nothing listening on the other side, so there's no point in
|
||||||
// propagating cancellation.
|
// propagating cancellation.
|
||||||
Err(RpcError::Disconnected)
|
Err(RpcError::Shutdown)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -222,7 +224,9 @@ impl<Resp> Drop for ResponseGuard<'_, Resp> {
|
|||||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||||
// receiver as closed.
|
// receiver as closed.
|
||||||
self.response.close();
|
self.response.close();
|
||||||
self.cancellation.cancel(self.request_id);
|
if self.cancel {
|
||||||
|
self.cancellation.cancel(self.request_id);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -237,7 +241,6 @@ where
|
|||||||
{
|
{
|
||||||
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
|
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
|
||||||
let (cancellation, canceled_requests) = cancellations();
|
let (cancellation, canceled_requests) = cancellations();
|
||||||
let canceled_requests = canceled_requests;
|
|
||||||
|
|
||||||
NewClient {
|
NewClient {
|
||||||
client: Channel {
|
client: Channel {
|
||||||
@@ -257,6 +260,7 @@ where
|
|||||||
|
|
||||||
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
|
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
|
||||||
/// and dispatching responses to the appropriate channel.
|
/// and dispatching responses to the appropriate channel.
|
||||||
|
#[must_use]
|
||||||
#[pin_project]
|
#[pin_project]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct RequestDispatch<Req, Resp, C> {
|
pub struct RequestDispatch<Req, Resp, C> {
|
||||||
@@ -268,42 +272,18 @@ pub struct RequestDispatch<Req, Resp, C> {
|
|||||||
/// Requests that were dropped.
|
/// Requests that were dropped.
|
||||||
canceled_requests: CanceledRequests,
|
canceled_requests: CanceledRequests,
|
||||||
/// Requests already written to the wire that haven't yet received responses.
|
/// Requests already written to the wire that haven't yet received responses.
|
||||||
in_flight_requests: InFlightRequests<Resp>,
|
in_flight_requests: InFlightRequests<Result<Resp, RpcError>>,
|
||||||
/// Configures limits to prevent unlimited resource usage.
|
/// Configures limits to prevent unlimited resource usage.
|
||||||
config: Config,
|
config: Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Critical errors that result in a Channel disconnecting.
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
|
||||||
pub enum ChannelError<E>
|
|
||||||
where
|
|
||||||
E: Error + Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
/// Could not read from the transport.
|
|
||||||
#[error("could not read from the transport")]
|
|
||||||
Read(#[source] E),
|
|
||||||
/// Could not ready the transport for writes.
|
|
||||||
#[error("could not ready the transport for writes")]
|
|
||||||
Ready(#[source] E),
|
|
||||||
/// Could not write to the transport.
|
|
||||||
#[error("could not write to the transport")]
|
|
||||||
Write(#[source] E),
|
|
||||||
/// Could not flush the transport.
|
|
||||||
#[error("could not flush the transport")]
|
|
||||||
Flush(#[source] E),
|
|
||||||
/// Could not close the write end of the transport.
|
|
||||||
#[error("could not close the write end of the transport")]
|
|
||||||
Close(#[source] E),
|
|
||||||
/// Could not poll expired requests.
|
|
||||||
#[error("could not poll expired requests")]
|
|
||||||
Timer(#[source] tokio::time::error::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
|
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
|
||||||
where
|
where
|
||||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||||
{
|
{
|
||||||
fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<Resp> {
|
fn in_flight_requests<'a>(
|
||||||
|
self: &'a mut Pin<&mut Self>,
|
||||||
|
) -> &'a mut InFlightRequests<Result<Resp, RpcError>> {
|
||||||
self.as_mut().project().in_flight_requests
|
self.as_mut().project().in_flight_requests
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -363,7 +343,17 @@ where
|
|||||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||||
self.transport_pin_mut()
|
self.transport_pin_mut()
|
||||||
.poll_next(cx)
|
.poll_next(cx)
|
||||||
.map_err(ChannelError::Read)
|
.map_err(|e| {
|
||||||
|
let e = Arc::new(e);
|
||||||
|
for span in self
|
||||||
|
.in_flight_requests()
|
||||||
|
.complete_all_requests(|| Err(RpcError::Receive(e.clone())))
|
||||||
|
{
|
||||||
|
let _entered = span.enter();
|
||||||
|
tracing::info!("ReceiveError");
|
||||||
|
}
|
||||||
|
ChannelError::Read(e)
|
||||||
|
})
|
||||||
.map_ok(|response| {
|
.map_ok(|response| {
|
||||||
self.complete(response);
|
self.complete(response);
|
||||||
})
|
})
|
||||||
@@ -395,8 +385,7 @@ where
|
|||||||
// track the status like is done with pending and cancelled requests.
|
// track the status like is done with pending and cancelled requests.
|
||||||
if let Poll::Ready(Some(_)) = self
|
if let Poll::Ready(Some(_)) = self
|
||||||
.in_flight_requests()
|
.in_flight_requests()
|
||||||
.poll_expired(cx)
|
.poll_expired(cx, || Err(RpcError::DeadlineExceeded))
|
||||||
.map_err(ChannelError::Timer)?
|
|
||||||
{
|
{
|
||||||
// Expired requests are considered complete; there is no compelling reason to send a
|
// Expired requests are considered complete; there is no compelling reason to send a
|
||||||
// cancellation message to the server, since it will have already exhausted its
|
// cancellation message to the server, since it will have already exhausted its
|
||||||
@@ -508,11 +497,10 @@ where
|
|||||||
Some(dispatch_request) => dispatch_request,
|
Some(dispatch_request) => dispatch_request,
|
||||||
None => return Poll::Ready(None),
|
None => return Poll::Ready(None),
|
||||||
};
|
};
|
||||||
let entered = span.enter();
|
let _entered = span.enter();
|
||||||
// poll_next_request only returns Ready if there is room to buffer another request.
|
// poll_next_request only returns Ready if there is room to buffer another request.
|
||||||
// Therefore, we can call write_request without fear of erroring due to a full
|
// Therefore, we can call write_request without fear of erroring due to a full
|
||||||
// buffer.
|
// buffer.
|
||||||
let request_id = request_id;
|
|
||||||
let request = ClientMessage::Request(Request {
|
let request = ClientMessage::Request(Request {
|
||||||
id: request_id,
|
id: request_id,
|
||||||
message: request,
|
message: request,
|
||||||
@@ -521,17 +509,16 @@ where
|
|||||||
trace_context: ctx.trace_context,
|
trace_context: ctx.trace_context,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
self.start_send(request)?;
|
|
||||||
let deadline = ctx.deadline;
|
|
||||||
tracing::info!(
|
|
||||||
tarpc.deadline = %humantime::format_rfc3339(deadline),
|
|
||||||
"SendRequest"
|
|
||||||
);
|
|
||||||
drop(entered);
|
|
||||||
|
|
||||||
self.in_flight_requests()
|
self.in_flight_requests()
|
||||||
.insert_request(request_id, ctx, span, response_completion)
|
.insert_request(request_id, ctx, span.clone(), response_completion)
|
||||||
.expect("Request IDs should be unique");
|
.expect("Request IDs should be unique");
|
||||||
|
match self.start_send(request) {
|
||||||
|
Ok(()) => tracing::info!("SendRequest"),
|
||||||
|
Err(e) => {
|
||||||
|
self.in_flight_requests()
|
||||||
|
.complete_request(request_id, Err(RpcError::Send(Box::new(e))));
|
||||||
|
}
|
||||||
|
}
|
||||||
Poll::Ready(Some(Ok(())))
|
Poll::Ready(Some(Ok(())))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -556,7 +543,15 @@ where
|
|||||||
|
|
||||||
/// Sends a server response to the client task that initiated the associated request.
|
/// Sends a server response to the client task that initiated the associated request.
|
||||||
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
|
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
|
||||||
self.in_flight_requests().complete_request(response)
|
if let Some(span) = self.in_flight_requests().complete_request(
|
||||||
|
response.request_id,
|
||||||
|
response.message.map_err(RpcError::Server),
|
||||||
|
) {
|
||||||
|
let _entered = span.enter();
|
||||||
|
tracing::info!("ReceiveResponse");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -605,76 +600,43 @@ struct DispatchRequest<Req, Resp> {
|
|||||||
pub span: Span,
|
pub span: Span,
|
||||||
pub request_id: u64,
|
pub request_id: u64,
|
||||||
pub request: Req,
|
pub request: Req,
|
||||||
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
pub response_completion: oneshot::Sender<Result<Resp, RpcError>>,
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends request cancellation signals.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct RequestCancellation(mpsc::UnboundedSender<u64>);
|
|
||||||
|
|
||||||
/// A stream of IDs of requests that have been canceled.
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
|
|
||||||
|
|
||||||
/// Returns a channel to send request cancellation messages.
|
|
||||||
fn cancellations() -> (RequestCancellation, CanceledRequests) {
|
|
||||||
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
|
|
||||||
// bounded by the number of in-flight requests.
|
|
||||||
let (tx, rx) = mpsc::unbounded_channel();
|
|
||||||
(RequestCancellation(tx), CanceledRequests(rx))
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RequestCancellation {
|
|
||||||
/// Cancels the request with ID `request_id`.
|
|
||||||
fn cancel(&self, request_id: u64) {
|
|
||||||
let _ = self.0.send(request_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CanceledRequests {
|
|
||||||
fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
|
||||||
self.0.poll_recv(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Stream for CanceledRequests {
|
|
||||||
type Item = u64;
|
|
||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
|
||||||
self.poll_recv(cx)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{
|
use super::{
|
||||||
cancellations, CanceledRequests, Channel, DispatchRequest, RequestCancellation,
|
cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError,
|
||||||
RequestDispatch, ResponseGuard,
|
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
client::{
|
client::{in_flight_requests::InFlightRequests, Config},
|
||||||
in_flight_requests::{DeadlineExceededError, InFlightRequests},
|
context::{self, current},
|
||||||
Config,
|
|
||||||
},
|
|
||||||
context,
|
|
||||||
transport::{self, channel::UnboundedChannel},
|
transport::{self, channel::UnboundedChannel},
|
||||||
ClientMessage, Response,
|
ChannelError, ClientMessage, Response,
|
||||||
};
|
};
|
||||||
use assert_matches::assert_matches;
|
use assert_matches::assert_matches;
|
||||||
use futures::{prelude::*, task::*};
|
use futures::{prelude::*, task::*};
|
||||||
use std::{
|
use std::{
|
||||||
convert::TryFrom,
|
convert::TryFrom,
|
||||||
|
fmt::Display,
|
||||||
|
marker::PhantomData,
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
sync::atomic::{AtomicUsize, Ordering},
|
sync::{
|
||||||
sync::Arc,
|
atomic::{AtomicUsize, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::sync::{
|
||||||
|
mpsc::{self},
|
||||||
|
oneshot,
|
||||||
};
|
};
|
||||||
use tokio::sync::{mpsc, oneshot};
|
|
||||||
use tracing::Span;
|
use tracing::Span;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn response_completes_request_future() {
|
async fn response_completes_request_future() {
|
||||||
let (mut dispatch, mut _channel, mut server_channel) = set_up();
|
let (mut dispatch, mut _channel, mut server_channel) = set_up();
|
||||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||||
let (tx, mut rx) = oneshot::channel();
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
|
||||||
dispatch
|
dispatch
|
||||||
@@ -689,7 +651,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
|
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
|
||||||
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
|
assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -700,10 +662,11 @@ mod tests {
|
|||||||
response: &mut response,
|
response: &mut response,
|
||||||
cancellation: &cancellation,
|
cancellation: &cancellation,
|
||||||
request_id: 3,
|
request_id: 3,
|
||||||
|
cancel: true,
|
||||||
});
|
});
|
||||||
// resp's drop() is run, which should send a cancel message.
|
// resp's drop() is run, which should send a cancel message.
|
||||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||||
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(Some(3)));
|
assert_eq!(canceled_requests.poll_recv(cx), Poll::Ready(Some(3)));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -720,23 +683,25 @@ mod tests {
|
|||||||
response: &mut response,
|
response: &mut response,
|
||||||
cancellation: &cancellation,
|
cancellation: &cancellation,
|
||||||
request_id: 3,
|
request_id: 3,
|
||||||
|
cancel: true,
|
||||||
}
|
}
|
||||||
.response()
|
.response()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
drop(cancellation);
|
drop(cancellation);
|
||||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||||
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(None));
|
assert_eq!(canceled_requests.poll_recv(cx), Poll::Ready(None));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn stage_request() {
|
async fn stage_request() {
|
||||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||||
let (tx, mut rx) = oneshot::channel();
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
|
||||||
let _resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
let _resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||||
|
|
||||||
|
#[allow(unstable_name_collisions)]
|
||||||
let req = dispatch.as_mut().poll_next_request(cx).ready();
|
let req = dispatch.as_mut().poll_next_request(cx).ready();
|
||||||
assert!(req.is_some());
|
assert!(req.is_some());
|
||||||
|
|
||||||
@@ -749,7 +714,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn stage_request_channel_dropped_doesnt_panic() {
|
async fn stage_request_channel_dropped_doesnt_panic() {
|
||||||
let (mut dispatch, mut channel, mut server_channel) = set_up();
|
let (mut dispatch, mut channel, mut server_channel) = set_up();
|
||||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||||
let (tx, mut rx) = oneshot::channel();
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
|
||||||
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
|
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||||
@@ -767,10 +732,11 @@ mod tests {
|
|||||||
dispatch.await.unwrap();
|
dispatch.await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(unstable_name_collisions)]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn stage_request_response_future_dropped_is_canceled_before_sending() {
|
async fn stage_request_response_future_dropped_is_canceled_before_sending() {
|
||||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||||
let (tx, mut rx) = oneshot::channel();
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
|
||||||
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
|
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||||
@@ -782,10 +748,11 @@ mod tests {
|
|||||||
assert!(dispatch.as_mut().poll_next_request(cx).ready().is_none());
|
assert!(dispatch.as_mut().poll_next_request(cx).ready().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(unstable_name_collisions)]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn stage_request_response_future_dropped_is_canceled_after_sending() {
|
async fn stage_request_response_future_dropped_is_canceled_after_sending() {
|
||||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||||
let (tx, mut rx) = oneshot::channel();
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
|
||||||
let req = send_request(&mut channel, "hi", tx, &mut rx).await;
|
let req = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||||
@@ -806,7 +773,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn stage_request_response_closed_skipped() {
|
async fn stage_request_response_closed_skipped() {
|
||||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||||
let (tx, mut rx) = oneshot::channel();
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
|
||||||
// Test that a request future that's closed its receiver but not yet canceled its request --
|
// Test that a request future that's closed its receiver but not yet canceled its request --
|
||||||
@@ -818,6 +785,185 @@ mod tests {
|
|||||||
assert!(dispatch.as_mut().poll_next_request(cx).is_pending());
|
assert!(dispatch.as_mut().poll_next_request(cx).is_pending());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_shutdown_error() {
|
||||||
|
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
|
||||||
|
let (dispatch, mut channel, _) = set_up();
|
||||||
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
// send succeeds
|
||||||
|
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||||
|
drop(dispatch);
|
||||||
|
// error on receive
|
||||||
|
assert_matches!(resp.response().await, Err(RpcError::Shutdown));
|
||||||
|
let (dispatch, channel, _) = set_up();
|
||||||
|
drop(dispatch);
|
||||||
|
// error on send
|
||||||
|
let resp = channel
|
||||||
|
.call(current(), "test_request", "hi".to_string())
|
||||||
|
.await;
|
||||||
|
assert_matches!(resp, Err(RpcError::Shutdown));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_transport_error_write() {
|
||||||
|
let cause = TransportError::Write;
|
||||||
|
let (mut dispatch, mut channel, mut cx) = setup_always_err(cause);
|
||||||
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
|
||||||
|
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||||
|
assert!(dispatch.as_mut().poll(&mut cx).is_pending());
|
||||||
|
let res = resp.response().await;
|
||||||
|
assert_matches!(res, Err(RpcError::Send(_)));
|
||||||
|
let client_error: anyhow::Error = res.unwrap_err().into();
|
||||||
|
let mut chain = client_error.chain();
|
||||||
|
chain.next(); // original RpcError
|
||||||
|
assert_eq!(
|
||||||
|
chain
|
||||||
|
.next()
|
||||||
|
.unwrap()
|
||||||
|
.downcast_ref::<ChannelError<TransportError>>(),
|
||||||
|
Some(&ChannelError::Write(cause))
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
client_error.root_cause().downcast_ref::<TransportError>(),
|
||||||
|
Some(&cause)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_transport_error_read() {
|
||||||
|
let cause = TransportError::Read;
|
||||||
|
let (mut dispatch, mut channel, mut cx) = setup_always_err(cause);
|
||||||
|
let (tx, mut rx) = oneshot::channel();
|
||||||
|
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||||
|
assert_eq!(
|
||||||
|
dispatch.as_mut().pump_write(&mut cx),
|
||||||
|
Poll::Ready(Some(Ok(())))
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
dispatch.as_mut().pump_read(&mut cx),
|
||||||
|
Poll::Ready(Some(Err(ChannelError::Read(Arc::new(cause)))))
|
||||||
|
);
|
||||||
|
assert_matches!(resp.response().await, Err(RpcError::Receive(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_transport_error_ready() {
|
||||||
|
let cause = TransportError::Ready;
|
||||||
|
let (mut dispatch, _, mut cx) = setup_always_err(cause);
|
||||||
|
assert_eq!(
|
||||||
|
dispatch.as_mut().poll(&mut cx),
|
||||||
|
Poll::Ready(Err(ChannelError::Ready(cause)))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_transport_error_flush() {
|
||||||
|
let cause = TransportError::Flush;
|
||||||
|
let (mut dispatch, _, mut cx) = setup_always_err(cause);
|
||||||
|
assert_eq!(
|
||||||
|
dispatch.as_mut().poll(&mut cx),
|
||||||
|
Poll::Ready(Err(ChannelError::Flush(cause)))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_transport_error_close() {
|
||||||
|
let cause = TransportError::Close;
|
||||||
|
let (mut dispatch, channel, mut cx) = setup_always_err(cause);
|
||||||
|
drop(channel);
|
||||||
|
assert_eq!(
|
||||||
|
dispatch.as_mut().poll(&mut cx),
|
||||||
|
Poll::Ready(Err(ChannelError::Close(cause)))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn setup_always_err(
|
||||||
|
cause: TransportError,
|
||||||
|
) -> (
|
||||||
|
Pin<Box<RequestDispatch<String, String, AlwaysErrorTransport<String>>>>,
|
||||||
|
Channel<String, String>,
|
||||||
|
Context<'static>,
|
||||||
|
) {
|
||||||
|
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
||||||
|
let (cancellation, canceled_requests) = cancellations();
|
||||||
|
let transport: AlwaysErrorTransport<String> = AlwaysErrorTransport(cause, PhantomData);
|
||||||
|
let dispatch = Box::pin(RequestDispatch::<String, String, _> {
|
||||||
|
transport: transport.fuse(),
|
||||||
|
pending_requests,
|
||||||
|
canceled_requests,
|
||||||
|
in_flight_requests: InFlightRequests::default(),
|
||||||
|
config: Config::default(),
|
||||||
|
});
|
||||||
|
let channel = Channel {
|
||||||
|
to_dispatch,
|
||||||
|
cancellation,
|
||||||
|
next_request_id: Arc::new(AtomicUsize::new(0)),
|
||||||
|
};
|
||||||
|
let cx = Context::from_waker(noop_waker_ref());
|
||||||
|
(dispatch, channel, cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AlwaysErrorTransport<I>(TransportError, PhantomData<I>);
|
||||||
|
|
||||||
|
#[derive(Debug, Error, PartialEq, Eq, Clone, Copy)]
|
||||||
|
enum TransportError {
|
||||||
|
Read,
|
||||||
|
Ready,
|
||||||
|
Write,
|
||||||
|
Flush,
|
||||||
|
Close,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for TransportError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.write_str(&format!("{self:?}"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I: Clone, S> Sink<S> for AlwaysErrorTransport<I> {
|
||||||
|
type Error = TransportError;
|
||||||
|
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
match self.0 {
|
||||||
|
TransportError::Ready => Poll::Ready(Err(self.0)),
|
||||||
|
TransportError::Flush => Poll::Pending,
|
||||||
|
_ => Poll::Ready(Ok(())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn start_send(self: Pin<&mut Self>, _: S) -> Result<(), Self::Error> {
|
||||||
|
if matches!(self.0, TransportError::Write) {
|
||||||
|
Err(self.0)
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
if matches!(self.0, TransportError::Flush) {
|
||||||
|
Poll::Ready(Err(self.0))
|
||||||
|
} else {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
if matches!(self.0, TransportError::Close) {
|
||||||
|
Poll::Ready(Err(self.0))
|
||||||
|
} else {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I: Clone> Stream for AlwaysErrorTransport<I> {
|
||||||
|
type Item = Result<Response<I>, TransportError>;
|
||||||
|
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
if matches!(self.0, TransportError::Read) {
|
||||||
|
Poll::Ready(Some(Err(self.0)))
|
||||||
|
} else {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn set_up() -> (
|
fn set_up() -> (
|
||||||
Pin<
|
Pin<
|
||||||
Box<
|
Box<
|
||||||
@@ -834,18 +980,17 @@ mod tests {
|
|||||||
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
|
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
|
||||||
|
|
||||||
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
||||||
let (cancel_tx, canceled_requests) = mpsc::unbounded_channel();
|
let (cancellation, canceled_requests) = cancellations();
|
||||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||||
|
|
||||||
let dispatch = RequestDispatch::<String, String, _> {
|
let dispatch = RequestDispatch::<String, String, _> {
|
||||||
transport: client_channel.fuse(),
|
transport: client_channel.fuse(),
|
||||||
pending_requests: pending_requests,
|
pending_requests,
|
||||||
canceled_requests: CanceledRequests(canceled_requests),
|
canceled_requests,
|
||||||
in_flight_requests: InFlightRequests::default(),
|
in_flight_requests: InFlightRequests::default(),
|
||||||
config: Config::default(),
|
config: Config::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let cancellation = RequestCancellation(cancel_tx);
|
|
||||||
let channel = Channel {
|
let channel = Channel {
|
||||||
to_dispatch,
|
to_dispatch,
|
||||||
cancellation,
|
cancellation,
|
||||||
@@ -858,8 +1003,8 @@ mod tests {
|
|||||||
async fn send_request<'a>(
|
async fn send_request<'a>(
|
||||||
channel: &'a mut Channel<String, String>,
|
channel: &'a mut Channel<String, String>,
|
||||||
request: &str,
|
request: &str,
|
||||||
response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
|
response_completion: oneshot::Sender<Result<String, RpcError>>,
|
||||||
response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
|
response: &'a mut oneshot::Receiver<Result<String, RpcError>>,
|
||||||
) -> ResponseGuard<'a, String> {
|
) -> ResponseGuard<'a, String> {
|
||||||
let request_id =
|
let request_id =
|
||||||
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||||
@@ -870,13 +1015,14 @@ mod tests {
|
|||||||
request: request.to_string(),
|
request: request.to_string(),
|
||||||
response_completion,
|
response_completion,
|
||||||
};
|
};
|
||||||
channel.to_dispatch.send(request).await.unwrap();
|
let response_guard = ResponseGuard {
|
||||||
|
|
||||||
ResponseGuard {
|
|
||||||
response,
|
response,
|
||||||
cancellation: &channel.cancellation,
|
cancellation: &channel.cancellation,
|
||||||
request_id,
|
request_id,
|
||||||
}
|
cancel: true,
|
||||||
|
};
|
||||||
|
channel.to_dispatch.send(request).await.unwrap();
|
||||||
|
response_guard
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_response(
|
async fn send_response(
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
context,
|
context,
|
||||||
util::{Compact, TimeUntil},
|
util::{Compact, TimeUntil},
|
||||||
Response,
|
|
||||||
};
|
};
|
||||||
use fnv::FnvHashMap;
|
use fnv::FnvHashMap;
|
||||||
use std::{
|
use std::{
|
||||||
@@ -28,17 +27,11 @@ impl<Resp> Default for InFlightRequests<Resp> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The request exceeded its deadline.
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
|
||||||
#[non_exhaustive]
|
|
||||||
#[error("the request exceeded its deadline")]
|
|
||||||
pub struct DeadlineExceededError;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct RequestData<Resp> {
|
struct RequestData<Res> {
|
||||||
ctx: context::Context,
|
ctx: context::Context,
|
||||||
span: Span,
|
span: Span,
|
||||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
response_completion: oneshot::Sender<Res>,
|
||||||
/// The key to remove the timer for the request's deadline.
|
/// The key to remove the timer for the request's deadline.
|
||||||
deadline_key: delay_queue::Key,
|
deadline_key: delay_queue::Key,
|
||||||
}
|
}
|
||||||
@@ -48,7 +41,7 @@ struct RequestData<Resp> {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct AlreadyExistsError;
|
pub struct AlreadyExistsError;
|
||||||
|
|
||||||
impl<Resp> InFlightRequests<Resp> {
|
impl<Res> InFlightRequests<Res> {
|
||||||
/// Returns the number of in-flight requests.
|
/// Returns the number of in-flight requests.
|
||||||
pub fn len(&self) -> usize {
|
pub fn len(&self) -> usize {
|
||||||
self.request_data.len()
|
self.request_data.len()
|
||||||
@@ -65,7 +58,7 @@ impl<Resp> InFlightRequests<Resp> {
|
|||||||
request_id: u64,
|
request_id: u64,
|
||||||
ctx: context::Context,
|
ctx: context::Context,
|
||||||
span: Span,
|
span: Span,
|
||||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
response_completion: oneshot::Sender<Res>,
|
||||||
) -> Result<(), AlreadyExistsError> {
|
) -> Result<(), AlreadyExistsError> {
|
||||||
match self.request_data.entry(request_id) {
|
match self.request_data.entry(request_id) {
|
||||||
hash_map::Entry::Vacant(vacant) => {
|
hash_map::Entry::Vacant(vacant) => {
|
||||||
@@ -84,23 +77,31 @@ impl<Resp> InFlightRequests<Resp> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Removes a request without aborting. Returns true iff the request was found.
|
/// Removes a request without aborting. Returns true iff the request was found.
|
||||||
pub fn complete_request(&mut self, response: Response<Resp>) -> bool {
|
pub fn complete_request(&mut self, request_id: u64, result: Res) -> Option<Span> {
|
||||||
if let Some(request_data) = self.request_data.remove(&response.request_id) {
|
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||||
let _entered = request_data.span.enter();
|
|
||||||
tracing::info!("ReceiveResponse");
|
|
||||||
self.request_data.compact(0.1);
|
self.request_data.compact(0.1);
|
||||||
self.deadlines.remove(&request_data.deadline_key);
|
self.deadlines.remove(&request_data.deadline_key);
|
||||||
let _ = request_data.response_completion.send(Ok(response));
|
let _ = request_data.response_completion.send(result);
|
||||||
return true;
|
return Some(request_data.span);
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::debug!(
|
tracing::debug!("No in-flight request found for request_id = {request_id}.");
|
||||||
"No in-flight request found for request_id = {}.",
|
|
||||||
response.request_id
|
|
||||||
);
|
|
||||||
|
|
||||||
// If the response completion was absent, then the request was already canceled.
|
// If the response completion was absent, then the request was already canceled.
|
||||||
false
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Completes all requests using the provided function.
|
||||||
|
/// Returns Spans for all completes requests.
|
||||||
|
pub fn complete_all_requests<'a>(
|
||||||
|
&'a mut self,
|
||||||
|
mut result: impl FnMut() -> Res + 'a,
|
||||||
|
) -> impl Iterator<Item = Span> + 'a {
|
||||||
|
self.deadlines.clear();
|
||||||
|
self.request_data.drain().map(move |(_, request_data)| {
|
||||||
|
let _ = request_data.response_completion.send(result());
|
||||||
|
request_data.span
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Cancels a request without completing (typically used when a request handle was dropped
|
/// Cancels a request without completing (typically used when a request handle was dropped
|
||||||
@@ -120,18 +121,17 @@ impl<Resp> InFlightRequests<Resp> {
|
|||||||
pub fn poll_expired(
|
pub fn poll_expired(
|
||||||
&mut self,
|
&mut self,
|
||||||
cx: &mut Context,
|
cx: &mut Context,
|
||||||
) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
|
expired_error: impl Fn() -> Res,
|
||||||
self.deadlines.poll_expired(cx).map_ok(|expired| {
|
) -> Poll<Option<u64>> {
|
||||||
let request_id = expired.into_inner();
|
self.deadlines.poll_expired(cx).map(|expired| {
|
||||||
|
let request_id = expired?.into_inner();
|
||||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||||
let _entered = request_data.span.enter();
|
let _entered = request_data.span.enter();
|
||||||
tracing::error!("DeadlineExceeded");
|
tracing::error!("DeadlineExceeded");
|
||||||
self.request_data.compact(0.1);
|
self.request_data.compact(0.1);
|
||||||
let _ = request_data
|
let _ = request_data.response_completion.send(expired_error());
|
||||||
.response_completion
|
|
||||||
.send(Err(DeadlineExceededError));
|
|
||||||
}
|
}
|
||||||
request_id
|
Some(request_id)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
45
tarpc/src/client/stub.rs
Normal file
45
tarpc/src/client/stub.rs
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
//! Provides a Stub trait, implemented by types that can call remote services.
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
client::{Channel, RpcError},
|
||||||
|
context,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub mod load_balance;
|
||||||
|
pub mod retry;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod mock;
|
||||||
|
|
||||||
|
/// A connection to a remote service.
|
||||||
|
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
|
||||||
|
#[allow(async_fn_in_trait)]
|
||||||
|
pub trait Stub {
|
||||||
|
/// The service request type.
|
||||||
|
type Req;
|
||||||
|
|
||||||
|
/// The service response type.
|
||||||
|
type Resp;
|
||||||
|
|
||||||
|
/// Calls a remote service.
|
||||||
|
async fn call(
|
||||||
|
&self,
|
||||||
|
ctx: context::Context,
|
||||||
|
request_name: &'static str,
|
||||||
|
request: Self::Req,
|
||||||
|
) -> Result<Self::Resp, RpcError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp> Stub for Channel<Req, Resp> {
|
||||||
|
type Req = Req;
|
||||||
|
type Resp = Resp;
|
||||||
|
|
||||||
|
async fn call(
|
||||||
|
&self,
|
||||||
|
ctx: context::Context,
|
||||||
|
request_name: &'static str,
|
||||||
|
request: Req,
|
||||||
|
) -> Result<Self::Resp, RpcError> {
|
||||||
|
Self::call(self, ctx, request_name, request).await
|
||||||
|
}
|
||||||
|
}
|
||||||
279
tarpc/src/client/stub/load_balance.rs
Normal file
279
tarpc/src/client/stub/load_balance.rs
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
//! Provides load-balancing [Stubs](crate::client::stub::Stub).
|
||||||
|
|
||||||
|
pub use consistent_hash::ConsistentHash;
|
||||||
|
pub use round_robin::RoundRobin;
|
||||||
|
|
||||||
|
/// Provides a stub that load-balances with a simple round-robin strategy.
|
||||||
|
mod round_robin {
|
||||||
|
use crate::{
|
||||||
|
client::{stub, RpcError},
|
||||||
|
context,
|
||||||
|
};
|
||||||
|
use cycle::AtomicCycle;
|
||||||
|
|
||||||
|
impl<Stub> stub::Stub for RoundRobin<Stub>
|
||||||
|
where
|
||||||
|
Stub: stub::Stub,
|
||||||
|
{
|
||||||
|
type Req = Stub::Req;
|
||||||
|
type Resp = Stub::Resp;
|
||||||
|
|
||||||
|
async fn call(
|
||||||
|
&self,
|
||||||
|
ctx: context::Context,
|
||||||
|
request_name: &'static str,
|
||||||
|
request: Self::Req,
|
||||||
|
) -> Result<Stub::Resp, RpcError> {
|
||||||
|
let next = self.stubs.next();
|
||||||
|
next.call(ctx, request_name, request).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Stub that load-balances across backing stubs by round robin.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct RoundRobin<Stub> {
|
||||||
|
stubs: AtomicCycle<Stub>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Stub> RoundRobin<Stub>
|
||||||
|
where
|
||||||
|
Stub: stub::Stub,
|
||||||
|
{
|
||||||
|
/// Returns a new RoundRobin stub.
|
||||||
|
pub fn new(stubs: Vec<Stub>) -> Self {
|
||||||
|
Self {
|
||||||
|
stubs: AtomicCycle::new(stubs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mod cycle {
|
||||||
|
use std::sync::{
|
||||||
|
atomic::{AtomicUsize, Ordering},
|
||||||
|
Arc,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Cycles endlessly and atomically over a collection of elements of type T.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct AtomicCycle<T>(Arc<State<T>>);
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct State<T> {
|
||||||
|
elements: Vec<T>,
|
||||||
|
next: AtomicUsize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> AtomicCycle<T> {
|
||||||
|
pub fn new(elements: Vec<T>) -> Self {
|
||||||
|
Self(Arc::new(State {
|
||||||
|
elements,
|
||||||
|
next: Default::default(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn next(&self) -> &T {
|
||||||
|
self.0.next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> State<T> {
|
||||||
|
pub fn next(&self) -> &T {
|
||||||
|
let next = self.next.fetch_add(1, Ordering::Relaxed);
|
||||||
|
&self.elements[next % self.elements.len()]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cycle() {
|
||||||
|
let cycle = AtomicCycle::new(vec![1, 2, 3]);
|
||||||
|
assert_eq!(cycle.next(), &1);
|
||||||
|
assert_eq!(cycle.next(), &2);
|
||||||
|
assert_eq!(cycle.next(), &3);
|
||||||
|
assert_eq!(cycle.next(), &1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provides a stub that load-balances with a consistent hashing strategy.
|
||||||
|
///
|
||||||
|
/// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use
|
||||||
|
/// the same stub.
|
||||||
|
mod consistent_hash {
|
||||||
|
use crate::{
|
||||||
|
client::{stub, RpcError},
|
||||||
|
context,
|
||||||
|
};
|
||||||
|
use std::{
|
||||||
|
collections::hash_map::RandomState,
|
||||||
|
hash::{BuildHasher, Hash, Hasher},
|
||||||
|
num::TryFromIntError,
|
||||||
|
};
|
||||||
|
|
||||||
|
impl<Stub, S> stub::Stub for ConsistentHash<Stub, S>
|
||||||
|
where
|
||||||
|
Stub: stub::Stub,
|
||||||
|
Stub::Req: Hash,
|
||||||
|
S: BuildHasher,
|
||||||
|
{
|
||||||
|
type Req = Stub::Req;
|
||||||
|
type Resp = Stub::Resp;
|
||||||
|
|
||||||
|
async fn call(
|
||||||
|
&self,
|
||||||
|
ctx: context::Context,
|
||||||
|
request_name: &'static str,
|
||||||
|
request: Self::Req,
|
||||||
|
) -> Result<Stub::Resp, RpcError> {
|
||||||
|
let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect(
|
||||||
|
"invariant broken: stubs_len is not larger than a usize, \
|
||||||
|
so the hash modulo stubs_len should always fit in a usize",
|
||||||
|
);
|
||||||
|
let next = &self.stubs[index];
|
||||||
|
next.call(ctx, request_name, request).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Stub that load-balances across backing stubs by round robin.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ConsistentHash<Stub, S = RandomState> {
|
||||||
|
stubs: Vec<Stub>,
|
||||||
|
stubs_len: u64,
|
||||||
|
hasher: S,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Stub> ConsistentHash<Stub, RandomState>
|
||||||
|
where
|
||||||
|
Stub: stub::Stub,
|
||||||
|
Stub::Req: Hash,
|
||||||
|
{
|
||||||
|
/// Returns a new RoundRobin stub.
|
||||||
|
/// Returns an err if the length of `stubs` overflows a u64.
|
||||||
|
pub fn new(stubs: Vec<Stub>) -> Result<Self, TryFromIntError> {
|
||||||
|
Ok(Self {
|
||||||
|
stubs_len: stubs.len().try_into()?,
|
||||||
|
stubs,
|
||||||
|
hasher: RandomState::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Stub, S> ConsistentHash<Stub, S>
|
||||||
|
where
|
||||||
|
Stub: stub::Stub,
|
||||||
|
Stub::Req: Hash,
|
||||||
|
S: BuildHasher,
|
||||||
|
{
|
||||||
|
/// Returns a new RoundRobin stub.
|
||||||
|
/// Returns an err if the length of `stubs` overflows a u64.
|
||||||
|
pub fn with_hasher(stubs: Vec<Stub>, hasher: S) -> Result<Self, TryFromIntError> {
|
||||||
|
Ok(Self {
|
||||||
|
stubs_len: stubs.len().try_into()?,
|
||||||
|
stubs,
|
||||||
|
hasher,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hash_request(&self, req: &Stub::Req) -> u64 {
|
||||||
|
let mut hasher = self.hasher.build_hasher();
|
||||||
|
req.hash(&mut hasher);
|
||||||
|
hasher.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::ConsistentHash;
|
||||||
|
use crate::{
|
||||||
|
client::stub::{mock::Mock, Stub},
|
||||||
|
context,
|
||||||
|
};
|
||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
hash::{BuildHasher, Hash, Hasher},
|
||||||
|
rc::Rc,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test() -> anyhow::Result<()> {
|
||||||
|
let stub = ConsistentHash::<_, FakeHasherBuilder>::with_hasher(
|
||||||
|
vec![
|
||||||
|
// For easier reading of the assertions made in this test, each Mock's response
|
||||||
|
// value is equal to a hash value that should map to its index: 3 % 3 = 0, 1 %
|
||||||
|
// 3 = 1, etc.
|
||||||
|
Mock::new([('a', 3), ('b', 3), ('c', 3)]),
|
||||||
|
Mock::new([('a', 1), ('b', 1), ('c', 1)]),
|
||||||
|
Mock::new([('a', 2), ('b', 2), ('c', 2)]),
|
||||||
|
],
|
||||||
|
FakeHasherBuilder::new([('a', 1), ('b', 2), ('c', 3)]),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
for _ in 0..2 {
|
||||||
|
let resp = stub.call(context::current(), "", 'a').await?;
|
||||||
|
assert_eq!(resp, 1);
|
||||||
|
|
||||||
|
let resp = stub.call(context::current(), "", 'b').await?;
|
||||||
|
assert_eq!(resp, 2);
|
||||||
|
|
||||||
|
let resp = stub.call(context::current(), "", 'c').await?;
|
||||||
|
assert_eq!(resp, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct HashRecorder(Vec<u8>);
|
||||||
|
impl Hasher for HashRecorder {
|
||||||
|
fn write(&mut self, bytes: &[u8]) {
|
||||||
|
self.0 = Vec::from(bytes);
|
||||||
|
}
|
||||||
|
fn finish(&self) -> u64 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FakeHasherBuilder {
|
||||||
|
recorded_hashes: Rc<HashMap<Vec<u8>, u64>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FakeHasher {
|
||||||
|
recorded_hashes: Rc<HashMap<Vec<u8>, u64>>,
|
||||||
|
output: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BuildHasher for FakeHasherBuilder {
|
||||||
|
type Hasher = FakeHasher;
|
||||||
|
|
||||||
|
fn build_hasher(&self) -> Self::Hasher {
|
||||||
|
FakeHasher {
|
||||||
|
recorded_hashes: self.recorded_hashes.clone(),
|
||||||
|
output: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeHasherBuilder {
|
||||||
|
fn new<T: Hash, const N: usize>(fake_hashes: [(T, u64); N]) -> Self {
|
||||||
|
let mut recorded_hashes = HashMap::new();
|
||||||
|
for (to_hash, fake_hash) in fake_hashes {
|
||||||
|
let mut recorder = HashRecorder(vec![]);
|
||||||
|
to_hash.hash(&mut recorder);
|
||||||
|
recorded_hashes.insert(recorder.0, fake_hash);
|
||||||
|
}
|
||||||
|
Self {
|
||||||
|
recorded_hashes: Rc::new(recorded_hashes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Hasher for FakeHasher {
|
||||||
|
fn write(&mut self, bytes: &[u8]) {
|
||||||
|
if let Some(hash) = self.recorded_hashes.get(bytes) {
|
||||||
|
self.output = *hash;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn finish(&self) -> u64 {
|
||||||
|
self.output
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
49
tarpc/src/client/stub/mock.rs
Normal file
49
tarpc/src/client/stub/mock.rs
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
use crate::{
|
||||||
|
client::{stub::Stub, RpcError},
|
||||||
|
context, ServerError,
|
||||||
|
};
|
||||||
|
use std::{collections::HashMap, hash::Hash, io};
|
||||||
|
|
||||||
|
/// A mock stub that returns user-specified responses.
|
||||||
|
pub struct Mock<Req, Resp> {
|
||||||
|
responses: HashMap<Req, Resp>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp> Mock<Req, Resp>
|
||||||
|
where
|
||||||
|
Req: Eq + Hash,
|
||||||
|
{
|
||||||
|
/// Returns a new mock, mocking the specified (request, response) pairs.
|
||||||
|
pub fn new<const N: usize>(responses: [(Req, Resp); N]) -> Self {
|
||||||
|
Self {
|
||||||
|
responses: HashMap::from(responses),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp> Stub for Mock<Req, Resp>
|
||||||
|
where
|
||||||
|
Req: Eq + Hash,
|
||||||
|
Resp: Clone,
|
||||||
|
{
|
||||||
|
type Req = Req;
|
||||||
|
type Resp = Resp;
|
||||||
|
|
||||||
|
async fn call(
|
||||||
|
&self,
|
||||||
|
_: context::Context,
|
||||||
|
_: &'static str,
|
||||||
|
request: Self::Req,
|
||||||
|
) -> Result<Resp, RpcError> {
|
||||||
|
self.responses
|
||||||
|
.get(&request)
|
||||||
|
.cloned()
|
||||||
|
.map(Ok)
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
Err(RpcError::Server(ServerError {
|
||||||
|
kind: io::ErrorKind::NotFound,
|
||||||
|
detail: "mock (request, response) entry not found".into(),
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
56
tarpc/src/client/stub/retry.rs
Normal file
56
tarpc/src/client/stub/retry.rs
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
//! Provides a stub that retries requests based on response contents..
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
client::{stub, RpcError},
|
||||||
|
context,
|
||||||
|
};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
impl<Stub, Req, F> stub::Stub for Retry<F, Stub>
|
||||||
|
where
|
||||||
|
Stub: stub::Stub<Req = Arc<Req>>,
|
||||||
|
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool,
|
||||||
|
{
|
||||||
|
type Req = Req;
|
||||||
|
type Resp = Stub::Resp;
|
||||||
|
|
||||||
|
async fn call(
|
||||||
|
&self,
|
||||||
|
ctx: context::Context,
|
||||||
|
request_name: &'static str,
|
||||||
|
request: Self::Req,
|
||||||
|
) -> Result<Stub::Resp, RpcError> {
|
||||||
|
let request = Arc::new(request);
|
||||||
|
for i in 1.. {
|
||||||
|
let result = self
|
||||||
|
.stub
|
||||||
|
.call(ctx, request_name, Arc::clone(&request))
|
||||||
|
.await;
|
||||||
|
if (self.should_retry)(&result, i) {
|
||||||
|
tracing::trace!("Retrying on attempt {i}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
unreachable!("Wow, that was a lot of attempts!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Stub that retries requests based on response contents.
|
||||||
|
/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct Retry<F, Stub> {
|
||||||
|
should_retry: F,
|
||||||
|
stub: Stub,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Stub, Req, F> Retry<F, Stub>
|
||||||
|
where
|
||||||
|
Stub: stub::Stub<Req = Arc<Req>>,
|
||||||
|
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool,
|
||||||
|
{
|
||||||
|
/// Creates a new Retry stub that delegates calls to the underlying `stub`.
|
||||||
|
pub fn new(stub: Stub, should_retry: F) -> Self {
|
||||||
|
Self { stub, should_retry }
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -28,6 +28,8 @@ pub struct Context {
|
|||||||
/// When the client expects the request to be complete by. The server should cancel the request
|
/// When the client expects the request to be complete by. The server should cancel the request
|
||||||
/// if it is not complete by this time.
|
/// if it is not complete by this time.
|
||||||
#[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))]
|
#[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))]
|
||||||
|
// Serialized as a Duration to prevent clock skew issues.
|
||||||
|
#[cfg_attr(feature = "serde1", serde(with = "absolute_to_relative_time"))]
|
||||||
pub deadline: SystemTime,
|
pub deadline: SystemTime,
|
||||||
/// Uniquely identifies requests originating from the same source.
|
/// Uniquely identifies requests originating from the same source.
|
||||||
/// When a service handles a request by making requests itself, those requests should
|
/// When a service handles a request by making requests itself, those requests should
|
||||||
@@ -36,6 +38,54 @@ pub struct Context {
|
|||||||
pub trace_context: trace::Context,
|
pub trace_context: trace::Context,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "serde1")]
|
||||||
|
mod absolute_to_relative_time {
|
||||||
|
pub use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
|
pub use std::time::{Duration, SystemTime};
|
||||||
|
|
||||||
|
pub fn serialize<S>(deadline: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
let deadline = deadline
|
||||||
|
.duration_since(SystemTime::now())
|
||||||
|
.unwrap_or(Duration::ZERO);
|
||||||
|
deadline.serialize(serializer)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let deadline = Duration::deserialize(deserializer)?;
|
||||||
|
Ok(SystemTime::now() + deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize)]
|
||||||
|
struct AbsoluteToRelative(#[serde(with = "self")] SystemTime);
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_serialize() {
|
||||||
|
let now = SystemTime::now();
|
||||||
|
let deadline = now + Duration::from_secs(10);
|
||||||
|
let serialized_deadline = bincode::serialize(&AbsoluteToRelative(deadline)).unwrap();
|
||||||
|
let deserialized_deadline: Duration = bincode::deserialize(&serialized_deadline).unwrap();
|
||||||
|
// TODO: how to avoid flakiness?
|
||||||
|
assert!(deserialized_deadline > Duration::from_secs(9));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_deserialize() {
|
||||||
|
let deadline = Duration::from_secs(10);
|
||||||
|
let serialized_deadline = bincode::serialize(&deadline).unwrap();
|
||||||
|
let AbsoluteToRelative(deserialized_deadline) =
|
||||||
|
bincode::deserialize(&serialized_deadline).unwrap();
|
||||||
|
// TODO: how to avoid flakiness?
|
||||||
|
assert!(deserialized_deadline > SystemTime::now() + Duration::from_secs(9));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
assert_impl_all!(Context: Send, Sync);
|
assert_impl_all!(Context: Send, Sync);
|
||||||
|
|
||||||
fn ten_seconds_from_now() -> SystemTime {
|
fn ten_seconds_from_now() -> SystemTime {
|
||||||
|
|||||||
128
tarpc/src/lib.rs
128
tarpc/src/lib.rs
@@ -27,7 +27,7 @@
|
|||||||
//! process, and no context switching between different languages.
|
//! process, and no context switching between different languages.
|
||||||
//!
|
//!
|
||||||
//! Some other features of tarpc:
|
//! Some other features of tarpc:
|
||||||
//! - Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
|
//! - Pluggable transport: any type implementing `Stream<Item = Request> + Sink<Response>` can be
|
||||||
//! used as a transport to connect the client and server.
|
//! used as a transport to connect the client and server.
|
||||||
//! - `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
|
//! - `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
|
||||||
//! - Cascading cancellation: dropping a request will send a cancellation message to the server.
|
//! - Cascading cancellation: dropping a request will send a cancellation message to the server.
|
||||||
@@ -42,7 +42,7 @@
|
|||||||
//! [tracing](https://github.com/tokio-rs/tracing) primitives extended with
|
//! [tracing](https://github.com/tokio-rs/tracing) primitives extended with
|
||||||
//! [OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
|
//! [OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
|
||||||
//! [Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
|
//! [Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
|
||||||
//! each RPC can be traced through the client, server, amd other dependencies downstream of the
|
//! each RPC can be traced through the client, server, and other dependencies downstream of the
|
||||||
//! server. Even for applications not connected to a distributed tracing collector, the
|
//! server. Even for applications not connected to a distributed tracing collector, the
|
||||||
//! instrumentation can also be ingested by regular loggers like
|
//! instrumentation can also be ingested by regular loggers like
|
||||||
//! [env_logger](https://github.com/env-logger-rs/env_logger/).
|
//! [env_logger](https://github.com/env-logger-rs/env_logger/).
|
||||||
@@ -54,7 +54,7 @@
|
|||||||
//! Add to your `Cargo.toml` dependencies:
|
//! Add to your `Cargo.toml` dependencies:
|
||||||
//!
|
//!
|
||||||
//! ```toml
|
//! ```toml
|
||||||
//! tarpc = "0.27"
|
//! tarpc = "0.29"
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||||
@@ -69,7 +69,7 @@
|
|||||||
//! ```toml
|
//! ```toml
|
||||||
//! anyhow = "1.0"
|
//! anyhow = "1.0"
|
||||||
//! futures = "0.3"
|
//! futures = "0.3"
|
||||||
//! tarpc = { version = "0.27", features = ["tokio1"] }
|
//! tarpc = { version = "0.29", features = ["tokio1"] }
|
||||||
//! tokio = { version = "1.0", features = ["macros"] }
|
//! tokio = { version = "1.0", features = ["macros"] }
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
@@ -88,7 +88,7 @@
|
|||||||
//! };
|
//! };
|
||||||
//! use tarpc::{
|
//! use tarpc::{
|
||||||
//! client, context,
|
//! client, context,
|
||||||
//! server::{self, incoming::Incoming},
|
//! server::{self, incoming::Incoming, Channel},
|
||||||
//! };
|
//! };
|
||||||
//!
|
//!
|
||||||
//! // This is the service definition. It looks a lot like a trait definition.
|
//! // This is the service definition. It looks a lot like a trait definition.
|
||||||
@@ -126,13 +126,9 @@
|
|||||||
//! struct HelloServer;
|
//! struct HelloServer;
|
||||||
//!
|
//!
|
||||||
//! impl World for HelloServer {
|
//! impl World for HelloServer {
|
||||||
//! // Each defined rpc generates two items in the trait, a fn that serves the RPC, and
|
//! // Each defined rpc generates an async fn that serves the RPC
|
||||||
//! // an associated type representing the future output by the fn.
|
//! async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
//!
|
//! format!("Hello, {name}!")
|
||||||
//! type HelloFut = Ready<String>;
|
|
||||||
//!
|
|
||||||
//! fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
|
||||||
//! future::ready(format!("Hello, {}!", name))
|
|
||||||
//! }
|
//! }
|
||||||
//! }
|
//! }
|
||||||
//! ```
|
//! ```
|
||||||
@@ -164,11 +160,9 @@
|
|||||||
//! # #[derive(Clone)]
|
//! # #[derive(Clone)]
|
||||||
//! # struct HelloServer;
|
//! # struct HelloServer;
|
||||||
//! # impl World for HelloServer {
|
//! # impl World for HelloServer {
|
||||||
//! # // Each defined rpc generates two items in the trait, a fn that serves the RPC, and
|
//! // Each defined rpc generates an async fn that serves the RPC
|
||||||
//! # // an associated type representing the future output by the fn.
|
//! # async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
//! # type HelloFut = Ready<String>;
|
//! # format!("Hello, {name}!")
|
||||||
//! # fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
|
||||||
//! # future::ready(format!("Hello, {}!", name))
|
|
||||||
//! # }
|
//! # }
|
||||||
//! # }
|
//! # }
|
||||||
//! # #[cfg(not(feature = "tokio1"))]
|
//! # #[cfg(not(feature = "tokio1"))]
|
||||||
@@ -179,7 +173,12 @@
|
|||||||
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||||
//!
|
//!
|
||||||
//! let server = server::BaseChannel::with_defaults(server_transport);
|
//! let server = server::BaseChannel::with_defaults(server_transport);
|
||||||
//! tokio::spawn(server.execute(HelloServer.serve()));
|
//! tokio::spawn(
|
||||||
|
//! server.execute(HelloServer.serve())
|
||||||
|
//! // Handle all requests concurrently.
|
||||||
|
//! .for_each(|response| async move {
|
||||||
|
//! tokio::spawn(response);
|
||||||
|
//! }));
|
||||||
//!
|
//!
|
||||||
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||||
//! // that takes a config and any Transport as input.
|
//! // that takes a config and any Transport as input.
|
||||||
@@ -190,7 +189,7 @@
|
|||||||
//! // specifies a deadline and trace information which can be helpful in debugging requests.
|
//! // specifies a deadline and trace information which can be helpful in debugging requests.
|
||||||
//! let hello = client.hello(context::current(), "Stim".to_string()).await?;
|
//! let hello = client.hello(context::current(), "Stim".to_string()).await?;
|
||||||
//!
|
//!
|
||||||
//! println!("{}", hello);
|
//! println!("{hello}");
|
||||||
//!
|
//!
|
||||||
//! Ok(())
|
//! Ok(())
|
||||||
//! }
|
//! }
|
||||||
@@ -200,6 +199,7 @@
|
|||||||
//!
|
//!
|
||||||
//! Use `cargo doc` as you normally would to see the documentation created for all
|
//! Use `cargo doc` as you normally would to see the documentation created for all
|
||||||
//! items expanded by a `service!` invocation.
|
//! items expanded by a `service!` invocation.
|
||||||
|
|
||||||
#![deny(missing_docs)]
|
#![deny(missing_docs)]
|
||||||
#![allow(clippy::type_complexity)]
|
#![allow(clippy::type_complexity)]
|
||||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||||
@@ -209,7 +209,7 @@
|
|||||||
pub use serde;
|
pub use serde;
|
||||||
|
|
||||||
#[cfg(feature = "serde-transport")]
|
#[cfg(feature = "serde-transport")]
|
||||||
pub use tokio_serde;
|
pub use {tokio_serde, tokio_util};
|
||||||
|
|
||||||
#[cfg(feature = "serde-transport")]
|
#[cfg(feature = "serde-transport")]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "serde-transport")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "serde-transport")))]
|
||||||
@@ -244,62 +244,7 @@ pub use tarpc_plugins::derive_serde;
|
|||||||
/// * `fn new_stub` -- creates a new Client stub.
|
/// * `fn new_stub` -- creates a new Client stub.
|
||||||
pub use tarpc_plugins::service;
|
pub use tarpc_plugins::service;
|
||||||
|
|
||||||
/// A utility macro that can be used for RPC server implementations.
|
pub(crate) mod cancellations;
|
||||||
///
|
|
||||||
/// Syntactic sugar to make using async functions in the server implementation
|
|
||||||
/// easier. It does this by rewriting code like this, which would normally not
|
|
||||||
/// compile because async functions are disallowed in trait implementations:
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # use tarpc::context;
|
|
||||||
/// # use std::net::SocketAddr;
|
|
||||||
/// #[tarpc::service]
|
|
||||||
/// trait World {
|
|
||||||
/// async fn hello(name: String) -> String;
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// #[derive(Clone)]
|
|
||||||
/// struct HelloServer(SocketAddr);
|
|
||||||
///
|
|
||||||
/// #[tarpc::server]
|
|
||||||
/// impl World for HelloServer {
|
|
||||||
/// async fn hello(self, _: context::Context, name: String) -> String {
|
|
||||||
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// Into code like this, which matches the service trait definition:
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # use tarpc::context;
|
|
||||||
/// # use std::pin::Pin;
|
|
||||||
/// # use futures::Future;
|
|
||||||
/// # use std::net::SocketAddr;
|
|
||||||
/// #[derive(Clone)]
|
|
||||||
/// struct HelloServer(SocketAddr);
|
|
||||||
///
|
|
||||||
/// #[tarpc::service]
|
|
||||||
/// trait World {
|
|
||||||
/// async fn hello(name: String) -> String;
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// impl World for HelloServer {
|
|
||||||
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
|
|
||||||
///
|
|
||||||
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
|
|
||||||
/// + Send>> {
|
|
||||||
/// Box::pin(async move {
|
|
||||||
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
|
||||||
/// })
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// Note that this won't touch functions unless they have been annotated with
|
|
||||||
/// `async`, meaning that this should not break existing code.
|
|
||||||
pub use tarpc_plugins::server;
|
|
||||||
|
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod context;
|
pub mod context;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
@@ -310,6 +255,7 @@ pub use crate::transport::sealed::Transport;
|
|||||||
|
|
||||||
use anyhow::Context as _;
|
use anyhow::Context as _;
|
||||||
use futures::task::*;
|
use futures::task::*;
|
||||||
|
use std::sync::Arc;
|
||||||
use std::{error::Error, fmt::Display, io, time::SystemTime};
|
use std::{error::Error, fmt::Display, io, time::SystemTime};
|
||||||
|
|
||||||
/// A message from a client to a server.
|
/// A message from a client to a server.
|
||||||
@@ -382,6 +328,36 @@ pub struct ServerError {
|
|||||||
pub detail: String,
|
pub detail: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Critical errors that result in a Channel disconnecting.
|
||||||
|
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
|
||||||
|
pub enum ChannelError<E>
|
||||||
|
where
|
||||||
|
E: Error + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
/// Could not read from the transport.
|
||||||
|
#[error("could not read from the transport")]
|
||||||
|
Read(#[source] Arc<E>),
|
||||||
|
/// Could not ready the transport for writes.
|
||||||
|
#[error("could not ready the transport for writes")]
|
||||||
|
Ready(#[source] E),
|
||||||
|
/// Could not write to the transport.
|
||||||
|
#[error("could not write to the transport")]
|
||||||
|
Write(#[source] E),
|
||||||
|
/// Could not flush the transport.
|
||||||
|
#[error("could not flush the transport")]
|
||||||
|
Flush(#[source] E),
|
||||||
|
/// Could not close the write end of the transport.
|
||||||
|
#[error("could not close the write end of the transport")]
|
||||||
|
Close(#[source] E),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerError {
|
||||||
|
/// Returns a new server error with `kind` and `detail`.
|
||||||
|
pub fn new(kind: io::ErrorKind, detail: String) -> ServerError {
|
||||||
|
Self { kind, detail }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T> Request<T> {
|
impl<T> Request<T> {
|
||||||
/// Returns the deadline for this request.
|
/// Returns the deadline for this request.
|
||||||
pub fn deadline(&self) -> &SystemTime {
|
pub fn deadline(&self) -> &SystemTime {
|
||||||
|
|||||||
@@ -129,14 +129,6 @@ pub mod tcp {
|
|||||||
tokio_util::codec::length_delimited,
|
tokio_util::codec::length_delimited,
|
||||||
};
|
};
|
||||||
|
|
||||||
mod private {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
pub trait Sealed {}
|
|
||||||
|
|
||||||
impl<Item, SinkItem, Codec> Sealed for Transport<TcpStream, Item, SinkItem, Codec> {}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<Item, SinkItem, Codec> Transport<TcpStream, Item, SinkItem, Codec> {
|
impl<Item, SinkItem, Codec> Transport<TcpStream, Item, SinkItem, Codec> {
|
||||||
/// Returns the peer address of the underlying TcpStream.
|
/// Returns the peer address of the underlying TcpStream.
|
||||||
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
|
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
|
||||||
@@ -149,6 +141,7 @@ pub mod tcp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A connection Future that also exposes the length-delimited framing config.
|
/// A connection Future that also exposes the length-delimited framing config.
|
||||||
|
#[must_use]
|
||||||
#[pin_project]
|
#[pin_project]
|
||||||
pub struct Connect<T, Item, SinkItem, CodecFn> {
|
pub struct Connect<T, Item, SinkItem, CodecFn> {
|
||||||
#[pin]
|
#[pin]
|
||||||
@@ -276,6 +269,270 @@ pub mod tcp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(all(unix, feature = "unix"))]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(all(unix, feature = "unix"))))]
|
||||||
|
/// Unix Domain Socket support for generic transport using Tokio.
|
||||||
|
pub mod unix {
|
||||||
|
use {
|
||||||
|
super::*,
|
||||||
|
futures::ready,
|
||||||
|
std::{marker::PhantomData, path::Path},
|
||||||
|
tokio::net::{unix::SocketAddr, UnixListener, UnixStream},
|
||||||
|
tokio_util::codec::length_delimited,
|
||||||
|
};
|
||||||
|
|
||||||
|
impl<Item, SinkItem, Codec> Transport<UnixStream, Item, SinkItem, Codec> {
|
||||||
|
/// Returns the socket address of the remote half of the underlying [`UnixStream`].
|
||||||
|
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
|
||||||
|
self.inner.get_ref().get_ref().peer_addr()
|
||||||
|
}
|
||||||
|
/// Returns the socket address of the local half of the underlying [`UnixStream`].
|
||||||
|
pub fn local_addr(&self) -> io::Result<SocketAddr> {
|
||||||
|
self.inner.get_ref().get_ref().local_addr()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A connection Future that also exposes the length-delimited framing config.
|
||||||
|
#[must_use]
|
||||||
|
#[pin_project]
|
||||||
|
pub struct Connect<T, Item, SinkItem, CodecFn> {
|
||||||
|
#[pin]
|
||||||
|
inner: T,
|
||||||
|
codec_fn: CodecFn,
|
||||||
|
config: length_delimited::Builder,
|
||||||
|
ghost: PhantomData<(fn(SinkItem), fn() -> Item)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, Item, SinkItem, Codec, CodecFn> Future for Connect<T, Item, SinkItem, CodecFn>
|
||||||
|
where
|
||||||
|
T: Future<Output = io::Result<UnixStream>>,
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
SinkItem: Serialize,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
CodecFn: Fn() -> Codec,
|
||||||
|
{
|
||||||
|
type Output = io::Result<Transport<UnixStream, Item, SinkItem, Codec>>;
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
||||||
|
let io = ready!(self.as_mut().project().inner.poll(cx))?;
|
||||||
|
Poll::Ready(Ok(new(self.config.new_framed(io), (self.codec_fn)())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, Item, SinkItem, CodecFn> Connect<T, Item, SinkItem, CodecFn> {
|
||||||
|
/// Returns an immutable reference to the length-delimited codec's config.
|
||||||
|
pub fn config(&self) -> &length_delimited::Builder {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable reference to the length-delimited codec's config.
|
||||||
|
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
|
||||||
|
&mut self.config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Connects to socket named by `path`, wrapping the connection in a Unix Domain Socket
|
||||||
|
/// transport.
|
||||||
|
pub fn connect<P, Item, SinkItem, Codec, CodecFn>(
|
||||||
|
path: P,
|
||||||
|
codec_fn: CodecFn,
|
||||||
|
) -> Connect<impl Future<Output = io::Result<UnixStream>>, Item, SinkItem, CodecFn>
|
||||||
|
where
|
||||||
|
P: AsRef<Path>,
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
SinkItem: Serialize,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
CodecFn: Fn() -> Codec,
|
||||||
|
{
|
||||||
|
Connect {
|
||||||
|
inner: UnixStream::connect(path),
|
||||||
|
codec_fn,
|
||||||
|
config: LengthDelimitedCodec::builder(),
|
||||||
|
ghost: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Listens on the socket named by `path`, wrapping accepted connections in Unix Domain Socket
|
||||||
|
/// transports.
|
||||||
|
pub async fn listen<P, Item, SinkItem, Codec, CodecFn>(
|
||||||
|
path: P,
|
||||||
|
codec_fn: CodecFn,
|
||||||
|
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
|
||||||
|
where
|
||||||
|
P: AsRef<Path>,
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
CodecFn: Fn() -> Codec,
|
||||||
|
{
|
||||||
|
let listener = UnixListener::bind(path)?;
|
||||||
|
let local_addr = listener.local_addr()?;
|
||||||
|
Ok(Incoming {
|
||||||
|
listener,
|
||||||
|
codec_fn,
|
||||||
|
local_addr,
|
||||||
|
config: LengthDelimitedCodec::builder(),
|
||||||
|
ghost: PhantomData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A [`UnixListener`] that wraps connections in [transports](Transport).
|
||||||
|
#[pin_project]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Incoming<Item, SinkItem, Codec, CodecFn> {
|
||||||
|
listener: UnixListener,
|
||||||
|
local_addr: SocketAddr,
|
||||||
|
codec_fn: CodecFn,
|
||||||
|
config: length_delimited::Builder,
|
||||||
|
ghost: PhantomData<(fn() -> Item, fn(SinkItem), Codec)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Item, SinkItem, Codec, CodecFn> Incoming<Item, SinkItem, Codec, CodecFn> {
|
||||||
|
/// Returns the the socket address being listened on.
|
||||||
|
pub fn local_addr(&self) -> &SocketAddr {
|
||||||
|
&self.local_addr
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an immutable reference to the length-delimited codec's config.
|
||||||
|
pub fn config(&self) -> &length_delimited::Builder {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a mutable reference to the length-delimited codec's config.
|
||||||
|
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
|
||||||
|
&mut self.config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Item, SinkItem, Codec, CodecFn> Stream for Incoming<Item, SinkItem, Codec, CodecFn>
|
||||||
|
where
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
SinkItem: Serialize,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
CodecFn: Fn() -> Codec,
|
||||||
|
{
|
||||||
|
type Item = io::Result<Transport<UnixStream, Item, SinkItem, Codec>>;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
let conn: UnixStream = ready!(self.as_mut().project().listener.poll_accept(cx)?).0;
|
||||||
|
Poll::Ready(Some(Ok(new(
|
||||||
|
self.config.new_framed(conn),
|
||||||
|
(self.codec_fn)(),
|
||||||
|
))))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A temporary `PathBuf` that lives in `std::env::temp_dir` and is removed on drop.
|
||||||
|
pub struct TempPathBuf(std::path::PathBuf);
|
||||||
|
|
||||||
|
impl TempPathBuf {
|
||||||
|
/// A named socket that results in `<tempdir>/<name>`
|
||||||
|
pub fn new<S: AsRef<str>>(name: S) -> Self {
|
||||||
|
let mut sock = std::env::temp_dir();
|
||||||
|
sock.push(name.as_ref());
|
||||||
|
Self(sock)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Appends a random hex string to the socket name resulting in
|
||||||
|
/// `<tempdir>/<name>_<xxxxx>`
|
||||||
|
pub fn with_random<S: AsRef<str>>(name: S) -> Self {
|
||||||
|
Self::new(format!("{}_{:016x}", name.as_ref(), rand::random::<u64>()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsRef<std::path::Path> for TempPathBuf {
|
||||||
|
fn as_ref(&self) -> &std::path::Path {
|
||||||
|
self.0.as_path()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for TempPathBuf {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
// This will remove the file pointed to by this PathBuf if it exists, however Err's can
|
||||||
|
// be returned such as attempting to remove a non-existing file, or one which we don't
|
||||||
|
// have permission to remove. In these cases the Err is swallowed
|
||||||
|
let _ = std::fs::remove_file(&self.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio_serde::formats::SymmetricalJson;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn temp_path_buf_non_random() {
|
||||||
|
let sock = TempPathBuf::new("test");
|
||||||
|
let mut good = std::env::temp_dir();
|
||||||
|
good.push("test");
|
||||||
|
assert_eq!(sock.as_ref(), good);
|
||||||
|
assert_eq!(sock.as_ref().file_name().unwrap(), "test");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn temp_path_buf_random() {
|
||||||
|
let sock = TempPathBuf::with_random("test");
|
||||||
|
let good = std::env::temp_dir();
|
||||||
|
assert!(sock.as_ref().starts_with(good));
|
||||||
|
// Since there are 16 random characters we just assert the file_name has the right name
|
||||||
|
// and starts with the correct string 'test_'
|
||||||
|
// file name: test_xxxxxxxxxxxxxxxx
|
||||||
|
// test = 4
|
||||||
|
// _ = 1
|
||||||
|
// <hex> = 16
|
||||||
|
// total = 21
|
||||||
|
let fname = sock.as_ref().file_name().unwrap().to_string_lossy();
|
||||||
|
assert!(fname.starts_with("test_"));
|
||||||
|
assert_eq!(fname.len(), 21);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn temp_path_buf_non_existing() {
|
||||||
|
let sock = TempPathBuf::with_random("test");
|
||||||
|
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||||
|
|
||||||
|
// No actual file has been created yet
|
||||||
|
assert!(!sock_path.exists());
|
||||||
|
// Should not panic
|
||||||
|
std::mem::drop(sock);
|
||||||
|
assert!(!sock_path.exists());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn temp_path_buf_existing_file() {
|
||||||
|
let sock = TempPathBuf::with_random("test");
|
||||||
|
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||||
|
let _file = std::fs::File::create(&sock).unwrap();
|
||||||
|
assert!(sock_path.exists());
|
||||||
|
std::mem::drop(sock);
|
||||||
|
assert!(!sock_path.exists());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn temp_path_buf_preexisting_file() {
|
||||||
|
let mut pre_existing = std::env::temp_dir();
|
||||||
|
pre_existing.push("test");
|
||||||
|
let _file = std::fs::File::create(&pre_existing).unwrap();
|
||||||
|
let sock = TempPathBuf::new("test");
|
||||||
|
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||||
|
assert!(sock_path.exists());
|
||||||
|
std::mem::drop(sock);
|
||||||
|
assert!(!sock_path.exists());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn temp_path_buf_for_socket() {
|
||||||
|
let sock = TempPathBuf::with_random("test");
|
||||||
|
// Save path for testing after drop
|
||||||
|
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||||
|
// create the actual socket
|
||||||
|
let _ = listen(&sock, SymmetricalJson::<String>::default).await;
|
||||||
|
assert!(sock_path.exists());
|
||||||
|
std::mem::drop(sock);
|
||||||
|
assert!(!sock_path.exists());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::Transport;
|
use super::Transport;
|
||||||
@@ -290,7 +547,7 @@ mod tests {
|
|||||||
use tokio_serde::formats::SymmetricalJson;
|
use tokio_serde::formats::SymmetricalJson;
|
||||||
|
|
||||||
fn ctx() -> Context<'static> {
|
fn ctx() -> Context<'static> {
|
||||||
Context::from_waker(&noop_waker_ref())
|
Context::from_waker(noop_waker_ref())
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TestIo(Cursor<Vec<u8>>);
|
struct TestIo(Cursor<Vec<u8>>);
|
||||||
@@ -392,4 +649,24 @@ mod tests {
|
|||||||
assert_matches!(transport.next().await, None);
|
assert_matches!(transport.next().await, None);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(all(unix, feature = "unix"))]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn uds() -> io::Result<()> {
|
||||||
|
use super::unix;
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
let sock = unix::TempPathBuf::with_random("uds");
|
||||||
|
let mut listener = unix::listen(&sock, SymmetricalJson::<String>::default).await?;
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut transport = listener.next().await.unwrap().unwrap();
|
||||||
|
let message = transport.next().await.unwrap().unwrap();
|
||||||
|
transport.send(message).await.unwrap();
|
||||||
|
});
|
||||||
|
let mut transport = unix::connect(&sock, SymmetricalJson::<String>::default).await?;
|
||||||
|
transport.send(String::from("test")).await?;
|
||||||
|
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
|
||||||
|
assert_matches!(transport.next().await, None);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -94,16 +94,14 @@ impl InFlightRequests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Yields a request that has expired, aborting any ongoing processing of that request.
|
/// Yields a request that has expired, aborting any ongoing processing of that request.
|
||||||
pub fn poll_expired(
|
pub fn poll_expired(&mut self, cx: &mut Context) -> Poll<Option<u64>> {
|
||||||
&mut self,
|
|
||||||
cx: &mut Context,
|
|
||||||
) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
|
|
||||||
if self.deadlines.is_empty() {
|
if self.deadlines.is_empty() {
|
||||||
// TODO(https://github.com/tokio-rs/tokio/issues/4161)
|
// TODO(https://github.com/tokio-rs/tokio/issues/4161)
|
||||||
// This is a workaround for DelayQueue not always treating this case correctly.
|
// This is a workaround for DelayQueue not always treating this case correctly.
|
||||||
return Poll::Ready(None);
|
return Poll::Ready(None);
|
||||||
}
|
}
|
||||||
self.deadlines.poll_expired(cx).map_ok(|expired| {
|
self.deadlines.poll_expired(cx).map(|expired| {
|
||||||
|
let expired = expired?;
|
||||||
if let Some(RequestData {
|
if let Some(RequestData {
|
||||||
abort_handle, span, ..
|
abort_handle, span, ..
|
||||||
}) = self.request_data.remove(expired.get_ref())
|
}) = self.request_data.remove(expired.get_ref())
|
||||||
@@ -113,7 +111,7 @@ impl InFlightRequests {
|
|||||||
abort_handle.abort();
|
abort_handle.abort();
|
||||||
tracing::error!("DeadlineExceeded");
|
tracing::error!("DeadlineExceeded");
|
||||||
}
|
}
|
||||||
expired.into_inner()
|
Some(expired.into_inner())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -161,7 +159,7 @@ mod tests {
|
|||||||
|
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
in_flight_requests.poll_expired(&mut noop_context()),
|
in_flight_requests.poll_expired(&mut noop_context()),
|
||||||
Poll::Ready(Some(Ok(_)))
|
Poll::Ready(Some(_))
|
||||||
);
|
);
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
abortable_future.poll_unpin(&mut noop_context()),
|
abortable_future.poll_unpin(&mut noop_context()),
|
||||||
@@ -178,7 +176,7 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||||
|
|
||||||
assert_eq!(in_flight_requests.cancel_request(0), true);
|
assert!(in_flight_requests.cancel_request(0));
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
abortable_future.poll_unpin(&mut noop_context()),
|
abortable_future.poll_unpin(&mut noop_context()),
|
||||||
Poll::Ready(Err(_))
|
Poll::Ready(Err(_))
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
use super::{
|
use super::{
|
||||||
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
|
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
|
||||||
Channel,
|
Channel, Serve,
|
||||||
};
|
};
|
||||||
use futures::prelude::*;
|
use futures::prelude::*;
|
||||||
use std::{fmt, hash::Hash};
|
use std::{fmt, hash::Hash};
|
||||||
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
use super::{tokio::TokioServerExecutor, Serve};
|
|
||||||
|
|
||||||
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
|
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
|
||||||
pub trait Incoming<C>
|
pub trait Incoming<C>
|
||||||
where
|
where
|
||||||
@@ -28,16 +25,62 @@ where
|
|||||||
MaxRequestsPerChannel::new(self, n)
|
MaxRequestsPerChannel::new(self, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
|
/// Returns a stream of channels in execution. Each channel in execution is a stream of
|
||||||
/// concurrently by spawning on tokio's default executor, and each request will be also
|
/// futures, where each future is an in-flight request being rsponded to.
|
||||||
/// be spawned on tokio's default executor.
|
fn execute<S>(
|
||||||
#[cfg(feature = "tokio1")]
|
self,
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
serve: S,
|
||||||
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
|
) -> impl Stream<Item = impl Stream<Item = impl Future<Output = ()>>>
|
||||||
where
|
where
|
||||||
S: Serve<C::Req, Resp = C::Resp>,
|
S: Serve<Req = C::Req, Resp = C::Resp> + Clone,
|
||||||
{
|
{
|
||||||
TokioServerExecutor::new(self, serve)
|
self.map(move |channel| channel.execute(serve.clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "tokio1")]
|
||||||
|
/// Spawns all channels-in-execution, delegating to the tokio runtime to manage their completion.
|
||||||
|
/// Each channel is spawned, and each request from each channel is spawned.
|
||||||
|
/// Note that this function is generic over any stream-of-streams-of-futures, but it is intended
|
||||||
|
/// for spawning streams of channels.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```rust
|
||||||
|
/// use tarpc::{
|
||||||
|
/// context,
|
||||||
|
/// client::{self, NewClient},
|
||||||
|
/// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve},
|
||||||
|
/// transport,
|
||||||
|
/// };
|
||||||
|
/// use futures::prelude::*;
|
||||||
|
///
|
||||||
|
/// #[tokio::main]
|
||||||
|
/// async fn main() {
|
||||||
|
/// let (tx, rx) = transport::channel::unbounded();
|
||||||
|
/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx);
|
||||||
|
/// tokio::spawn(dispatch);
|
||||||
|
///
|
||||||
|
/// let incoming = stream::once(async move {
|
||||||
|
/// BaseChannel::new(server::Config::default(), rx)
|
||||||
|
/// }).execute(serve(|_, i| async move { Ok(i + 1) }));
|
||||||
|
/// tokio::spawn(spawn_incoming(incoming));
|
||||||
|
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub async fn spawn_incoming(
|
||||||
|
incoming: impl Stream<
|
||||||
|
Item = impl Stream<Item = impl Future<Output = ()> + Send + 'static> + Send + 'static,
|
||||||
|
>,
|
||||||
|
) {
|
||||||
|
use futures::pin_mut;
|
||||||
|
pin_mut!(incoming);
|
||||||
|
while let Some(channel) = incoming.next().await {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
pin_mut!(channel);
|
||||||
|
while let Some(request) = channel.next().await {
|
||||||
|
tokio::spawn(request);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -282,7 +282,7 @@ where
|
|||||||
fn ctx() -> Context<'static> {
|
fn ctx() -> Context<'static> {
|
||||||
use futures::task::*;
|
use futures::task::*;
|
||||||
|
|
||||||
Context::from_waker(&noop_waker_ref())
|
Context::from_waker(noop_waker_ref())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
25
tarpc/src/server/request_hook.rs
Normal file
25
tarpc/src/server/request_hook.rs
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
// Copyright 2022 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
//! Hooks for horizontal functionality that can run either before or after a request is executed.
|
||||||
|
|
||||||
|
/// A request hook that runs before a request is executed.
|
||||||
|
mod before;
|
||||||
|
|
||||||
|
/// A request hook that runs after a request is completed.
|
||||||
|
mod after;
|
||||||
|
|
||||||
|
/// A request hook that runs both before a request is executed and after it is completed.
|
||||||
|
mod before_and_after;
|
||||||
|
|
||||||
|
pub use {
|
||||||
|
after::{AfterRequest, ServeThenHook},
|
||||||
|
before::{
|
||||||
|
before, BeforeRequest, BeforeRequestCons, BeforeRequestList, BeforeRequestNil,
|
||||||
|
HookThenServe,
|
||||||
|
},
|
||||||
|
before_and_after::HookThenServeThenHook,
|
||||||
|
};
|
||||||
72
tarpc/src/server/request_hook/after.rs
Normal file
72
tarpc/src/server/request_hook/after.rs
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
// Copyright 2022 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
//! Provides a hook that runs after request execution.
|
||||||
|
|
||||||
|
use crate::{context, server::Serve, ServerError};
|
||||||
|
use futures::prelude::*;
|
||||||
|
|
||||||
|
/// A hook that runs after request execution.
|
||||||
|
#[allow(async_fn_in_trait)]
|
||||||
|
pub trait AfterRequest<Resp> {
|
||||||
|
/// The function that is called after request execution.
|
||||||
|
///
|
||||||
|
/// The hook can modify the request context and the response.
|
||||||
|
async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F, Fut, Resp> AfterRequest<Resp> for F
|
||||||
|
where
|
||||||
|
F: FnMut(&mut context::Context, &mut Result<Resp, ServerError>) -> Fut,
|
||||||
|
Fut: Future<Output = ()>,
|
||||||
|
{
|
||||||
|
async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>) {
|
||||||
|
self(ctx, resp).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Service function that runs a hook after request execution.
|
||||||
|
pub struct ServeThenHook<Serv, Hook> {
|
||||||
|
serve: Serv,
|
||||||
|
hook: Hook,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Serv, Hook> ServeThenHook<Serv, Hook> {
|
||||||
|
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
|
||||||
|
Self { serve, hook }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Serv: Clone, Hook: Clone> Clone for ServeThenHook<Serv, Hook> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
serve: self.serve.clone(),
|
||||||
|
hook: self.hook.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Serv, Hook> Serve for ServeThenHook<Serv, Hook>
|
||||||
|
where
|
||||||
|
Serv: Serve,
|
||||||
|
Hook: AfterRequest<Serv::Resp>,
|
||||||
|
{
|
||||||
|
type Req = Serv::Req;
|
||||||
|
type Resp = Serv::Resp;
|
||||||
|
|
||||||
|
async fn serve(
|
||||||
|
self,
|
||||||
|
mut ctx: context::Context,
|
||||||
|
req: Serv::Req,
|
||||||
|
) -> Result<Serv::Resp, ServerError> {
|
||||||
|
let ServeThenHook {
|
||||||
|
serve, mut hook, ..
|
||||||
|
} = self;
|
||||||
|
let mut resp = serve.serve(ctx, req).await;
|
||||||
|
hook.after(&mut ctx, &mut resp).await;
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
}
|
||||||
210
tarpc/src/server/request_hook/before.rs
Normal file
210
tarpc/src/server/request_hook/before.rs
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
// Copyright 2022 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
//! Provides a hook that runs before request execution.
|
||||||
|
|
||||||
|
use crate::{context, server::Serve, ServerError};
|
||||||
|
use futures::prelude::*;
|
||||||
|
|
||||||
|
/// A hook that runs before request execution.
|
||||||
|
#[allow(async_fn_in_trait)]
|
||||||
|
pub trait BeforeRequest<Req> {
|
||||||
|
/// The function that is called before request execution.
|
||||||
|
///
|
||||||
|
/// If this function returns an error, the request will not be executed and the error will be
|
||||||
|
/// returned instead.
|
||||||
|
///
|
||||||
|
/// This function can also modify the request context. This could be used, for example, to
|
||||||
|
/// enforce a maximum deadline on all requests.
|
||||||
|
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A list of hooks that run in order before request execution.
|
||||||
|
pub trait BeforeRequestList<Req>: BeforeRequest<Req> {
|
||||||
|
/// The hook returned by `BeforeRequestList::then`.
|
||||||
|
type Then<Next>: BeforeRequest<Req>
|
||||||
|
where
|
||||||
|
Next: BeforeRequest<Req>;
|
||||||
|
|
||||||
|
/// Returns a hook that, when run, runs two hooks, first `self` and then `next`.
|
||||||
|
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next>;
|
||||||
|
|
||||||
|
/// Same as `then`, but helps the compiler with type inference when Next is a closure.
|
||||||
|
fn then_fn<
|
||||||
|
Next: FnMut(&mut context::Context, &Req) -> Fut,
|
||||||
|
Fut: Future<Output = Result<(), ServerError>>,
|
||||||
|
>(
|
||||||
|
self,
|
||||||
|
next: Next,
|
||||||
|
) -> Self::Then<Next>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
self.then(next)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The service fn returned by `BeforeRequestList::serving`.
|
||||||
|
type Serve<S: Serve<Req = Req>>: Serve<Req = Req>;
|
||||||
|
|
||||||
|
/// Runs the list of request hooks before execution of the given serve fn.
|
||||||
|
/// This is equivalent to `serve.before(before_request_chain)` but may be syntactically nicer.
|
||||||
|
fn serving<S: Serve<Req = Req>>(self, serve: S) -> Self::Serve<S>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F, Fut, Req> BeforeRequest<Req> for F
|
||||||
|
where
|
||||||
|
F: FnMut(&mut context::Context, &Req) -> Fut,
|
||||||
|
Fut: Future<Output = Result<(), ServerError>>,
|
||||||
|
{
|
||||||
|
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> {
|
||||||
|
self(ctx, req).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Service function that runs a hook before request execution.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct HookThenServe<Serv, Hook> {
|
||||||
|
serve: Serv,
|
||||||
|
hook: Hook,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Serv, Hook> HookThenServe<Serv, Hook> {
|
||||||
|
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
|
||||||
|
Self { serve, hook }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Serv, Hook> Serve for HookThenServe<Serv, Hook>
|
||||||
|
where
|
||||||
|
Serv: Serve,
|
||||||
|
Hook: BeforeRequest<Serv::Req>,
|
||||||
|
{
|
||||||
|
type Req = Serv::Req;
|
||||||
|
type Resp = Serv::Resp;
|
||||||
|
|
||||||
|
async fn serve(
|
||||||
|
self,
|
||||||
|
mut ctx: context::Context,
|
||||||
|
req: Self::Req,
|
||||||
|
) -> Result<Serv::Resp, ServerError> {
|
||||||
|
let HookThenServe {
|
||||||
|
serve, mut hook, ..
|
||||||
|
} = self;
|
||||||
|
hook.before(&mut ctx, &req).await?;
|
||||||
|
serve.serve(ctx, req).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a request hook builder that runs a series of hooks before request execution.
|
||||||
|
///
|
||||||
|
/// Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use futures::{executor::block_on, future};
|
||||||
|
/// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self,
|
||||||
|
/// BeforeRequest, BeforeRequestList}}};
|
||||||
|
/// use std::{cell::Cell, io};
|
||||||
|
///
|
||||||
|
/// let i = Cell::new(0);
|
||||||
|
/// let serve = request_hook::before()
|
||||||
|
/// .then_fn(|_, _| async {
|
||||||
|
/// assert!(i.get() == 0);
|
||||||
|
/// i.set(1);
|
||||||
|
/// Ok(())
|
||||||
|
/// })
|
||||||
|
/// .then_fn(|_, _| async {
|
||||||
|
/// assert!(i.get() == 1);
|
||||||
|
/// i.set(2);
|
||||||
|
/// Ok(())
|
||||||
|
/// })
|
||||||
|
/// .serving(serve(|_ctx, i| async move { Ok(i + 1) }));
|
||||||
|
/// let response = serve.clone().serve(context::current(), 1);
|
||||||
|
/// assert!(block_on(response).is_ok());
|
||||||
|
/// assert!(i.get() == 2);
|
||||||
|
/// ```
|
||||||
|
pub fn before() -> BeforeRequestNil {
|
||||||
|
BeforeRequestNil
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A list of hooks that run in order before a request is executed.
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct BeforeRequestCons<First, Rest>(First, Rest);
|
||||||
|
|
||||||
|
/// A noop hook that runs before a request is executed.
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct BeforeRequestNil;
|
||||||
|
|
||||||
|
impl<Req, First: BeforeRequest<Req>, Rest: BeforeRequest<Req>> BeforeRequest<Req>
|
||||||
|
for BeforeRequestCons<First, Rest>
|
||||||
|
{
|
||||||
|
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> {
|
||||||
|
let BeforeRequestCons(first, rest) = self;
|
||||||
|
first.before(ctx, req).await?;
|
||||||
|
rest.before(ctx, req).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req> BeforeRequest<Req> for BeforeRequestNil {
|
||||||
|
async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, First: BeforeRequest<Req>, Rest: BeforeRequestList<Req>> BeforeRequestList<Req>
|
||||||
|
for BeforeRequestCons<First, Rest>
|
||||||
|
{
|
||||||
|
type Then<Next> = BeforeRequestCons<First, Rest::Then<Next>> where Next: BeforeRequest<Req>;
|
||||||
|
|
||||||
|
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next> {
|
||||||
|
let BeforeRequestCons(first, rest) = self;
|
||||||
|
BeforeRequestCons(first, rest.then(next))
|
||||||
|
}
|
||||||
|
|
||||||
|
type Serve<S: Serve<Req = Req>> = HookThenServe<S, Self>;
|
||||||
|
|
||||||
|
fn serving<S: Serve<Req = Req>>(self, serve: S) -> Self::Serve<S> {
|
||||||
|
HookThenServe::new(serve, self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req> BeforeRequestList<Req> for BeforeRequestNil {
|
||||||
|
type Then<Next> = BeforeRequestCons<Next, BeforeRequestNil> where Next: BeforeRequest<Req>;
|
||||||
|
|
||||||
|
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next> {
|
||||||
|
BeforeRequestCons(next, BeforeRequestNil)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Serve<S: Serve<Req = Req>> = S;
|
||||||
|
|
||||||
|
fn serving<S: Serve<Req = Req>>(self, serve: S) -> S {
|
||||||
|
serve
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn before_request_list() {
|
||||||
|
use crate::server::serve;
|
||||||
|
use futures::executor::block_on;
|
||||||
|
use std::cell::Cell;
|
||||||
|
|
||||||
|
let i = Cell::new(0);
|
||||||
|
let serve = before()
|
||||||
|
.then_fn(|_, _| async {
|
||||||
|
assert!(i.get() == 0);
|
||||||
|
i.set(1);
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.then_fn(|_, _| async {
|
||||||
|
assert!(i.get() == 1);
|
||||||
|
i.set(2);
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.serving(serve(|_ctx, i| async move { Ok(i + 1) }));
|
||||||
|
let response = serve.clone().serve(context::current(), 1);
|
||||||
|
assert!(block_on(response).is_ok());
|
||||||
|
assert!(i.get() == 2);
|
||||||
|
}
|
||||||
57
tarpc/src/server/request_hook/before_and_after.rs
Normal file
57
tarpc/src/server/request_hook/before_and_after.rs
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
// Copyright 2022 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
//! Provides a hook that runs both before and after request execution.
|
||||||
|
|
||||||
|
use super::{after::AfterRequest, before::BeforeRequest};
|
||||||
|
use crate::{context, server::Serve, ServerError};
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
/// A Service function that runs a hook both before and after request execution.
|
||||||
|
pub struct HookThenServeThenHook<Req, Resp, Serv, Hook> {
|
||||||
|
serve: Serv,
|
||||||
|
hook: Hook,
|
||||||
|
fns: PhantomData<(fn(Req), fn(Resp))>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp, Serv, Hook> HookThenServeThenHook<Req, Resp, Serv, Hook> {
|
||||||
|
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
|
||||||
|
Self {
|
||||||
|
serve,
|
||||||
|
hook,
|
||||||
|
fns: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp, Serv: Clone, Hook: Clone> Clone for HookThenServeThenHook<Req, Resp, Serv, Hook> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
serve: self.serve.clone(),
|
||||||
|
hook: self.hook.clone(),
|
||||||
|
fns: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp, Serv, Hook> Serve for HookThenServeThenHook<Req, Resp, Serv, Hook>
|
||||||
|
where
|
||||||
|
Serv: Serve<Req = Req, Resp = Resp>,
|
||||||
|
Hook: BeforeRequest<Req> + AfterRequest<Resp>,
|
||||||
|
{
|
||||||
|
type Req = Req;
|
||||||
|
type Resp = Resp;
|
||||||
|
|
||||||
|
async fn serve(self, mut ctx: context::Context, req: Req) -> Result<Serv::Resp, ServerError> {
|
||||||
|
let HookThenServeThenHook {
|
||||||
|
serve, mut hook, ..
|
||||||
|
} = self;
|
||||||
|
hook.before(&mut ctx, &req).await?;
|
||||||
|
let mut resp = serve.serve(ctx, req).await;
|
||||||
|
hook.after(&mut ctx, &mut resp).await;
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,8 +5,9 @@
|
|||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||||
context,
|
context,
|
||||||
server::{Channel, Config, TrackedRequest},
|
server::{Channel, Config, ResponseGuard, TrackedRequest},
|
||||||
Request, Response,
|
Request, Response,
|
||||||
};
|
};
|
||||||
use futures::{task::*, Sink, Stream};
|
use futures::{task::*, Sink, Stream};
|
||||||
@@ -22,6 +23,8 @@ pub(crate) struct FakeChannel<In, Out> {
|
|||||||
pub sink: VecDeque<Out>,
|
pub sink: VecDeque<Out>,
|
||||||
pub config: Config,
|
pub config: Config,
|
||||||
pub in_flight_requests: super::in_flight_requests::InFlightRequests,
|
pub in_flight_requests: super::in_flight_requests::InFlightRequests,
|
||||||
|
pub request_cancellation: RequestCancellation,
|
||||||
|
pub canceled_requests: CanceledRequests,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<In, Out> Stream for FakeChannel<In, Out>
|
impl<In, Out> Stream for FakeChannel<In, Out>
|
||||||
@@ -86,6 +89,7 @@ where
|
|||||||
impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||||
pub fn push_req(&mut self, id: u64, message: Req) {
|
pub fn push_req(&mut self, id: u64, message: Req) {
|
||||||
let (_, abort_registration) = futures::future::AbortHandle::new_pair();
|
let (_, abort_registration) = futures::future::AbortHandle::new_pair();
|
||||||
|
let (request_cancellation, _) = cancellations();
|
||||||
self.stream.push_back(Ok(TrackedRequest {
|
self.stream.push_back(Ok(TrackedRequest {
|
||||||
request: Request {
|
request: Request {
|
||||||
context: context::Context {
|
context: context::Context {
|
||||||
@@ -97,17 +101,25 @@ impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
|||||||
},
|
},
|
||||||
abort_registration,
|
abort_registration,
|
||||||
span: Span::none(),
|
span: Span::none(),
|
||||||
|
response_guard: ResponseGuard {
|
||||||
|
request_cancellation,
|
||||||
|
request_id: id,
|
||||||
|
cancel: false,
|
||||||
|
},
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FakeChannel<(), ()> {
|
impl FakeChannel<(), ()> {
|
||||||
pub fn default<Req, Resp>() -> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
pub fn default<Req, Resp>() -> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||||
|
let (request_cancellation, canceled_requests) = cancellations();
|
||||||
FakeChannel {
|
FakeChannel {
|
||||||
stream: Default::default(),
|
stream: Default::default(),
|
||||||
sink: Default::default(),
|
sink: Default::default(),
|
||||||
config: Default::default(),
|
config: Default::default(),
|
||||||
in_flight_requests: Default::default(),
|
in_flight_requests: Default::default(),
|
||||||
|
request_cancellation,
|
||||||
|
canceled_requests,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -123,5 +135,5 @@ impl<T> PollExt for Poll<Option<T>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn cx() -> Context<'static> {
|
pub fn cx() -> Context<'static> {
|
||||||
Context::from_waker(&noop_waker_ref())
|
Context::from_waker(noop_waker_ref())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,111 +0,0 @@
|
|||||||
use super::{Channel, Requests, Serve};
|
|
||||||
use futures::{prelude::*, ready, task::*};
|
|
||||||
use pin_project::pin_project;
|
|
||||||
use std::pin::Pin;
|
|
||||||
|
|
||||||
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
|
|
||||||
/// for each new channel. Returned by
|
|
||||||
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct TokioServerExecutor<T, S> {
|
|
||||||
#[pin]
|
|
||||||
inner: T,
|
|
||||||
serve: S,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, S> TokioServerExecutor<T, S> {
|
|
||||||
pub(crate) fn new(inner: T, serve: S) -> Self {
|
|
||||||
Self { inner, serve }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A future that drives the server by [spawning](tokio::spawn) each [response
|
|
||||||
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
|
|
||||||
/// [`Channel::execute`](crate::server::Channel::execute).
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct TokioChannelExecutor<T, S> {
|
|
||||||
#[pin]
|
|
||||||
inner: T,
|
|
||||||
serve: S,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, S> TokioServerExecutor<T, S> {
|
|
||||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
|
||||||
self.as_mut().project().inner
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, S> TokioChannelExecutor<T, S> {
|
|
||||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
|
||||||
self.as_mut().project().inner
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send + 'static execution helper methods.
|
|
||||||
|
|
||||||
impl<C> Requests<C>
|
|
||||||
where
|
|
||||||
C: Channel,
|
|
||||||
C::Req: Send + 'static,
|
|
||||||
C::Resp: Send + 'static,
|
|
||||||
{
|
|
||||||
/// Executes all requests using the given service function. Requests are handled concurrently
|
|
||||||
/// by [spawning](::tokio::spawn) each handler on tokio's default executor.
|
|
||||||
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
|
|
||||||
where
|
|
||||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
|
||||||
{
|
|
||||||
TokioChannelExecutor { inner: self, serve }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
|
|
||||||
where
|
|
||||||
St: Sized + Stream<Item = C>,
|
|
||||||
C: Channel + Send + 'static,
|
|
||||||
C::Req: Send + 'static,
|
|
||||||
C::Resp: Send + 'static,
|
|
||||||
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
|
||||||
Se::Fut: Send,
|
|
||||||
{
|
|
||||||
type Output = ();
|
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
|
||||||
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
|
||||||
tokio::spawn(channel.execute(self.serve.clone()));
|
|
||||||
}
|
|
||||||
tracing::info!("Server shutting down.");
|
|
||||||
Poll::Ready(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
|
|
||||||
where
|
|
||||||
C: Channel + 'static,
|
|
||||||
C::Req: Send + 'static,
|
|
||||||
C::Resp: Send + 'static,
|
|
||||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
|
||||||
S::Fut: Send,
|
|
||||||
{
|
|
||||||
type Output = ();
|
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
||||||
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
|
||||||
match response_handler {
|
|
||||||
Ok(resp) => {
|
|
||||||
let server = self.serve.clone();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
resp.execute(server).await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::warn!("Requests stream errored out: {}", e);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Poll::Ready(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -138,25 +138,25 @@ impl From<u64> for SpanId {
|
|||||||
|
|
||||||
impl From<opentelemetry::trace::TraceId> for TraceId {
|
impl From<opentelemetry::trace::TraceId> for TraceId {
|
||||||
fn from(trace_id: opentelemetry::trace::TraceId) -> Self {
|
fn from(trace_id: opentelemetry::trace::TraceId) -> Self {
|
||||||
Self::from(trace_id.to_u128())
|
Self::from(u128::from_be_bytes(trace_id.to_bytes()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<TraceId> for opentelemetry::trace::TraceId {
|
impl From<TraceId> for opentelemetry::trace::TraceId {
|
||||||
fn from(trace_id: TraceId) -> Self {
|
fn from(trace_id: TraceId) -> Self {
|
||||||
Self::from_u128(trace_id.into())
|
Self::from_bytes(u128::from(trace_id).to_be_bytes())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<opentelemetry::trace::SpanId> for SpanId {
|
impl From<opentelemetry::trace::SpanId> for SpanId {
|
||||||
fn from(span_id: opentelemetry::trace::SpanId) -> Self {
|
fn from(span_id: opentelemetry::trace::SpanId) -> Self {
|
||||||
Self::from(span_id.to_u64())
|
Self::from(u64::from_be_bytes(span_id.to_bytes()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<SpanId> for opentelemetry::trace::SpanId {
|
impl From<SpanId> for opentelemetry::trace::SpanId {
|
||||||
fn from(span_id: SpanId) -> Self {
|
fn from(span_id: SpanId) -> Self {
|
||||||
Self::from_u64(span_id.0)
|
Self::from_bytes(u64::from(span_id).to_be_bytes())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,9 +14,15 @@ use tokio::sync::mpsc;
|
|||||||
/// Errors that occur in the sending or receiving of messages over a channel.
|
/// Errors that occur in the sending or receiving of messages over a channel.
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum ChannelError {
|
pub enum ChannelError {
|
||||||
/// An error occurred sending over the channel.
|
/// An error occurred readying to send into the channel.
|
||||||
#[error("an error occurred sending over the channel")]
|
#[error("an error occurred readying to send into the channel")]
|
||||||
|
Ready(#[source] Box<dyn Error + Send + Sync + 'static>),
|
||||||
|
/// An error occurred sending into the channel.
|
||||||
|
#[error("an error occurred sending into the channel")]
|
||||||
Send(#[source] Box<dyn Error + Send + Sync + 'static>),
|
Send(#[source] Box<dyn Error + Send + Sync + 'static>),
|
||||||
|
/// An error occurred receiving from the channel.
|
||||||
|
#[error("an error occurred receiving from the channel")]
|
||||||
|
Receive(#[source] Box<dyn Error + Send + Sync + 'static>),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
|
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
|
||||||
@@ -48,7 +54,10 @@ impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
|
|||||||
mut self: Pin<&mut Self>,
|
mut self: Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Option<Result<Item, ChannelError>>> {
|
) -> Poll<Option<Result<Item, ChannelError>>> {
|
||||||
self.rx.poll_recv(cx).map(|option| option.map(Ok))
|
self.rx
|
||||||
|
.poll_recv(cx)
|
||||||
|
.map(|option| option.map(Ok))
|
||||||
|
.map_err(ChannelError::Receive)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,7 +68,7 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
|
|||||||
|
|
||||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
Poll::Ready(if self.tx.is_closed() {
|
Poll::Ready(if self.tx.is_closed() {
|
||||||
Err(ChannelError::Send(CLOSED_MESSAGE.into()))
|
Err(ChannelError::Ready(CLOSED_MESSAGE.into()))
|
||||||
} else {
|
} else {
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
@@ -110,7 +119,11 @@ impl<Item, SinkItem> Stream for Channel<Item, SinkItem> {
|
|||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Option<Result<Item, ChannelError>>> {
|
) -> Poll<Option<Result<Item, ChannelError>>> {
|
||||||
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
|
self.project()
|
||||||
|
.rx
|
||||||
|
.poll_next(cx)
|
||||||
|
.map(|option| option.map(Ok))
|
||||||
|
.map_err(ChannelError::Receive)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,7 +134,7 @@ impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
|
|||||||
self.project()
|
self.project()
|
||||||
.tx
|
.tx
|
||||||
.poll_ready(cx)
|
.poll_ready(cx)
|
||||||
.map_err(|e| ChannelError::Send(Box::new(e)))
|
.map_err(|e| ChannelError::Ready(Box::new(e)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||||
@@ -146,16 +159,17 @@ impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(all(test, feature = "tokio1"))]
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::{
|
use crate::{
|
||||||
client, context,
|
client::{self, RpcError},
|
||||||
server::{incoming::Incoming, BaseChannel},
|
context,
|
||||||
|
server::{incoming::Incoming, serve, BaseChannel},
|
||||||
transport::{
|
transport::{
|
||||||
self,
|
self,
|
||||||
channel::{Channel, UnboundedChannel},
|
channel::{Channel, UnboundedChannel},
|
||||||
},
|
},
|
||||||
|
ServerError,
|
||||||
};
|
};
|
||||||
use assert_matches::assert_matches;
|
use assert_matches::assert_matches;
|
||||||
use futures::{prelude::*, stream};
|
use futures::{prelude::*, stream};
|
||||||
@@ -177,25 +191,28 @@ mod tests {
|
|||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
stream::once(future::ready(server_channel))
|
stream::once(future::ready(server_channel))
|
||||||
.map(BaseChannel::with_defaults)
|
.map(BaseChannel::with_defaults)
|
||||||
.execute(|_ctx, request: String| {
|
.execute(serve(|_ctx, request: String| async move {
|
||||||
future::ready(request.parse::<u64>().map_err(|_| {
|
request.parse::<u64>().map_err(|_| {
|
||||||
io::Error::new(
|
ServerError::new(
|
||||||
io::ErrorKind::InvalidInput,
|
io::ErrorKind::InvalidInput,
|
||||||
format!("{:?} is not an int", request),
|
format!("{request:?} is not an int"),
|
||||||
)
|
)
|
||||||
}))
|
})
|
||||||
|
}))
|
||||||
|
.for_each(|channel| async move {
|
||||||
|
tokio::spawn(channel.for_each(|response| response));
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
let client = client::new(client::Config::default(), client_channel).spawn();
|
let client = client::new(client::Config::default(), client_channel).spawn();
|
||||||
|
|
||||||
let response1 = client.call(context::current(), "", "123".into()).await?;
|
let response1 = client.call(context::current(), "", "123".into()).await;
|
||||||
let response2 = client.call(context::current(), "", "abc".into()).await?;
|
let response2 = client.call(context::current(), "", "abc".into()).await;
|
||||||
|
|
||||||
trace!("response1: {:?}, response2: {:?}", response1, response2);
|
trace!("response1: {:?}, response2: {:?}", response1, response2);
|
||||||
|
|
||||||
assert_matches!(response1, Ok(123));
|
assert_matches!(response1, Ok(123));
|
||||||
assert_matches!(response2, Err(ref e) if e.kind() == io::ErrorKind::InvalidInput);
|
assert_matches!(response2, Err(RpcError::Server(e)) if e.kind == io::ErrorKind::InvalidInput);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,11 +38,34 @@ where
|
|||||||
H: BuildHasher,
|
H: BuildHasher,
|
||||||
{
|
{
|
||||||
fn compact(&mut self, usage_ratio_threshold: f64) {
|
fn compact(&mut self, usage_ratio_threshold: f64) {
|
||||||
if self.capacity() > 1000 {
|
let usage_ratio_threshold = usage_ratio_threshold.clamp(f64::MIN_POSITIVE, 1.);
|
||||||
let usage_ratio = self.len() as f64 / self.capacity() as f64;
|
let cap = f64::max(1000., self.len() as f64 / usage_ratio_threshold);
|
||||||
if usage_ratio < usage_ratio_threshold {
|
self.shrink_to(cap as usize);
|
||||||
self.shrink_to_fit();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_compact() {
|
||||||
|
let mut map = HashMap::with_capacity(2048);
|
||||||
|
assert_eq!(map.capacity(), 3584);
|
||||||
|
|
||||||
|
// Make usage ratio 25%
|
||||||
|
for i in 0..896 {
|
||||||
|
map.insert(format!("k{i}"), "v");
|
||||||
|
}
|
||||||
|
|
||||||
|
map.compact(-1.0);
|
||||||
|
assert_eq!(map.capacity(), 3584);
|
||||||
|
|
||||||
|
map.compact(0.25);
|
||||||
|
assert_eq!(map.capacity(), 3584);
|
||||||
|
|
||||||
|
map.compact(0.50);
|
||||||
|
assert_eq!(map.capacity(), 1792);
|
||||||
|
|
||||||
|
map.compact(1.0);
|
||||||
|
assert_eq!(map.capacity(), 1792);
|
||||||
|
|
||||||
|
map.compact(2.0);
|
||||||
|
assert_eq!(map.capacity(), 1792);
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,4 +2,6 @@
|
|||||||
fn ui() {
|
fn ui() {
|
||||||
let t = trybuild::TestCases::new();
|
let t = trybuild::TestCases::new();
|
||||||
t.compile_fail("tests/compile_fail/*.rs");
|
t.compile_fail("tests/compile_fail/*.rs");
|
||||||
|
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
||||||
|
t.compile_fail("tests/compile_fail/serde_transport/*.rs");
|
||||||
}
|
}
|
||||||
|
|||||||
15
tarpc/tests/compile_fail/must_use_request_dispatch.rs
Normal file
15
tarpc/tests/compile_fail/must_use_request_dispatch.rs
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
use tarpc::client;
|
||||||
|
|
||||||
|
#[tarpc::service]
|
||||||
|
trait World {
|
||||||
|
async fn hello(name: String) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let (client_transport, _) = tarpc::transport::channel::unbounded();
|
||||||
|
|
||||||
|
#[deny(unused_must_use)]
|
||||||
|
{
|
||||||
|
WorldClient::new(client::Config::default(), client_transport).dispatch;
|
||||||
|
}
|
||||||
|
}
|
||||||
15
tarpc/tests/compile_fail/must_use_request_dispatch.stderr
Normal file
15
tarpc/tests/compile_fail/must_use_request_dispatch.stderr
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
error: unused `RequestDispatch` that must be used
|
||||||
|
--> tests/compile_fail/must_use_request_dispatch.rs:13:9
|
||||||
|
|
|
||||||
|
13 | WorldClient::new(client::Config::default(), client_transport).dispatch;
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
note: the lint level is defined here
|
||||||
|
--> tests/compile_fail/must_use_request_dispatch.rs:11:12
|
||||||
|
|
|
||||||
|
11 | #[deny(unused_must_use)]
|
||||||
|
| ^^^^^^^^^^^^^^^
|
||||||
|
help: use `let _ = ...` to ignore the resulting value
|
||||||
|
|
|
||||||
|
13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch;
|
||||||
|
| +++++++
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
use tarpc::serde_transport;
|
||||||
|
use tokio_serde::formats::Json;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
#[deny(unused_must_use)]
|
||||||
|
{
|
||||||
|
serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
error: unused `tarpc::serde_transport::tcp::Connect` that must be used
|
||||||
|
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9
|
||||||
|
|
|
||||||
|
7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
note: the lint level is defined here
|
||||||
|
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:5:12
|
||||||
|
|
|
||||||
|
5 | #[deny(unused_must_use)]
|
||||||
|
| ^^^^^^^^^^^^^^^
|
||||||
|
help: use `let _ = ...` to ignore the resulting value
|
||||||
|
|
|
||||||
|
7 | let _ = serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
|
||||||
|
| +++++++
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
#[tarpc::service(derive_serde = false)]
|
|
||||||
trait World {
|
|
||||||
async fn hello(name: String) -> String;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct HelloServer;
|
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl World for HelloServer {
|
|
||||||
fn hello(name: String) -> String {
|
|
||||||
format!("Hello, {}!", name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
error: not all trait items implemented, missing: `HelloFut`
|
|
||||||
--> $DIR/tarpc_server_missing_async.rs:9:1
|
|
||||||
|
|
|
||||||
9 | impl World for HelloServer {
|
|
||||||
| ^^^^
|
|
||||||
|
|
||||||
error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async
|
|
||||||
--> $DIR/tarpc_server_missing_async.rs:10:5
|
|
||||||
|
|
|
||||||
10 | fn hello(name: String) -> String {
|
|
||||||
| ^^
|
|
||||||
@@ -7,7 +7,7 @@ use tarpc::{
|
|||||||
use tokio_serde::formats::Json;
|
use tokio_serde::formats::Json;
|
||||||
|
|
||||||
#[tarpc::derive_serde]
|
#[tarpc::derive_serde]
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub enum TestData {
|
pub enum TestData {
|
||||||
Black,
|
Black,
|
||||||
White,
|
White,
|
||||||
@@ -21,7 +21,6 @@ pub trait ColorProtocol {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct ColorServer;
|
struct ColorServer;
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl ColorProtocol for ColorServer {
|
impl ColorProtocol for ColorServer {
|
||||||
async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData {
|
async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData {
|
||||||
match color {
|
match color {
|
||||||
@@ -31,6 +30,11 @@ impl ColorProtocol for ColorServer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_call() -> anyhow::Result<()> {
|
async fn test_call() -> anyhow::Result<()> {
|
||||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
|
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
|
||||||
@@ -40,7 +44,9 @@ async fn test_call() -> anyhow::Result<()> {
|
|||||||
.take(1)
|
.take(1)
|
||||||
.filter_map(|r| async { r.ok() })
|
.filter_map(|r| async { r.ok() })
|
||||||
.map(BaseChannel::with_defaults)
|
.map(BaseChannel::with_defaults)
|
||||||
.execute(ColorServer.serve()),
|
.execute(ColorServer.serve())
|
||||||
|
.map(|channel| channel.for_each(spawn))
|
||||||
|
.for_each(spawn),
|
||||||
);
|
);
|
||||||
|
|
||||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
use assert_matches::assert_matches;
|
use assert_matches::assert_matches;
|
||||||
use futures::{
|
use futures::{
|
||||||
future::{join_all, ready, Ready},
|
future::{join_all, ready},
|
||||||
prelude::*,
|
prelude::*,
|
||||||
};
|
};
|
||||||
use std::time::{Duration, SystemTime};
|
use std::time::{Duration, SystemTime};
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
client::{self},
|
client::{self},
|
||||||
context,
|
context,
|
||||||
server::{self, incoming::Incoming, BaseChannel, Channel},
|
server::{incoming::Incoming, BaseChannel, Channel},
|
||||||
transport::channel,
|
transport::channel,
|
||||||
};
|
};
|
||||||
use tokio::join;
|
use tokio::join;
|
||||||
@@ -22,39 +22,29 @@ trait Service {
|
|||||||
struct Server;
|
struct Server;
|
||||||
|
|
||||||
impl Service for Server {
|
impl Service for Server {
|
||||||
type AddFut = Ready<i32>;
|
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
|
||||||
|
x + y
|
||||||
fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut {
|
|
||||||
ready(x + y)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type HeyFut = Ready<String>;
|
async fn hey(self, _: context::Context, name: String) -> String {
|
||||||
|
format!("Hey, {name}.")
|
||||||
fn hey(self, _: context::Context, name: String) -> Self::HeyFut {
|
|
||||||
ready(format!("Hey, {}.", name))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sequential() -> anyhow::Result<()> {
|
async fn sequential() {
|
||||||
let _ = tracing_subscriber::fmt::try_init();
|
let (tx, rx) = tarpc::transport::channel::unbounded();
|
||||||
|
let client = client::new(client::Config::default(), tx).spawn();
|
||||||
let (tx, rx) = channel::unbounded();
|
let channel = BaseChannel::with_defaults(rx);
|
||||||
|
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
BaseChannel::new(server::Config::default(), rx)
|
channel
|
||||||
.requests()
|
.execute(tarpc::server::serve(|_, i| async move { Ok(i + 1) }))
|
||||||
.execute(Server.serve()),
|
.for_each(|response| response),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
client.call(context::current(), "AddOne", 1).await.unwrap(),
|
||||||
|
2
|
||||||
);
|
);
|
||||||
|
|
||||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
|
||||||
|
|
||||||
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
|
|
||||||
assert_matches!(
|
|
||||||
client.hey(context::current(), "Tim".into()).await,
|
|
||||||
Ok(ref s) if s == "Hey, Tim.");
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -70,7 +60,6 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct AllHandlersComplete;
|
struct AllHandlersComplete;
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl Loop for LoopServer {
|
impl Loop for LoopServer {
|
||||||
async fn r#loop(self, _: context::Context) {
|
async fn r#loop(self, _: context::Context) {
|
||||||
loop {
|
loop {
|
||||||
@@ -108,7 +97,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn serde() -> anyhow::Result<()> {
|
async fn serde_tcp() -> anyhow::Result<()> {
|
||||||
use tarpc::serde_transport;
|
use tarpc::serde_transport;
|
||||||
use tokio_serde::formats::Json;
|
use tokio_serde::formats::Json;
|
||||||
|
|
||||||
@@ -121,7 +110,9 @@ async fn serde() -> anyhow::Result<()> {
|
|||||||
.take(1)
|
.take(1)
|
||||||
.filter_map(|r| async { r.ok() })
|
.filter_map(|r| async { r.ok() })
|
||||||
.map(BaseChannel::with_defaults)
|
.map(BaseChannel::with_defaults)
|
||||||
.execute(Server.serve()),
|
.execute(Server.serve())
|
||||||
|
.map(|channel| channel.for_each(spawn))
|
||||||
|
.for_each(spawn),
|
||||||
);
|
);
|
||||||
|
|
||||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||||
@@ -136,6 +127,39 @@ async fn serde() -> anyhow::Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(all(feature = "serde-transport", feature = "unix", unix))]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn serde_uds() -> anyhow::Result<()> {
|
||||||
|
use tarpc::serde_transport;
|
||||||
|
use tokio_serde::formats::Json;
|
||||||
|
|
||||||
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
|
||||||
|
let sock = tarpc::serde_transport::unix::TempPathBuf::with_random("uds");
|
||||||
|
let transport = tarpc::serde_transport::unix::listen(&sock, Json::default).await?;
|
||||||
|
tokio::spawn(
|
||||||
|
transport
|
||||||
|
.take(1)
|
||||||
|
.filter_map(|r| async { r.ok() })
|
||||||
|
.map(BaseChannel::with_defaults)
|
||||||
|
.execute(Server.serve())
|
||||||
|
.map(|channel| channel.for_each(spawn))
|
||||||
|
.for_each(spawn),
|
||||||
|
);
|
||||||
|
|
||||||
|
let transport = serde_transport::unix::connect(&sock, Json::default).await?;
|
||||||
|
let client = ServiceClient::new(client::Config::default(), transport).spawn();
|
||||||
|
|
||||||
|
// Save results using socket so we can clean the socket even if our test assertions fail
|
||||||
|
let res1 = client.add(context::current(), 1, 2).await;
|
||||||
|
let res2 = client.hey(context::current(), "Tim".to_string()).await;
|
||||||
|
|
||||||
|
assert_matches!(res1, Ok(3));
|
||||||
|
assert_matches!(res2, Ok(ref s) if s == "Hey, Tim.");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn concurrent() -> anyhow::Result<()> {
|
async fn concurrent() -> anyhow::Result<()> {
|
||||||
let _ = tracing_subscriber::fmt::try_init();
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
@@ -144,7 +168,9 @@ async fn concurrent() -> anyhow::Result<()> {
|
|||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
stream::once(ready(rx))
|
stream::once(ready(rx))
|
||||||
.map(BaseChannel::with_defaults)
|
.map(BaseChannel::with_defaults)
|
||||||
.execute(Server.serve()),
|
.execute(Server.serve())
|
||||||
|
.map(|channel| channel.for_each(spawn))
|
||||||
|
.for_each(spawn),
|
||||||
);
|
);
|
||||||
|
|
||||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||||
@@ -168,7 +194,9 @@ async fn concurrent_join() -> anyhow::Result<()> {
|
|||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
stream::once(ready(rx))
|
stream::once(ready(rx))
|
||||||
.map(BaseChannel::with_defaults)
|
.map(BaseChannel::with_defaults)
|
||||||
.execute(Server.serve()),
|
.execute(Server.serve())
|
||||||
|
.map(|channel| channel.for_each(spawn))
|
||||||
|
.for_each(spawn),
|
||||||
);
|
);
|
||||||
|
|
||||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||||
@@ -185,15 +213,20 @@ async fn concurrent_join() -> anyhow::Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn concurrent_join_all() -> anyhow::Result<()> {
|
async fn concurrent_join_all() -> anyhow::Result<()> {
|
||||||
let _ = tracing_subscriber::fmt::try_init();
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
|
||||||
let (tx, rx) = channel::unbounded();
|
let (tx, rx) = channel::unbounded();
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
stream::once(ready(rx))
|
BaseChannel::with_defaults(rx)
|
||||||
.map(BaseChannel::with_defaults)
|
.execute(Server.serve())
|
||||||
.execute(Server.serve()),
|
.for_each(spawn),
|
||||||
);
|
);
|
||||||
|
|
||||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||||
@@ -218,11 +251,9 @@ async fn counter() -> anyhow::Result<()> {
|
|||||||
struct CountService(u32);
|
struct CountService(u32);
|
||||||
|
|
||||||
impl Counter for &mut CountService {
|
impl Counter for &mut CountService {
|
||||||
type CountFut = futures::future::Ready<u32>;
|
async fn count(self, _: context::Context) -> u32 {
|
||||||
|
|
||||||
fn count(self, _: context::Context) -> Self::CountFut {
|
|
||||||
self.0 += 1;
|
self.0 += 1;
|
||||||
futures::future::ready(self.0)
|
self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user