30 Commits

Author SHA1 Message Date
Tim Kuehn
bed85e2827 Prepare release of v0.32.0 2023-03-24 15:04:06 -07:00
Bruno
93f3880025 Return transport errors to the caller (#399)
* Make client::InFlightRequests generic over result.

Previously, InFlightRequests required the client response type to be a
server response. However, this prevented injection of non-server
responses: for example, if the client fails to send a request, it should
complete the request with an IO error rather than a server error.

* Gracefully handle client-side send errors.

Previously, a client channel would immediately disconnect when
encountering an error in Transport::try_send. One kind of error that can
occur in try_send is message validation, e.g. validating a message is
not larger than a configured frame size. The problem with shutting down
the client immediately is that debuggability suffers: it can be hard to
understand what caused the client to fail. Also, these errors are not
always fatal, as with frame size limits, so complete shutdown was
extreme.

By bubbling up errors, it's now possible for the caller to
programmatically handle them. For example, the error could be walked
via anyhow::Error:

```
    2023-01-10T02:49:32.528939Z  WARN client: the client failed to send the request

    Caused by:
        0: could not write to the transport
        1: frame size too big
```

* Some follow-up work: right now, read errors will bubble up to all pending RPCs. However, on the write side, only `start_send` bubbles up. `poll_ready`, `poll_flush`, and `poll_close` do not propagate back to pending RPCs. This is probably okay in most circumstances, because fatal write errors likely coincide with fatal read errors, which *do* propagate back to clients. But it might still be worth unifying this logic.

---------

Co-authored-by: Tim Kuehn <tikue@google.com>
2023-03-24 14:31:25 -07:00
cguentherTUChemnitz
878f594d5b Feature/tls over tcp example (#398)
Example tarpc service that encodes messages with bincode written to a TLS-over-TCP transport.

Certs were generated with openssl 3 using https://github.com/rustls/rustls/tree/main/test-ca

New dependencies:
- tokio-rustls to set up the TLS connections
- rustls-pemfile to load certs from .pem files
2023-03-22 10:35:21 -07:00
Tim Kuehn
aa9bbad109 Fix compile_fail tests for formatting changes on stable 2023-03-17 10:07:54 -07:00
Tim Kuehn
7e872ce925 Remove bad mem::forget usage.
mem::forget is a dangerous tool, and it was being used carelessly for
things that have safer alternatives. There was at least one bug where a
cloned tokio::sync::mpsc::UnboundedSender used for request cancellation
was being leaked on every successful server response, so its refcounts
were never decremented. Because these are atomic refcounts, they'll wrap
around rather than overflow when reaching the maximum value, so I don't
believe this could lead to panics or unsoundness.
2022-11-23 18:01:12 -08:00
Tim Kuehn
62541b709d Replace actions-rs 2022-11-23 16:39:29 -08:00
Tim Kuehn
8c43f94fb6 Remove unused Sealed trait 2022-11-17 00:57:56 -08:00
Tim Kuehn
7fa4e5064d Ignore clippy false positive 2022-11-13 00:25:07 -08:00
Tim Kuehn
94db7610bb Require a static lifetime for request_name. 2022-11-05 11:43:03 -07:00
Tim Kuehn
0c08d5e8ca Prepare release of v0.31.0 2022-11-03 13:29:46 -07:00
Tim Kuehn
75b15fe2aa Address clippy lint 2022-10-07 10:51:45 -07:00
Tim Kuehn
863a08d87e In example-service, print the port the server is listened on.
This is helpful when passing starting the server with --port 0.
2022-10-06 20:58:54 -07:00
Tim Kuehn
49ba8f8b1b Zero-pad the random number suffix of TempPathBufs.
This way, the hex number is always 16 digits, which is helpful for test
verification as well as simple consistency.
2022-10-03 18:50:50 -07:00
Kevin K
d832209da3 feat: Unix domain sockets with serde transports (#380)
* adds support for Unix Domain Socket generic transports
* adds a TempPathBuf that lives in temp and is removed on drop
2022-10-03 18:07:29 -07:00
royrustdev
584426d414 fix clippy warnings #378 2022-09-19 23:26:07 -07:00
royrustdev
50eb80c883 reference latest tarpc version in readme 2022-09-19 21:58:21 -07:00
royrustdev
1f0c80d8c9 bump github actions 2022-09-15 11:17:58 -07:00
Tim Kuehn
99bf3e62a3 Prepare release of 0.30.0 2022-08-12 16:08:33 -07:00
Tim Kuehn
68863e3db0 Remove Channel::request_cancellation.
This trait fn returns a private type, which means it's useless for
anyone using the Channel.

Instead, add an inert (now-public) ResponseGuard to TrackedRequest that,
when taken out of the ManuallyDrop, ensures a Channel's request state is
cleaned up. It's preferable to make ResponseGuard public instead of
RequestCancellations because it's a smaller API surface (no public
methods, just a Drop fn) and harder to misuse, because it is already
associated with the correct request ID to cancel.
2022-08-12 16:08:33 -07:00
Tim Kuehn
453ba1c074 Lower log level of log in the RPC callpath 2022-08-12 09:04:47 -07:00
Makro
e3eac1b4f5 Add LICENSE files to crates (#372) 2022-08-10 17:11:50 -07:00
kkharji
0e102288a5 feat: re-export used packages (#371)
## Problem
Library users might get stuck with or ran into issues while using tarpc because of incompatible third party libraries. in particular, tokio_serde and tokio_util.

## Solution
This PR does the following:

1. re-export tokio_serde as part of feature serde-transport, because the end user imports it to use some serde-transport APIs.
2. Update third library packages to latest release and fix resulting issues from that.

## Important Notes
tokio_util 7.3 DelayQueue::poll_expired API changed [0] therefore, InFlightRequests::poll_expired now returns Poll<Option<u64>>

[0] https://docs.rs/tokio-util/latest/tokio_util/time/delay_queue/struct.DelayQueue.html#method.poll_expired
2022-07-15 10:14:49 -07:00
Tim Kuehn
4c8ba41b2f #[allow(unstable_name_collisions)] for .ready()
.ready() is being added to std, but in the meantime, I don't want to stop using PollTest::ready.
2022-06-07 01:29:14 -07:00
Tim Kuehn
946c627579 Remove unused field 2022-06-07 01:29:14 -07:00
Tim Kuehn
104dd71bba Clean up Channel request data more reliably.
When an InFlightRequest is dropped before response completion, request
data in the Channel persists until either the request expires or the
client cancels the request. In rare cases, requests with very large
deadlines could clog up the Channel long after request processing
ceases.

This commit adds a drop hook to InFlightRequest so that if it is dropped
before execution completes, a cancellation message is sent to the
Channel so that it can clean up the associated request data.

This only works for when using `InFlightRequest::execute` or
`Channel::execute`. However, users of raw `Channel` have access
to the `RequestCancellation` handle via `Channel::request_cancellation`,
so they can implement a similar method if they wish to manually clean up
request data.

Note that once a Channel's request data is cleaned up, that request can
never be responded to, even if a response is produced afterward.

Fixes https://github.com/google/tarpc/issues/314
2022-06-07 01:29:04 -07:00
Tim Kuehn
012c481861 Move cancellation types into a dedicated module.
Cancellation utilities could be useful for both client and server code.
2022-06-05 18:54:52 -07:00
Tim Kuehn
dc12bd09aa Annotate types that impl Future with #[must_use].
These types do nothing unless polled / .awaited.
Annotating them with #[must_use] helps prevent a common class of coding errors.

Fixes https://github.com/google/tarpc/issues/368.
2022-06-05 18:54:52 -07:00
Tim Kuehn
2594ea8ce9 Prepare release of 0.29.0 2022-06-05 15:26:33 -07:00
Tim Kuehn
839b87e394 Serialize RPC deadline as a Duration.
Duration was previously serialized as SystemTime. However, absolute
times run into problems with clock skew: if the remote machine's clock
is too far in the future, the RPC deadline will be exceeded before
request processing can begin. Conversely, if the remote machine's clock
is too far in the past, the RPC deadline will not be enforced.

By converting the absolute deadline to a relative duration, clock skew
is no longer relevant, as the remote machine will convert the deadline
into a time relative to its own clock. This mirrors how the gRPC HTTP2
protocol includes a Timeout in the request headers [0] but the SDK uses
timestamps [1]. Keeping the absolute time in the core APIs maintains all
the benefits of today, namely, natural deadline propagation between RPC
hops when using the current context.

This serialization strategy means that, generally, the remote machine's
deadline will be slightly in the future compared to the local machine.
Depending on network transfer latencies, this could be microseconds to
milliseconds, or worse in the worst case. Because the deadline is not
intended for high-precision scenarios, I don't view this is as
problematic.

Because this change only affects the serialization layer, local
transports that bypass serialization are not affected.

[0] https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
[1] https://grpc.io/blog/deadlines/#setting-a-deadline
2022-05-26 15:18:49 -07:00
Tim Kuehn
57d0638a99 Add rpc.deadline tag to Opentelemetry traces. 2022-05-26 15:18:49 -07:00
41 changed files with 1377 additions and 329 deletions

View File

@@ -14,99 +14,57 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.7.0
uses: styfle/cancel-workflow-action@0.10.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
with:
profile: minimal
toolchain: stable
target: mipsel-unknown-linux-gnu
override: true
- uses: actions-rs/cargo@v1
with:
command: check
args: --all-features
- uses: actions-rs/cargo@v1
with:
command: check
args: --all-features --target mipsel-unknown-linux-gnu
targets: mipsel-unknown-linux-gnu
- run: cargo check --all-features
- run: cargo check --all-features --target mipsel-unknown-linux-gnu
test:
name: Test Suite
runs-on: ubuntu-latest
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.7.0
uses: styfle/cancel-workflow-action@0.10.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: test
- uses: actions-rs/cargo@v1
with:
command: test
args: --manifest-path tarpc/Cargo.toml --features serde1
- uses: actions-rs/cargo@v1
with:
command: test
args: --manifest-path tarpc/Cargo.toml --features tokio1
- uses: actions-rs/cargo@v1
with:
command: test
args: --manifest-path tarpc/Cargo.toml --features serde-transport
- uses: actions-rs/cargo@v1
with:
command: test
args: --manifest-path tarpc/Cargo.toml --features tcp
- uses: actions-rs/cargo@v1
with:
command: test
args: --all-features
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
- run: cargo test
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde1
- run: cargo test --manifest-path tarpc/Cargo.toml --features tokio1
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde-transport
- run: cargo test --manifest-path tarpc/Cargo.toml --features tcp
- run: cargo test --all-features
fmt:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.7.0
uses: styfle/cancel-workflow-action@0.10.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
with:
profile: minimal
toolchain: stable
override: true
- run: rustup component add rustfmt
- uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
components: rustfmt
- run: cargo fmt --all -- --check
clippy:
name: Clippy
runs-on: ubuntu-latest
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.7.0
uses: styfle/cancel-workflow-action@0.10.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v1
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
with:
profile: minimal
toolchain: stable
override: true
- run: rustup component add clippy
- uses: actions-rs/cargo@v1
with:
command: clippy
args: --all-features -- -D warnings
components: clippy
- run: cargo clippy --all-features -- -D warnings

View File

@@ -67,7 +67,7 @@ Some other features of tarpc:
Add to your `Cargo.toml` dependencies:
```toml
tarpc = "0.29"
tarpc = "0.32"
```
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
@@ -82,7 +82,7 @@ your `Cargo.toml`:
```toml
anyhow = "1.0"
futures = "0.3"
tarpc = { version = "0.29", features = ["tokio1"] }
tarpc = { version = "0.31", features = ["tokio1"] }
tokio = { version = "1.0", features = ["macros"] }
```

View File

@@ -1,3 +1,47 @@
## 0.32.0 (2023-03-24)
### Breaking Changes
- As part of a fix to return more channel errors in RPC results, a few error types have changed:
0. `client::RpcError::Disconnected` was split into the following errors:
- Shutdown: the client was shutdown, either intentionally or due to an error. If due to an
error, pending RPCs should see the more specific errors below.
- Send: an RPC message failed to send over the transport. Only the RPC that failed to be sent
will see this error.
- Receive: a fatal error occurred while receiving from the transport. All in-flight RPCs will
receive this error.
0. `client::ChannelError` and `server::ChannelError` are unified in `tarpc::ChannelError`.
Previously, server transport errors would not indicate during which activity the transport
error occurred. Now, just like the client already was, it will be specific: reading, readying,
sending, flushing, or closing.
## 0.31.0 (2022-11-03)
### New Features
This release adds Unix Domain Sockets to the `serde_transport` module.
To use it, enable the "unix" feature. See the docs for more information.
## 0.30.0 (2022-08-12)
### Breaking Changes
- Some types that impl Future are now annotated with `#[must_use]`. Code that previously created
these types but did not use them will now receive a warning. Code that disallows warnings will
receive a compilation error.
### Fixes
- Servers will more reliably clean up request state for requests with long deadlines when response
processing is aborted without sending a response.
### Other Changes
- `TrackedRequest` now contains a response guard that can be used to ensure state cleanup for
aborted requests. (This was already handled automatically by `InFlightRequests`).
- When the feature serde-transport is enabled, the crate tokio_serde is now re-exported.
## 0.29.0 (2022-05-26)
### Breaking Changes

View File

@@ -1,6 +1,6 @@
[package]
name = "tarpc-example-service"
version = "0.11.0"
version = "0.14.0"
rust-version = "1.56"
authors = ["Tim Kuehn <tikue@google.com>"]
edition = "2021"
@@ -18,14 +18,14 @@ anyhow = "1.0"
clap = { version = "3.0.0-rc.9", features = ["derive"] }
log = "0.4"
futures = "0.3"
opentelemetry = { version = "0.16", features = ["rt-tokio"] }
opentelemetry-jaeger = { version = "0.15", features = ["rt-tokio"] }
opentelemetry = { version = "0.17", features = ["rt-tokio"] }
opentelemetry-jaeger = { version = "0.16", features = ["rt-tokio"] }
rand = "0.8"
tarpc = { version = "0.29", path = "../tarpc", features = ["full"] }
tarpc = { version = "0.32", path = "../tarpc", features = ["full"] }
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
tracing = { version = "0.1" }
tracing-opentelemetry = "0.15"
tracing-subscriber = "0.2"
tracing-opentelemetry = "0.17"
tracing-subscriber = {version = "0.3", features = ["env-filter"]}
[lib]
name = "service"

View File

@@ -26,7 +26,8 @@ async fn main() -> anyhow::Result<()> {
let flags = Flags::parse();
init_tracing("Tarpc Example Client")?;
let transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
transport.config_mut().max_frame_length(usize::MAX);
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
// config and any Transport as input.
@@ -42,7 +43,10 @@ async fn main() -> anyhow::Result<()> {
.instrument(tracing::info_span!("Two Hellos"))
.await;
tracing::info!("{:?}", hello);
match hello {
Ok(hello) => tracing::info!("{hello:?}"),
Err(e) => tracing::warn!("{:?}", anyhow::Error::from(e)),
}
// Let the background span processor finish.
sleep(Duration::from_micros(1)).await;

View File

@@ -54,6 +54,7 @@ async fn main() -> anyhow::Result<()> {
// JSON transport is provided by the json_transport tarpc module. It makes it easy
// to start up a serde-powered json serialization strategy over TCP.
let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?;
tracing::info!("Listening on port {}", listener.local_addr().port());
listener.config_mut().max_frame_length(usize::MAX);
listener
// Ignore accept errors.

9
plugins/LICENSE Normal file
View File

@@ -0,0 +1,9 @@
The MIT License (MIT)
Copyright 2016 Google Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -285,7 +285,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
args,
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
method_idents: &methods,
request_names: &*request_names,
request_names: &request_names,
attrs,
rpcs,
return_types: &rpcs

View File

@@ -1,6 +1,6 @@
[package]
name = "tarpc"
version = "0.29.0"
version = "0.32.0"
rust-version = "1.58.0"
authors = [
"Adam Wright <adam.austin.wright@gmail.com>",
@@ -13,7 +13,7 @@ homepage = "https://github.com/google/tarpc"
repository = "https://github.com/google/tarpc"
keywords = ["rpc", "network", "server", "api", "microservices"]
categories = ["asynchronous", "network-programming"]
readme = "../README.md"
readme = "README.md"
description = "An RPC framework for Rust with a focus on ease of use."
[features]
@@ -25,6 +25,7 @@ serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
serde-transport-json = ["tokio-serde/json"]
serde-transport-bincode = ["tokio-serde/bincode"]
tcp = ["tokio/net"]
unix = ["tokio/net"]
full = [
"serde1",
@@ -33,6 +34,7 @@ full = [
"serde-transport-json",
"serde-transport-bincode",
"tcp",
"unix",
]
[badges]
@@ -50,7 +52,7 @@ static_assertions = "1.1.0"
tarpc-plugins = { path = "../plugins", version = "0.12" }
thiserror = "1.0"
tokio = { version = "1", features = ["time"] }
tokio-util = { version = "0.6.9", features = ["time"] }
tokio-util = { version = "0.7.3", features = ["time"] }
tokio-serde = { optional = true, version = "0.8" }
tracing = { version = "0.1", default-features = false, features = [
"attributes",
@@ -76,6 +78,8 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tokio = { version = "1", features = ["full", "test-util"] }
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
trybuild = "1.0"
tokio-rustls = "0.23"
rustls-pemfile = "1.0"
[package.metadata.docs.rs]
all-features = true
@@ -101,6 +105,10 @@ required-features = ["full"]
name = "custom_transport"
required-features = ["serde1", "tokio1", "serde-transport"]
[[example]]
name = "tls_over_tcp"
required-features = ["full"]
[[test]]
name = "service_functional"
required-features = ["serde-transport"]

9
tarpc/LICENSE Normal file
View File

@@ -0,0 +1,9 @@
The MIT License (MIT)
Copyright 2016 Google Inc. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -0,0 +1,11 @@
-----BEGIN CERTIFICATE-----
MIIBlDCCAUagAwIBAgICAxUwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
NjEwMTEwNFowGjEYMBYGA1UEAwwPcG9ueXRvd24gY2xpZW50MCowBQYDK2VwAyEA
NTKuLume19IhJfEFd/5OZUuYDKZH6xvy4AGver17OoejgZswgZgwDAYDVR0TAQH/
BAIwADALBgNVHQ8EBAMCBsAwFgYDVR0lAQH/BAwwCgYIKwYBBQUHAwIwHQYDVR0O
BBYEFDjdrlMu4tyw5MHtbg7WnzSGRBpFMEQGA1UdIwQ9MDuAFHIl7fHKWP6/l8FE
fI2YEIM3oHxKoSCkHjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQYIBezAF
BgMrZXADQQCaahfj/QLxoCOpvl6y0ZQ9CpojPqBnxV3460j5nUOp040Va2MpF137
izCBY7LwgUE/YG6E+kH30G4jMEnqVEYK
-----END CERTIFICATE-----

View File

@@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
-----END CERTIFICATE-----

View File

@@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIIJX9ThTHpVS1SNZb6HP4myg4fRInIVGunTRdgnc+weH
-----END PRIVATE KEY-----

View File

@@ -0,0 +1,12 @@
-----BEGIN CERTIFICATE-----
MIIBuDCCAWqgAwIBAgICAcgwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
NjEwMTEwNFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wKjAFBgMrZXADIQDc
RLl3/N2tPoWnzBV3noVn/oheEl8IUtiY11Vg/QXTUKOBwDCBvTAMBgNVHRMBAf8E
AjAAMAsGA1UdDwQEAwIGwDAdBgNVHQ4EFgQUk7U2mnxedNWBAH84BsNy5si3ZQow
RAYDVR0jBD0wO4AUciXt8cpY/r+XwUR8jZgQgzegfEqhIKQeMBwxGjAYBgNVBAMM
EXBvbnl0b3duIEVkRFNBIENBggF7MDsGA1UdEQQ0MDKCDnRlc3RzZXJ2ZXIuY29t
ghVzZWNvbmQudGVzdHNlcnZlci5jb22CCWxvY2FsaG9zdDAFBgMrZXADQQCFWIcF
9FiztCuUNzgXDNu5kshuflt0RjkjWpGlWzQjGoYM2IvYhNVPeqnCiY92gqwDSBtq
amD2TBup4eNUCsQB
-----END CERTIFICATE-----

View File

@@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
-----END CERTIFICATE-----

View File

@@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIMU6xGVe8JTpZ3bN/wajHfw6pEHt0Rd7wPBxds9eEFy2
-----END PRIVATE KEY-----

View File

@@ -1,8 +1,9 @@
use tarpc::context::Context;
use tarpc::serde_transport as transport;
use tarpc::server::{BaseChannel, Channel};
use tarpc::{context::Context, tokio_serde::formats::Bincode};
use tarpc::tokio_serde::formats::Bincode;
use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec;
use tokio::net::{UnixListener, UnixStream};
use tokio_util::codec::length_delimited::LengthDelimitedCodec;
#[tarpc::service]
pub trait PingService {

View File

@@ -52,9 +52,9 @@ use tarpc::{
client, context,
serde_transport::tcp,
server::{self, Channel},
tokio_serde::formats::Json,
};
use tokio::net::ToSocketAddrs;
use tokio_serde::formats::Json;
use tracing::info;
use tracing_subscriber::prelude::*;
@@ -129,7 +129,6 @@ impl Subscriber {
#[derive(Debug)]
struct Subscription {
subscriber: subscriber::SubscriberClient,
topics: Vec<String>,
}
@@ -210,7 +209,6 @@ impl Publisher {
self.clients.lock().unwrap().insert(
subscriber_addr,
Subscription {
subscriber: subscriber.clone(),
topics: topics.clone(),
},
);

View File

@@ -0,0 +1,152 @@
use rustls_pemfile::certs;
use std::io::{BufReader, Cursor};
use std::net::{IpAddr, Ipv4Addr};
use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio_rustls::rustls::{self, Certificate, OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{webpki, TlsAcceptor, TlsConnector};
use tarpc::context::Context;
use tarpc::serde_transport as transport;
use tarpc::server::{BaseChannel, Channel};
use tarpc::tokio_serde::formats::Bincode;
use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec;
#[tarpc::service]
pub trait PingService {
async fn ping() -> String;
}
#[derive(Clone)]
struct Service;
#[tarpc::server]
impl PingService for Service {
async fn ping(self, _: Context) -> String {
"🔒".to_owned()
}
}
// certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca
// used on client-side for server tls
const END_CHAIN: &[u8] = include_bytes!("certs/eddsa/end.chain");
// used on client-side for client-auth
const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key");
const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert");
// used on server-side for server tls
const END_CERT: &str = include_str!("certs/eddsa/end.cert");
const END_PRIVATEKEY: &str = include_str!("certs/eddsa/end.key");
// used on server-side for client-auth
const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain");
pub fn load_private_key(key: &str) -> rustls::PrivateKey {
let mut reader = BufReader::new(Cursor::new(key));
loop {
match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
Some(rustls_pemfile::Item::RSAKey(key)) => return rustls::PrivateKey(key),
Some(rustls_pemfile::Item::PKCS8Key(key)) => return rustls::PrivateKey(key),
Some(rustls_pemfile::Item::ECKey(key)) => return rustls::PrivateKey(key),
None => break,
_ => {}
}
}
panic!("no keys found in {:?} (encrypted keys not supported)", key);
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// -------------------- start here to setup tls tcp tokio stream --------------------------
// ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs
// ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs
let cert = certs(&mut BufReader::new(Cursor::new(END_CERT)))
.unwrap()
.into_iter()
.map(rustls::Certificate)
.collect();
let key = load_private_key(END_PRIVATEKEY);
let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000);
// ------------- server side client_auth cert loading start
let roots: Vec<Certificate> = certs(&mut BufReader::new(Cursor::new(CLIENT_CHAIN_CLIENT_AUTH)))
.unwrap()
.into_iter()
.map(rustls::Certificate)
.collect();
let mut client_auth_roots = RootCertStore::empty();
for root in roots {
client_auth_roots.add(&root).unwrap();
}
let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots);
// ------------- server side client_auth cert loading end
let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_client_cert_verifier(client_auth) // use .with_no_client_auth() instead if you don't want client-auth
.with_single_cert(cert, key)
.unwrap();
let acceptor = TlsAcceptor::from(Arc::new(config));
let listener = TcpListener::bind(&server_addr).await.unwrap();
let codec_builder = LengthDelimitedCodec::builder();
// ref ./custom_transport.rs server side
tokio::spawn(async move {
loop {
let (stream, _peer_addr) = listener.accept().await.unwrap();
let acceptor = acceptor.clone();
let tls_stream = acceptor.accept(stream).await.unwrap();
let framed = codec_builder.new_framed(tls_stream);
let transport = transport::new(framed, Bincode::default());
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
tokio::spawn(fut);
}
});
// ---------------------- client connection ---------------------
// cert loading from: https://github.com/tokio-rs/tls/blob/357bc562483dcf04c1f8d08bd1a831b144bf7d4c/tokio-rustls/tests/test.rs#L113
// tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
let chain = certs(&mut std::io::Cursor::new(END_CHAIN)).unwrap();
let mut root_store = rustls::RootCertStore::empty();
root_store.add_server_trust_anchors(chain.iter().map(|cert| {
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH);
let client_auth_certs: Vec<Certificate> =
certs(&mut BufReader::new(Cursor::new(CLIENT_CERT_CLIENT_AUTH)))
.unwrap()
.into_iter()
.map(rustls::Certificate)
.collect();
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_single_cert(client_auth_certs, client_auth_private_key)?; // use .with_no_client_auth() instead if you don't want client-auth
let domain = rustls::ServerName::try_from("localhost")?;
let connector = TlsConnector::from(Arc::new(config));
let stream = TcpStream::connect(server_addr).await?;
let stream = connector.connect(domain, stream).await?;
let transport = transport::new(codec_builder.new_framed(stream), Bincode::default());
let answer = PingServiceClient::new(Default::default(), transport)
.spawn()
.ping(tarpc::context::current())
.await?;
println!("ping answer: {answer}");
Ok(())
}

View File

@@ -9,8 +9,8 @@ use futures::{future, prelude::*};
use tarpc::{
client, context,
server::{incoming::Incoming, BaseChannel},
tokio_serde::formats::Json,
};
use tokio_serde::formats::Json;
use tracing_subscriber::prelude::*;
pub mod add {

View File

@@ -0,0 +1,49 @@
use futures::{prelude::*, task::*};
use std::pin::Pin;
use tokio::sync::mpsc;
/// Sends request cancellation signals.
#[derive(Debug, Clone)]
pub struct RequestCancellation(mpsc::UnboundedSender<u64>);
/// A stream of IDs of requests that have been canceled.
#[derive(Debug)]
pub struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
/// Returns a channel to send request cancellation messages.
pub fn cancellations() -> (RequestCancellation, CanceledRequests) {
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
// bounded by the number of in-flight requests.
let (tx, rx) = mpsc::unbounded_channel();
(RequestCancellation(tx), CanceledRequests(rx))
}
impl RequestCancellation {
/// Cancels the request with ID `request_id`.
///
/// No validation is done of `request_id`. There is no way to know if the request id provided
/// corresponds to a request actually tracked by the backing channel. `RequestCancellation` is
/// a one-way communication channel.
///
/// Once request data is cleaned up, a response will never be received by the client. This is
/// useful primarily when request processing ends prematurely for requests with long deadlines
/// which would otherwise continue to be tracked by the backing channel—a kind of leak.
pub fn cancel(&self, request_id: u64) {
let _ = self.0.send(request_id);
}
}
impl CanceledRequests {
/// Polls for a cancelled request.
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<u64>> {
self.0.poll_recv(cx)
}
}
impl Stream for CanceledRequests {
type Item = u64;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
self.poll_recv(cx)
}
}

View File

@@ -8,14 +8,16 @@
mod in_flight_requests;
use crate::{context, trace, ClientMessage, Request, Response, ServerError, Transport};
use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation},
context, trace, ChannelError, ClientMessage, Request, Response, ServerError, Transport,
};
use futures::{prelude::*, ready, stream::Fuse, task::*};
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
use in_flight_requests::InFlightRequests;
use pin_project::pin_project;
use std::{
convert::TryFrom,
error::Error,
fmt, mem,
fmt,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
@@ -121,12 +123,12 @@ impl<Req, Resp> Channel<Req, Resp> {
pub async fn call(
&self,
mut ctx: context::Context,
request_name: &str,
request_name: &'static str,
request: Req,
) -> Result<Resp, RpcError> {
let span = Span::current();
ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
tracing::warn!(
tracing::trace!(
"OpenTelemetry subscriber not installed; making unsampled child context."
);
ctx.trace_context.new_child()
@@ -144,6 +146,7 @@ impl<Req, Resp> Channel<Req, Resp> {
response: &mut response,
request_id,
cancellation: &self.cancellation,
cancel: true,
};
self.to_dispatch
.send(DispatchRequest {
@@ -154,7 +157,7 @@ impl<Req, Resp> Channel<Req, Resp> {
response_completion,
})
.await
.map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
.map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?;
response_guard.response().await
}
}
@@ -162,19 +165,25 @@ impl<Req, Resp> Channel<Req, Resp> {
/// A server response that is completed by request dispatch when the corresponding response
/// arrives off the wire.
struct ResponseGuard<'a, Resp> {
response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
response: &'a mut oneshot::Receiver<Result<Resp, RpcError>>,
cancellation: &'a RequestCancellation,
request_id: u64,
cancel: bool,
}
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
/// rather cross-cutting errors that can always occur.
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[derive(thiserror::Error, Debug)]
pub enum RpcError {
/// The client disconnected from the server.
#[error("the client disconnected from the server")]
Disconnected,
#[error("the connection to the server was already shutdown")]
Shutdown,
/// The client failed to send the request.
#[error("the client failed to send the request")]
Send(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
/// An error occurred while waiting for the server response.
#[error("an error occurred while waiting for the server response")]
Receive(#[source] Arc<dyn std::error::Error + Send + Sync + 'static>),
/// The request exceeded its deadline.
#[error("the request exceeded its deadline")]
DeadlineExceeded,
@@ -183,24 +192,18 @@ pub enum RpcError {
Server(#[from] ServerError),
}
impl From<DeadlineExceededError> for RpcError {
fn from(_: DeadlineExceededError) -> Self {
RpcError::DeadlineExceeded
}
}
impl<Resp> ResponseGuard<'_, Resp> {
async fn response(mut self) -> Result<Resp, RpcError> {
let response = (&mut self.response).await;
// Cancel drop logic once a response has been received.
mem::forget(self);
self.cancel = false;
match response {
Ok(resp) => Ok(resp?.message?),
Ok(response) => response,
Err(oneshot::error::RecvError { .. }) => {
// The oneshot is Canceled when the dispatch task ends. In that case,
// there's nothing listening on the other side, so there's no point in
// propagating cancellation.
Err(RpcError::Disconnected)
Err(RpcError::Shutdown)
}
}
}
@@ -220,7 +223,9 @@ impl<Resp> Drop for ResponseGuard<'_, Resp> {
// dispatch task misses an early-arriving cancellation message, then it will see the
// receiver as closed.
self.response.close();
self.cancellation.cancel(self.request_id);
if self.cancel {
self.cancellation.cancel(self.request_id);
}
}
}
@@ -235,7 +240,6 @@ where
{
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
let (cancellation, canceled_requests) = cancellations();
let canceled_requests = canceled_requests;
NewClient {
client: Channel {
@@ -255,6 +259,7 @@ where
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
/// and dispatching responses to the appropriate channel.
#[must_use]
#[pin_project]
#[derive(Debug)]
pub struct RequestDispatch<Req, Resp, C> {
@@ -266,42 +271,18 @@ pub struct RequestDispatch<Req, Resp, C> {
/// Requests that were dropped.
canceled_requests: CanceledRequests,
/// Requests already written to the wire that haven't yet received responses.
in_flight_requests: InFlightRequests<Resp>,
in_flight_requests: InFlightRequests<Result<Resp, RpcError>>,
/// Configures limits to prevent unlimited resource usage.
config: Config,
}
/// Critical errors that result in a Channel disconnecting.
#[derive(thiserror::Error, Debug)]
pub enum ChannelError<E>
where
E: Error + Send + Sync + 'static,
{
/// Could not read from the transport.
#[error("could not read from the transport")]
Read(#[source] E),
/// Could not ready the transport for writes.
#[error("could not ready the transport for writes")]
Ready(#[source] E),
/// Could not write to the transport.
#[error("could not write to the transport")]
Write(#[source] E),
/// Could not flush the transport.
#[error("could not flush the transport")]
Flush(#[source] E),
/// Could not close the write end of the transport.
#[error("could not close the write end of the transport")]
Close(#[source] E),
/// Could not poll expired requests.
#[error("could not poll expired requests")]
Timer(#[source] tokio::time::error::Error),
}
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<Resp> {
fn in_flight_requests<'a>(
self: &'a mut Pin<&mut Self>,
) -> &'a mut InFlightRequests<Result<Resp, RpcError>> {
self.as_mut().project().in_flight_requests
}
@@ -361,7 +342,17 @@ where
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
self.transport_pin_mut()
.poll_next(cx)
.map_err(ChannelError::Read)
.map_err(|e| {
let e = Arc::new(e);
for span in self
.in_flight_requests()
.complete_all_requests(|| Err(RpcError::Receive(e.clone())))
{
let _entered = span.enter();
tracing::info!("ReceiveError");
}
ChannelError::Read(e)
})
.map_ok(|response| {
self.complete(response);
})
@@ -393,8 +384,7 @@ where
// track the status like is done with pending and cancelled requests.
if let Poll::Ready(Some(_)) = self
.in_flight_requests()
.poll_expired(cx)
.map_err(ChannelError::Timer)?
.poll_expired(cx, || Err(RpcError::DeadlineExceeded))
{
// Expired requests are considered complete; there is no compelling reason to send a
// cancellation message to the server, since it will have already exhausted its
@@ -506,7 +496,7 @@ where
Some(dispatch_request) => dispatch_request,
None => return Poll::Ready(None),
};
let entered = span.enter();
let _entered = span.enter();
// poll_next_request only returns Ready if there is room to buffer another request.
// Therefore, we can call write_request without fear of erroring due to a full
// buffer.
@@ -519,13 +509,16 @@ where
trace_context: ctx.trace_context,
},
});
self.start_send(request)?;
tracing::info!("SendRequest");
drop(entered);
self.in_flight_requests()
.insert_request(request_id, ctx, span, response_completion)
.insert_request(request_id, ctx, span.clone(), response_completion)
.expect("Request IDs should be unique");
match self.start_send(request) {
Ok(()) => tracing::info!("SendRequest"),
Err(e) => {
self.in_flight_requests()
.complete_request(request_id, Err(RpcError::Send(Box::new(e))));
}
}
Poll::Ready(Some(Ok(())))
}
@@ -550,7 +543,10 @@ where
/// Sends a server response to the client task that initiated the associated request.
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
self.in_flight_requests().complete_request(response)
self.in_flight_requests().complete_request(
response.request_id,
response.message.map_err(RpcError::Server),
)
}
}
@@ -599,76 +595,43 @@ struct DispatchRequest<Req, Resp> {
pub span: Span,
pub request_id: u64,
pub request: Req,
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
}
/// Sends request cancellation signals.
#[derive(Debug, Clone)]
struct RequestCancellation(mpsc::UnboundedSender<u64>);
/// A stream of IDs of requests that have been canceled.
#[derive(Debug)]
struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
/// Returns a channel to send request cancellation messages.
fn cancellations() -> (RequestCancellation, CanceledRequests) {
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
// bounded by the number of in-flight requests.
let (tx, rx) = mpsc::unbounded_channel();
(RequestCancellation(tx), CanceledRequests(rx))
}
impl RequestCancellation {
/// Cancels the request with ID `request_id`.
fn cancel(&self, request_id: u64) {
let _ = self.0.send(request_id);
}
}
impl CanceledRequests {
fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<u64>> {
self.0.poll_recv(cx)
}
}
impl Stream for CanceledRequests {
type Item = u64;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
self.poll_recv(cx)
}
pub response_completion: oneshot::Sender<Result<Resp, RpcError>>,
}
#[cfg(test)]
mod tests {
use super::{
cancellations, CanceledRequests, Channel, DispatchRequest, RequestCancellation,
RequestDispatch, ResponseGuard,
cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError,
};
use crate::{
client::{
in_flight_requests::{DeadlineExceededError, InFlightRequests},
Config,
},
context,
client::{in_flight_requests::InFlightRequests, Config},
context::{self, current},
transport::{self, channel::UnboundedChannel},
ClientMessage, Response,
ChannelError, ClientMessage, Response,
};
use assert_matches::assert_matches;
use futures::{prelude::*, task::*};
use std::{
convert::TryFrom,
fmt::Display,
marker::PhantomData,
pin::Pin,
sync::atomic::{AtomicUsize, Ordering},
sync::Arc,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use thiserror::Error;
use tokio::sync::{
mpsc::{self},
oneshot,
};
use tokio::sync::{mpsc, oneshot};
use tracing::Span;
#[tokio::test]
async fn response_completes_request_future() {
let (mut dispatch, mut _channel, mut server_channel) = set_up();
let cx = &mut Context::from_waker(&noop_waker_ref());
let cx = &mut Context::from_waker(noop_waker_ref());
let (tx, mut rx) = oneshot::channel();
dispatch
@@ -683,7 +646,7 @@ mod tests {
.await
.unwrap();
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp");
}
#[tokio::test]
@@ -694,10 +657,11 @@ mod tests {
response: &mut response,
cancellation: &cancellation,
request_id: 3,
cancel: true,
});
// resp's drop() is run, which should send a cancel message.
let cx = &mut Context::from_waker(&noop_waker_ref());
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(Some(3)));
let cx = &mut Context::from_waker(noop_waker_ref());
assert_eq!(canceled_requests.poll_recv(cx), Poll::Ready(Some(3)));
}
#[tokio::test]
@@ -714,23 +678,25 @@ mod tests {
response: &mut response,
cancellation: &cancellation,
request_id: 3,
cancel: true,
}
.response()
.await
.unwrap();
drop(cancellation);
let cx = &mut Context::from_waker(&noop_waker_ref());
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(None));
let cx = &mut Context::from_waker(noop_waker_ref());
assert_eq!(canceled_requests.poll_recv(cx), Poll::Ready(None));
}
#[tokio::test]
async fn stage_request() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let cx = &mut Context::from_waker(&noop_waker_ref());
let cx = &mut Context::from_waker(noop_waker_ref());
let (tx, mut rx) = oneshot::channel();
let _resp = send_request(&mut channel, "hi", tx, &mut rx).await;
#[allow(unstable_name_collisions)]
let req = dispatch.as_mut().poll_next_request(cx).ready();
assert!(req.is_some());
@@ -743,7 +709,7 @@ mod tests {
#[tokio::test]
async fn stage_request_channel_dropped_doesnt_panic() {
let (mut dispatch, mut channel, mut server_channel) = set_up();
let cx = &mut Context::from_waker(&noop_waker_ref());
let cx = &mut Context::from_waker(noop_waker_ref());
let (tx, mut rx) = oneshot::channel();
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
@@ -761,10 +727,11 @@ mod tests {
dispatch.await.unwrap();
}
#[allow(unstable_name_collisions)]
#[tokio::test]
async fn stage_request_response_future_dropped_is_canceled_before_sending() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let cx = &mut Context::from_waker(&noop_waker_ref());
let cx = &mut Context::from_waker(noop_waker_ref());
let (tx, mut rx) = oneshot::channel();
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
@@ -776,10 +743,11 @@ mod tests {
assert!(dispatch.as_mut().poll_next_request(cx).ready().is_none());
}
#[allow(unstable_name_collisions)]
#[tokio::test]
async fn stage_request_response_future_dropped_is_canceled_after_sending() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let cx = &mut Context::from_waker(&noop_waker_ref());
let cx = &mut Context::from_waker(noop_waker_ref());
let (tx, mut rx) = oneshot::channel();
let req = send_request(&mut channel, "hi", tx, &mut rx).await;
@@ -800,7 +768,7 @@ mod tests {
#[tokio::test]
async fn stage_request_response_closed_skipped() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let cx = &mut Context::from_waker(&noop_waker_ref());
let cx = &mut Context::from_waker(noop_waker_ref());
let (tx, mut rx) = oneshot::channel();
// Test that a request future that's closed its receiver but not yet canceled its request --
@@ -812,6 +780,185 @@ mod tests {
assert!(dispatch.as_mut().poll_next_request(cx).is_pending());
}
#[tokio::test]
async fn test_shutdown_error() {
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
let (dispatch, mut channel, _) = set_up();
let (tx, mut rx) = oneshot::channel();
// send succeeds
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
drop(dispatch);
// error on receive
assert_matches!(resp.response().await, Err(RpcError::Shutdown));
let (dispatch, channel, _) = set_up();
drop(dispatch);
// error on send
let resp = channel
.call(current(), "test_request", "hi".to_string())
.await;
assert_matches!(resp, Err(RpcError::Shutdown));
}
#[tokio::test]
async fn test_transport_error_write() {
let cause = TransportError::Write;
let (mut dispatch, mut channel, mut cx) = setup_always_err(cause);
let (tx, mut rx) = oneshot::channel();
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
assert!(dispatch.as_mut().poll(&mut cx).is_pending());
let res = resp.response().await;
assert_matches!(res, Err(RpcError::Send(_)));
let client_error: anyhow::Error = res.unwrap_err().into();
let mut chain = client_error.chain();
chain.next(); // original RpcError
assert_eq!(
chain
.next()
.unwrap()
.downcast_ref::<ChannelError<TransportError>>(),
Some(&ChannelError::Write(cause))
);
assert_eq!(
client_error.root_cause().downcast_ref::<TransportError>(),
Some(&cause)
);
}
#[tokio::test]
async fn test_transport_error_read() {
let cause = TransportError::Read;
let (mut dispatch, mut channel, mut cx) = setup_always_err(cause);
let (tx, mut rx) = oneshot::channel();
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
assert_eq!(
dispatch.as_mut().pump_write(&mut cx),
Poll::Ready(Some(Ok(())))
);
assert_eq!(
dispatch.as_mut().pump_read(&mut cx),
Poll::Ready(Some(Err(ChannelError::Read(Arc::new(cause)))))
);
assert_matches!(resp.response().await, Err(RpcError::Receive(_)));
}
#[tokio::test]
async fn test_transport_error_ready() {
let cause = TransportError::Ready;
let (mut dispatch, _, mut cx) = setup_always_err(cause);
assert_eq!(
dispatch.as_mut().poll(&mut cx),
Poll::Ready(Err(ChannelError::Ready(cause)))
);
}
#[tokio::test]
async fn test_transport_error_flush() {
let cause = TransportError::Flush;
let (mut dispatch, _, mut cx) = setup_always_err(cause);
assert_eq!(
dispatch.as_mut().poll(&mut cx),
Poll::Ready(Err(ChannelError::Flush(cause)))
);
}
#[tokio::test]
async fn test_transport_error_close() {
let cause = TransportError::Close;
let (mut dispatch, channel, mut cx) = setup_always_err(cause);
drop(channel);
assert_eq!(
dispatch.as_mut().poll(&mut cx),
Poll::Ready(Err(ChannelError::Close(cause)))
);
}
fn setup_always_err(
cause: TransportError,
) -> (
Pin<Box<RequestDispatch<String, String, AlwaysErrorTransport<String>>>>,
Channel<String, String>,
Context<'static>,
) {
let (to_dispatch, pending_requests) = mpsc::channel(1);
let (cancellation, canceled_requests) = cancellations();
let transport: AlwaysErrorTransport<String> = AlwaysErrorTransport(cause, PhantomData);
let dispatch = Box::pin(RequestDispatch::<String, String, _> {
transport: transport.fuse(),
pending_requests,
canceled_requests,
in_flight_requests: InFlightRequests::default(),
config: Config::default(),
});
let channel = Channel {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicUsize::new(0)),
};
let cx = Context::from_waker(noop_waker_ref());
(dispatch, channel, cx)
}
struct AlwaysErrorTransport<I>(TransportError, PhantomData<I>);
#[derive(Debug, Error, PartialEq, Eq, Clone, Copy)]
enum TransportError {
Read,
Ready,
Write,
Flush,
Close,
}
impl Display for TransportError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&format!("{self:?}"))
}
}
impl<I: Clone, S> Sink<S> for AlwaysErrorTransport<I> {
type Error = TransportError;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.0 {
TransportError::Ready => Poll::Ready(Err(self.0)),
TransportError::Flush => Poll::Pending,
_ => Poll::Ready(Ok(())),
}
}
fn start_send(self: Pin<&mut Self>, _: S) -> Result<(), Self::Error> {
if matches!(self.0, TransportError::Write) {
Err(self.0)
} else {
Ok(())
}
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if matches!(self.0, TransportError::Flush) {
Poll::Ready(Err(self.0))
} else {
Poll::Ready(Ok(()))
}
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if matches!(self.0, TransportError::Close) {
Poll::Ready(Err(self.0))
} else {
Poll::Ready(Ok(()))
}
}
}
impl<I: Clone> Stream for AlwaysErrorTransport<I> {
type Item = Result<Response<I>, TransportError>;
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if matches!(self.0, TransportError::Read) {
Poll::Ready(Some(Err(self.0)))
} else {
Poll::Pending
}
}
}
fn set_up() -> (
Pin<
Box<
@@ -828,18 +975,17 @@ mod tests {
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
let (to_dispatch, pending_requests) = mpsc::channel(1);
let (cancel_tx, canceled_requests) = mpsc::unbounded_channel();
let (cancellation, canceled_requests) = cancellations();
let (client_channel, server_channel) = transport::channel::unbounded();
let dispatch = RequestDispatch::<String, String, _> {
transport: client_channel.fuse(),
pending_requests: pending_requests,
canceled_requests: CanceledRequests(canceled_requests),
pending_requests,
canceled_requests,
in_flight_requests: InFlightRequests::default(),
config: Config::default(),
};
let cancellation = RequestCancellation(cancel_tx);
let channel = Channel {
to_dispatch,
cancellation,
@@ -852,8 +998,8 @@ mod tests {
async fn send_request<'a>(
channel: &'a mut Channel<String, String>,
request: &str,
response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
response_completion: oneshot::Sender<Result<String, RpcError>>,
response: &'a mut oneshot::Receiver<Result<String, RpcError>>,
) -> ResponseGuard<'a, String> {
let request_id =
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
@@ -864,13 +1010,14 @@ mod tests {
request: request.to_string(),
response_completion,
};
channel.to_dispatch.send(request).await.unwrap();
ResponseGuard {
let response_guard = ResponseGuard {
response,
cancellation: &channel.cancellation,
request_id,
}
cancel: true,
};
channel.to_dispatch.send(request).await.unwrap();
response_guard
}
async fn send_response(

View File

@@ -1,7 +1,6 @@
use crate::{
context,
util::{Compact, TimeUntil},
Response,
};
use fnv::FnvHashMap;
use std::{
@@ -28,17 +27,11 @@ impl<Resp> Default for InFlightRequests<Resp> {
}
}
/// The request exceeded its deadline.
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[error("the request exceeded its deadline")]
pub struct DeadlineExceededError;
#[derive(Debug)]
struct RequestData<Resp> {
struct RequestData<Res> {
ctx: context::Context,
span: Span,
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
response_completion: oneshot::Sender<Res>,
/// The key to remove the timer for the request's deadline.
deadline_key: delay_queue::Key,
}
@@ -48,7 +41,7 @@ struct RequestData<Resp> {
#[derive(Debug)]
pub struct AlreadyExistsError;
impl<Resp> InFlightRequests<Resp> {
impl<Res> InFlightRequests<Res> {
/// Returns the number of in-flight requests.
pub fn len(&self) -> usize {
self.request_data.len()
@@ -65,7 +58,7 @@ impl<Resp> InFlightRequests<Resp> {
request_id: u64,
ctx: context::Context,
span: Span,
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
response_completion: oneshot::Sender<Res>,
) -> Result<(), AlreadyExistsError> {
match self.request_data.entry(request_id) {
hash_map::Entry::Vacant(vacant) => {
@@ -84,25 +77,35 @@ impl<Resp> InFlightRequests<Resp> {
}
/// Removes a request without aborting. Returns true iff the request was found.
pub fn complete_request(&mut self, response: Response<Resp>) -> bool {
if let Some(request_data) = self.request_data.remove(&response.request_id) {
pub fn complete_request(&mut self, request_id: u64, result: Res) -> bool {
if let Some(request_data) = self.request_data.remove(&request_id) {
let _entered = request_data.span.enter();
tracing::info!("ReceiveResponse");
self.request_data.compact(0.1);
self.deadlines.remove(&request_data.deadline_key);
let _ = request_data.response_completion.send(Ok(response));
let _ = request_data.response_completion.send(result);
return true;
}
tracing::debug!(
"No in-flight request found for request_id = {}.",
response.request_id
);
tracing::debug!("No in-flight request found for request_id = {request_id}.");
// If the response completion was absent, then the request was already canceled.
false
}
/// Completes all requests using the provided function.
/// Returns Spans for all completes requests.
pub fn complete_all_requests<'a>(
&'a mut self,
mut result: impl FnMut() -> Res + 'a,
) -> impl Iterator<Item = Span> + 'a {
self.deadlines.clear();
self.request_data.drain().map(move |(_, request_data)| {
let _ = request_data.response_completion.send(result());
request_data.span
})
}
/// Cancels a request without completing (typically used when a request handle was dropped
/// before the request completed).
pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> {
@@ -120,18 +123,17 @@ impl<Resp> InFlightRequests<Resp> {
pub fn poll_expired(
&mut self,
cx: &mut Context,
) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
self.deadlines.poll_expired(cx).map_ok(|expired| {
let request_id = expired.into_inner();
expired_error: impl Fn() -> Res,
) -> Poll<Option<u64>> {
self.deadlines.poll_expired(cx).map(|expired| {
let request_id = expired?.into_inner();
if let Some(request_data) = self.request_data.remove(&request_id) {
let _entered = request_data.span.enter();
tracing::error!("DeadlineExceeded");
self.request_data.compact(0.1);
let _ = request_data
.response_completion
.send(Err(DeadlineExceededError));
let _ = request_data.response_completion.send(expired_error());
}
request_id
Some(request_id)
})
}
}

View File

@@ -209,7 +209,7 @@
pub use serde;
#[cfg(feature = "serde-transport")]
pub use tokio_serde;
pub use {tokio_serde, tokio_util};
#[cfg(feature = "serde-transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-transport")))]
@@ -300,6 +300,7 @@ pub use tarpc_plugins::service;
/// `async`, meaning that this should not break existing code.
pub use tarpc_plugins::server;
pub(crate) mod cancellations;
pub mod client;
pub mod context;
pub mod server;
@@ -310,6 +311,7 @@ pub use crate::transport::sealed::Transport;
use anyhow::Context as _;
use futures::task::*;
use std::sync::Arc;
use std::{error::Error, fmt::Display, io, time::SystemTime};
/// A message from a client to a server.
@@ -382,6 +384,29 @@ pub struct ServerError {
pub detail: String,
}
/// Critical errors that result in a Channel disconnecting.
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
pub enum ChannelError<E>
where
E: Error + Send + Sync + 'static,
{
/// Could not read from the transport.
#[error("could not read from the transport")]
Read(#[source] Arc<E>),
/// Could not ready the transport for writes.
#[error("could not ready the transport for writes")]
Ready(#[source] E),
/// Could not write to the transport.
#[error("could not write to the transport")]
Write(#[source] E),
/// Could not flush the transport.
#[error("could not flush the transport")]
Flush(#[source] E),
/// Could not close the write end of the transport.
#[error("could not close the write end of the transport")]
Close(#[source] E),
}
impl<T> Request<T> {
/// Returns the deadline for this request.
pub fn deadline(&self) -> &SystemTime {

View File

@@ -129,14 +129,6 @@ pub mod tcp {
tokio_util::codec::length_delimited,
};
mod private {
use super::*;
pub trait Sealed {}
impl<Item, SinkItem, Codec> Sealed for Transport<TcpStream, Item, SinkItem, Codec> {}
}
impl<Item, SinkItem, Codec> Transport<TcpStream, Item, SinkItem, Codec> {
/// Returns the peer address of the underlying TcpStream.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
@@ -149,6 +141,7 @@ pub mod tcp {
}
/// A connection Future that also exposes the length-delimited framing config.
#[must_use]
#[pin_project]
pub struct Connect<T, Item, SinkItem, CodecFn> {
#[pin]
@@ -276,6 +269,270 @@ pub mod tcp {
}
}
#[cfg(all(unix, feature = "unix"))]
#[cfg_attr(docsrs, doc(cfg(all(unix, feature = "unix"))))]
/// Unix Domain Socket support for generic transport using Tokio.
pub mod unix {
use {
super::*,
futures::ready,
std::{marker::PhantomData, path::Path},
tokio::net::{unix::SocketAddr, UnixListener, UnixStream},
tokio_util::codec::length_delimited,
};
impl<Item, SinkItem, Codec> Transport<UnixStream, Item, SinkItem, Codec> {
/// Returns the socket address of the remote half of the underlying [`UnixStream`].
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().peer_addr()
}
/// Returns the socket address of the local half of the underlying [`UnixStream`].
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().get_ref().local_addr()
}
}
/// A connection Future that also exposes the length-delimited framing config.
#[must_use]
#[pin_project]
pub struct Connect<T, Item, SinkItem, CodecFn> {
#[pin]
inner: T,
codec_fn: CodecFn,
config: length_delimited::Builder,
ghost: PhantomData<(fn(SinkItem), fn() -> Item)>,
}
impl<T, Item, SinkItem, Codec, CodecFn> Future for Connect<T, Item, SinkItem, CodecFn>
where
T: Future<Output = io::Result<UnixStream>>,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
type Output = io::Result<Transport<UnixStream, Item, SinkItem, Codec>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let io = ready!(self.as_mut().project().inner.poll(cx))?;
Poll::Ready(Ok(new(self.config.new_framed(io), (self.codec_fn)())))
}
}
impl<T, Item, SinkItem, CodecFn> Connect<T, Item, SinkItem, CodecFn> {
/// Returns an immutable reference to the length-delimited codec's config.
pub fn config(&self) -> &length_delimited::Builder {
&self.config
}
/// Returns a mutable reference to the length-delimited codec's config.
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
&mut self.config
}
}
/// Connects to socket named by `path`, wrapping the connection in a Unix Domain Socket
/// transport.
pub fn connect<P, Item, SinkItem, Codec, CodecFn>(
path: P,
codec_fn: CodecFn,
) -> Connect<impl Future<Output = io::Result<UnixStream>>, Item, SinkItem, CodecFn>
where
P: AsRef<Path>,
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
Connect {
inner: UnixStream::connect(path),
codec_fn,
config: LengthDelimitedCodec::builder(),
ghost: PhantomData,
}
}
/// Listens on the socket named by `path`, wrapping accepted connections in Unix Domain Socket
/// transports.
pub async fn listen<P, Item, SinkItem, Codec, CodecFn>(
path: P,
codec_fn: CodecFn,
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
where
P: AsRef<Path>,
Item: for<'de> Deserialize<'de>,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
let listener = UnixListener::bind(path)?;
let local_addr = listener.local_addr()?;
Ok(Incoming {
listener,
codec_fn,
local_addr,
config: LengthDelimitedCodec::builder(),
ghost: PhantomData,
})
}
/// A [`UnixListener`] that wraps connections in [transports](Transport).
#[pin_project]
#[derive(Debug)]
pub struct Incoming<Item, SinkItem, Codec, CodecFn> {
listener: UnixListener,
local_addr: SocketAddr,
codec_fn: CodecFn,
config: length_delimited::Builder,
ghost: PhantomData<(fn() -> Item, fn(SinkItem), Codec)>,
}
impl<Item, SinkItem, Codec, CodecFn> Incoming<Item, SinkItem, Codec, CodecFn> {
/// Returns the the socket address being listened on.
pub fn local_addr(&self) -> &SocketAddr {
&self.local_addr
}
/// Returns an immutable reference to the length-delimited codec's config.
pub fn config(&self) -> &length_delimited::Builder {
&self.config
}
/// Returns a mutable reference to the length-delimited codec's config.
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
&mut self.config
}
}
impl<Item, SinkItem, Codec, CodecFn> Stream for Incoming<Item, SinkItem, Codec, CodecFn>
where
Item: for<'de> Deserialize<'de>,
SinkItem: Serialize,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
type Item = io::Result<Transport<UnixStream, Item, SinkItem, Codec>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let conn: UnixStream = ready!(self.as_mut().project().listener.poll_accept(cx)?).0;
Poll::Ready(Some(Ok(new(
self.config.new_framed(conn),
(self.codec_fn)(),
))))
}
}
/// A temporary `PathBuf` that lives in `std::env::temp_dir` and is removed on drop.
pub struct TempPathBuf(std::path::PathBuf);
impl TempPathBuf {
/// A named socket that results in `<tempdir>/<name>`
pub fn new<S: AsRef<str>>(name: S) -> Self {
let mut sock = std::env::temp_dir();
sock.push(name.as_ref());
Self(sock)
}
/// Appends a random hex string to the socket name resulting in
/// `<tempdir>/<name>_<xxxxx>`
pub fn with_random<S: AsRef<str>>(name: S) -> Self {
Self::new(format!("{}_{:016x}", name.as_ref(), rand::random::<u64>()))
}
}
impl AsRef<std::path::Path> for TempPathBuf {
fn as_ref(&self) -> &std::path::Path {
self.0.as_path()
}
}
impl Drop for TempPathBuf {
fn drop(&mut self) {
// This will remove the file pointed to by this PathBuf if it exists, however Err's can
// be returned such as attempting to remove a non-existing file, or one which we don't
// have permission to remove. In these cases the Err is swallowed
let _ = std::fs::remove_file(&self.0);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_serde::formats::SymmetricalJson;
#[test]
fn temp_path_buf_non_random() {
let sock = TempPathBuf::new("test");
let mut good = std::env::temp_dir();
good.push("test");
assert_eq!(sock.as_ref(), good);
assert_eq!(sock.as_ref().file_name().unwrap(), "test");
}
#[test]
fn temp_path_buf_random() {
let sock = TempPathBuf::with_random("test");
let good = std::env::temp_dir();
assert!(sock.as_ref().starts_with(good));
// Since there are 16 random characters we just assert the file_name has the right name
// and starts with the correct string 'test_'
// file name: test_xxxxxxxxxxxxxxxx
// test = 4
// _ = 1
// <hex> = 16
// total = 21
let fname = sock.as_ref().file_name().unwrap().to_string_lossy();
assert!(fname.starts_with("test_"));
assert_eq!(fname.len(), 21);
}
#[test]
fn temp_path_buf_non_existing() {
let sock = TempPathBuf::with_random("test");
let sock_path = std::path::PathBuf::from(sock.as_ref());
// No actual file has been created yet
assert!(!sock_path.exists());
// Should not panic
std::mem::drop(sock);
assert!(!sock_path.exists());
}
#[test]
fn temp_path_buf_existing_file() {
let sock = TempPathBuf::with_random("test");
let sock_path = std::path::PathBuf::from(sock.as_ref());
let _file = std::fs::File::create(&sock).unwrap();
assert!(sock_path.exists());
std::mem::drop(sock);
assert!(!sock_path.exists());
}
#[test]
fn temp_path_buf_preexisting_file() {
let mut pre_existing = std::env::temp_dir();
pre_existing.push("test");
let _file = std::fs::File::create(&pre_existing).unwrap();
let sock = TempPathBuf::new("test");
let sock_path = std::path::PathBuf::from(sock.as_ref());
assert!(sock_path.exists());
std::mem::drop(sock);
assert!(!sock_path.exists());
}
#[tokio::test]
async fn temp_path_buf_for_socket() {
let sock = TempPathBuf::with_random("test");
// Save path for testing after drop
let sock_path = std::path::PathBuf::from(sock.as_ref());
// create the actual socket
let _ = listen(&sock, SymmetricalJson::<String>::default).await;
assert!(sock_path.exists());
std::mem::drop(sock);
assert!(!sock_path.exists());
}
}
}
#[cfg(test)]
mod tests {
use super::Transport;
@@ -290,7 +547,7 @@ mod tests {
use tokio_serde::formats::SymmetricalJson;
fn ctx() -> Context<'static> {
Context::from_waker(&noop_waker_ref())
Context::from_waker(noop_waker_ref())
}
struct TestIo(Cursor<Vec<u8>>);
@@ -392,4 +649,24 @@ mod tests {
assert_matches!(transport.next().await, None);
Ok(())
}
#[cfg(all(unix, feature = "unix"))]
#[tokio::test]
async fn uds() -> io::Result<()> {
use super::unix;
use super::*;
let sock = unix::TempPathBuf::with_random("uds");
let mut listener = unix::listen(&sock, SymmetricalJson::<String>::default).await?;
tokio::spawn(async move {
let mut transport = listener.next().await.unwrap().unwrap();
let message = transport.next().await.unwrap().unwrap();
transport.send(message).await.unwrap();
});
let mut transport = unix::connect(&sock, SymmetricalJson::<String>::default).await?;
transport.send(String::from("test")).await?;
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
assert_matches!(transport.next().await, None);
Ok(())
}
}

View File

@@ -7,8 +7,9 @@
//! Provides a server that concurrently handles many connections sending multiplexed requests.
use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation},
context::{self, SpanExt},
trace, ClientMessage, Request, Response, Transport,
trace, ChannelError, ClientMessage, Request, Response, Transport,
};
use ::tokio::sync::mpsc;
use futures::{
@@ -20,7 +21,7 @@ use futures::{
};
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
use pin_project::pin_project;
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin};
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc};
use tracing::{info_span, instrument::Instrument, Span};
mod in_flight_requests;
@@ -111,6 +112,11 @@ pub struct BaseChannel<Req, Resp, T> {
/// Writes responses to the wire and reads requests off the wire.
#[pin]
transport: Fuse<T>,
/// In-flight requests that were dropped by the server before completion.
#[pin]
canceled_requests: CanceledRequests,
/// Notifies `canceled_requests` when a request is canceled.
request_cancellation: RequestCancellation,
/// Holds data necessary to clean up in-flight requests.
in_flight_requests: InFlightRequests,
/// Types the request and response.
@@ -123,9 +129,12 @@ where
{
/// Creates a new channel backed by `transport` and configured with `config`.
pub fn new(config: Config, transport: T) -> Self {
let (request_cancellation, canceled_requests) = cancellations();
BaseChannel {
config,
transport: transport.fuse(),
canceled_requests,
request_cancellation,
in_flight_requests: InFlightRequests::default(),
ghost: PhantomData,
}
@@ -150,12 +159,18 @@ where
self.as_mut().project().in_flight_requests
}
fn canceled_requests_pin_mut<'a>(
self: &'a mut Pin<&mut Self>,
) -> Pin<&'a mut CanceledRequests> {
self.as_mut().project().canceled_requests
}
fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<T>> {
self.as_mut().project().transport
}
fn start_request(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
mut request: Request<Req>,
) -> Result<TrackedRequest<Req>, AlreadyExistsError> {
let span = info_span!(
@@ -175,7 +190,7 @@ where
});
let entered = span.enter();
tracing::info!("ReceiveRequest");
let start = self.project().in_flight_requests.start_request(
let start = self.in_flight_requests_mut().start_request(
request.id,
request.context.deadline,
span.clone(),
@@ -184,9 +199,14 @@ where
Ok(abort_registration) => {
drop(entered);
Ok(TrackedRequest {
request,
abort_registration,
span,
response_guard: ResponseGuard {
request_id: request.id,
request_cancellation: self.request_cancellation.clone(),
cancel: false,
},
request,
})
}
Err(AlreadyExistsError) => {
@@ -213,6 +233,8 @@ pub struct TrackedRequest<Req> {
pub abort_registration: AbortRegistration,
/// A span representing the server processing of this request.
pub span: Span,
/// An inert response guard. Becomes active in an InFlightRequest.
pub response_guard: ResponseGuard,
}
/// The server end of an open connection with a client, receiving requests from, and sending
@@ -231,13 +253,15 @@ pub struct TrackedRequest<Req> {
/// [`Sink::send`](futures::sink::SinkExt::send) - A user is free to manually read requests
/// from, and send responses into, a Channel in lieu of the previous methods. Channels stream
/// [`TrackedRequests`](TrackedRequest), which, in addition to the request itself, contains the
/// server [`Span`] and request lifetime [`AbortRegistration`]. Wrapping response
/// logic in an [`Abortable`] future using the abort registration will ensure that the response
/// does not execute longer than the request deadline. The `Channel` itself will clean up
/// request state once either the deadline expires, or a cancellation message is received, or a
/// response is sent. Because there is no guarantee that a cancellation message will ever be
/// received for a request, or that requests come with reasonably short deadlines, services
/// should strive to clean up Channel resources by sending a response for every request.
/// server [`Span`], request lifetime [`AbortRegistration`], and an inert [`ResponseGuard`].
/// Wrapping response logic in an [`Abortable`] future using the abort registration will ensure
/// that the response does not execute longer than the request deadline. The `Channel` itself
/// will clean up request state once either the deadline expires, or the response guard is
/// dropped, or a response is sent.
///
/// Channels must be implemented using the decorator pattern: the only way to create a
/// `TrackedRequest` is to get one from another `Channel`. Ultimately, all `TrackedRequests` are
/// created by [`BaseChannel`].
pub trait Channel
where
Self: Transport<Response<<Self as Channel>::Resp>, TrackedRequest<<Self as Channel>::Req>>,
@@ -313,20 +337,6 @@ where
}
}
/// Critical errors that result in a Channel disconnecting.
#[derive(thiserror::Error, Debug)]
pub enum ChannelError<E>
where
E: Error + Send + Sync + 'static,
{
/// An error occurred reading from, or writing to, the transport.
#[error("an error occurred in the transport: {0}")]
Transport(#[source] E),
/// An error occurred while polling expired requests.
#[error("an error occurred while polling expired requests: {0}")]
Timer(#[source] ::tokio::time::error::Error),
}
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
@@ -334,20 +344,45 @@ where
type Item = Result<TrackedRequest<Req>, ChannelError<T::Error>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
#[derive(Debug)]
#[derive(Clone, Copy, Debug)]
enum ReceiverStatus {
Ready,
Pending,
Closed,
}
impl ReceiverStatus {
fn combine(self, other: Self) -> Self {
use ReceiverStatus::*;
match (self, other) {
(Ready, _) | (_, Ready) => Ready,
(Closed, Closed) => Closed,
(Pending, Closed) | (Closed, Pending) | (Pending, Pending) => Pending,
}
}
}
use ReceiverStatus::*;
loop {
let expiration_status = match self
.in_flight_requests_mut()
.poll_expired(cx)
.map_err(ChannelError::Timer)?
{
let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) {
Poll::Ready(Some(request_id)) => {
if let Some(span) = self.in_flight_requests_mut().remove_request(request_id) {
let _entered = span.enter();
tracing::info!("ResponseCancelled");
}
Ready
}
// Pending cancellations don't block Channel closure, because all they do is ensure
// the Channel's internal state is cleaned up. But Channel closure also cleans up
// the Channel state, so there's no reason to wait on a cancellation before
// closing.
//
// Ready(None) can't happen, since `self` holds a Cancellation.
Poll::Pending | Poll::Ready(None) => Closed,
};
let expiration_status = match self.in_flight_requests_mut().poll_expired(cx) {
// No need to send a response, since the client wouldn't be waiting for one
// anymore.
Poll::Ready(Some(_)) => Ready,
@@ -358,7 +393,7 @@ where
let request_status = match self
.transport_pin_mut()
.poll_next(cx)
.map_err(ChannelError::Transport)?
.map_err(|e| ChannelError::Read(Arc::new(e)))?
{
Poll::Ready(Some(message)) => match message {
ClientMessage::Request(request) => {
@@ -395,10 +430,13 @@ where
expiration_status,
request_status
);
match (expiration_status, request_status) {
(Ready, _) | (_, Ready) => continue,
(Closed, Closed) => return Poll::Ready(None),
(Pending, Closed) | (Closed, Pending) | (Pending, Pending) => return Poll::Pending,
match cancellation_status
.combine(expiration_status)
.combine(request_status)
{
Ready => continue,
Closed => return Poll::Ready(None),
Pending => return Poll::Pending,
}
}
}
@@ -415,14 +453,12 @@ where
self.project()
.transport
.poll_ready(cx)
.map_err(ChannelError::Transport)
.map_err(ChannelError::Ready)
}
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
if let Some(span) = self
.as_mut()
.project()
.in_flight_requests
.in_flight_requests_mut()
.remove_request(response.request_id)
{
let _entered = span.enter();
@@ -430,7 +466,7 @@ where
self.project()
.transport
.start_send(response)
.map_err(ChannelError::Transport)
.map_err(ChannelError::Write)
} else {
// If the request isn't tracked anymore, there's no need to send the response.
Ok(())
@@ -442,14 +478,14 @@ where
self.project()
.transport
.poll_flush(cx)
.map_err(ChannelError::Transport)
.map_err(ChannelError::Flush)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project()
.transport
.poll_close(cx)
.map_err(ChannelError::Transport)
.map_err(ChannelError::Close)
}
}
@@ -499,6 +535,11 @@ impl<C> Requests<C>
where
C: Channel,
{
/// Returns a reference to the inner channel over which messages are sent and received.
pub fn channel(&self) -> &C {
&self.channel
}
/// Returns the inner channel over which messages are sent and received.
pub fn channel_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
self.as_mut().project().channel
@@ -515,12 +556,24 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<InFlightRequest<C::Req, C::Resp>, C::Error>>> {
self.channel_pin_mut()
.poll_next(cx)
.map_ok(|request| InFlightRequest {
request,
response_tx: self.responses_tx.clone(),
})
self.channel_pin_mut().poll_next(cx).map_ok(
|TrackedRequest {
request,
abort_registration,
span,
mut response_guard,
}| {
// The response guard becomes active once in an InFlightRequest.
response_guard.cancel = true;
InFlightRequest {
request,
abort_registration,
span,
response_guard,
response_tx: self.responses_tx.clone(),
}
},
)
}
fn pump_write(
@@ -597,17 +650,40 @@ where
}
}
/// A fail-safe to ensure requests are properly canceled if request processing is aborted before
/// completing.
#[derive(Debug)]
pub struct ResponseGuard {
request_cancellation: RequestCancellation,
request_id: u64,
cancel: bool,
}
impl Drop for ResponseGuard {
fn drop(&mut self) {
if self.cancel {
self.request_cancellation.cancel(self.request_id);
}
}
}
/// A request produced by [Channel::requests].
///
/// If dropped without calling [`execute`](InFlightRequest::execute), a cancellation message will
/// be sent to the Channel to clean up associated request state.
#[derive(Debug)]
pub struct InFlightRequest<Req, Res> {
request: TrackedRequest<Req>,
request: Request<Req>,
abort_registration: AbortRegistration,
response_guard: ResponseGuard,
span: Span,
response_tx: mpsc::Sender<Response<Res>>,
}
impl<Req, Res> InFlightRequest<Req, Res> {
/// Returns a reference to the request.
pub fn get(&self) -> &Request<Req> {
&self.request.request
&self.request
}
/// Returns a [future](Future) that executes the request using the given [service
@@ -621,25 +697,29 @@ impl<Req, Res> InFlightRequest<Req, Res> {
/// message](ClientMessage::Cancel) for this request.
/// 2. The request [deadline](crate::context::Context::deadline) is reached.
/// 3. The service function completes.
///
/// If the returned Future is dropped before completion, a cancellation message will be sent to
/// the Channel to clean up associated request state.
pub async fn execute<S>(self, serve: S)
where
S: Serve<Req, Resp = Res>,
{
let Self {
response_tx,
mut response_guard,
abort_registration,
span,
request:
TrackedRequest {
abort_registration,
span,
request:
Request {
context,
message,
id: request_id,
},
Request {
context,
message,
id: request_id,
},
} = self;
let method = serve.method(&message);
// TODO(https://github.com/rust-lang/rust-clippy/issues/9111)
// remove when clippy is fixed
#[allow(clippy::needless_borrow)]
span.record("otel.name", &method.unwrap_or(""));
let _ = Abortable::new(
async move {
@@ -657,6 +737,10 @@ impl<Req, Res> InFlightRequest<Req, Res> {
)
.instrument(span)
.await;
// Request processing has completed, meaning either the channel canceled the request or
// a request was sent back to the channel. Either way, the channel will clean up the
// request data, so the request does not need to be canceled.
response_guard.cancel = false;
}
}
@@ -741,9 +825,10 @@ mod tests {
channel::Channel<Response<Resp>, ClientMessage<Req>>,
) {
let (tx, rx) = crate::transport::channel::bounded(capacity);
let mut config = Config::default();
// Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded).
config.pending_response_buffer = capacity + 1;
let config = Config {
pending_response_buffer: capacity + 1,
};
(Box::pin(BaseChannel::new(config, rx).requests()), tx)
}
@@ -932,6 +1017,44 @@ mod tests {
assert_eq!(channel.in_flight_requests(), 0);
}
#[tokio::test]
async fn in_flight_request_drop_cancels_request() {
let (mut requests, mut tx) = test_requests::<(), ()>();
tx.send(fake_request(())).await.unwrap();
let request = match requests.as_mut().poll_next(&mut noop_context()) {
Poll::Ready(Some(Ok(request))) => request,
result => panic!("Unexpected result: {:?}", result),
};
drop(request);
let poll = requests
.as_mut()
.channel_pin_mut()
.poll_next(&mut noop_context());
assert!(poll.is_pending());
let in_flight_requests = requests.channel().in_flight_requests();
assert_eq!(in_flight_requests, 0);
}
#[tokio::test]
async fn in_flight_requests_successful_execute_doesnt_cancel_request() {
let (mut requests, mut tx) = test_requests::<(), ()>();
tx.send(fake_request(())).await.unwrap();
let request = match requests.as_mut().poll_next(&mut noop_context()) {
Poll::Ready(Some(Ok(request))) => request,
result => panic!("Unexpected result: {:?}", result),
};
request.execute(|_, _| async {}).await;
assert!(requests
.as_mut()
.channel_pin_mut()
.canceled_requests
.poll_recv(&mut noop_context())
.is_pending());
}
#[tokio::test]
async fn requests_poll_next_response_returns_pending_when_buffer_full() {
let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);

View File

@@ -94,16 +94,14 @@ impl InFlightRequests {
}
/// Yields a request that has expired, aborting any ongoing processing of that request.
pub fn poll_expired(
&mut self,
cx: &mut Context,
) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
pub fn poll_expired(&mut self, cx: &mut Context) -> Poll<Option<u64>> {
if self.deadlines.is_empty() {
// TODO(https://github.com/tokio-rs/tokio/issues/4161)
// This is a workaround for DelayQueue not always treating this case correctly.
return Poll::Ready(None);
}
self.deadlines.poll_expired(cx).map_ok(|expired| {
self.deadlines.poll_expired(cx).map(|expired| {
let expired = expired?;
if let Some(RequestData {
abort_handle, span, ..
}) = self.request_data.remove(expired.get_ref())
@@ -113,7 +111,7 @@ impl InFlightRequests {
abort_handle.abort();
tracing::error!("DeadlineExceeded");
}
expired.into_inner()
Some(expired.into_inner())
})
}
}
@@ -161,7 +159,7 @@ mod tests {
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Ready(Some(Ok(_)))
Poll::Ready(Some(_))
);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
@@ -178,7 +176,7 @@ mod tests {
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
assert_eq!(in_flight_requests.cancel_request(0), true);
assert!(in_flight_requests.cancel_request(0));
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Ready(Err(_))

View File

@@ -282,7 +282,7 @@ where
fn ctx() -> Context<'static> {
use futures::task::*;
Context::from_waker(&noop_waker_ref())
Context::from_waker(noop_waker_ref())
}
#[test]

View File

@@ -5,8 +5,9 @@
// https://opensource.org/licenses/MIT.
use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation},
context,
server::{Channel, Config, TrackedRequest},
server::{Channel, Config, ResponseGuard, TrackedRequest},
Request, Response,
};
use futures::{task::*, Sink, Stream};
@@ -22,6 +23,8 @@ pub(crate) struct FakeChannel<In, Out> {
pub sink: VecDeque<Out>,
pub config: Config,
pub in_flight_requests: super::in_flight_requests::InFlightRequests,
pub request_cancellation: RequestCancellation,
pub canceled_requests: CanceledRequests,
}
impl<In, Out> Stream for FakeChannel<In, Out>
@@ -86,6 +89,7 @@ where
impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
pub fn push_req(&mut self, id: u64, message: Req) {
let (_, abort_registration) = futures::future::AbortHandle::new_pair();
let (request_cancellation, _) = cancellations();
self.stream.push_back(Ok(TrackedRequest {
request: Request {
context: context::Context {
@@ -97,17 +101,25 @@ impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
},
abort_registration,
span: Span::none(),
response_guard: ResponseGuard {
request_cancellation,
request_id: id,
cancel: false,
},
}));
}
}
impl FakeChannel<(), ()> {
pub fn default<Req, Resp>() -> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
let (request_cancellation, canceled_requests) = cancellations();
FakeChannel {
stream: Default::default(),
sink: Default::default(),
config: Default::default(),
in_flight_requests: Default::default(),
request_cancellation,
canceled_requests,
}
}
}
@@ -123,5 +135,5 @@ impl<T> PollExt for Poll<Option<T>> {
}
pub fn cx() -> Context<'static> {
Context::from_waker(&noop_waker_ref())
Context::from_waker(noop_waker_ref())
}

View File

@@ -6,6 +6,7 @@ use std::pin::Pin;
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
/// for each new channel. Returned by
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
#[must_use]
#[pin_project]
#[derive(Debug)]
pub struct TokioServerExecutor<T, S> {
@@ -23,6 +24,7 @@ impl<T, S> TokioServerExecutor<T, S> {
/// A future that drives the server by [spawning](tokio::spawn) each [response
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
/// [`Channel::execute`](crate::server::Channel::execute).
#[must_use]
#[pin_project]
#[derive(Debug)]
pub struct TokioChannelExecutor<T, S> {

View File

@@ -2,4 +2,8 @@
fn ui() {
let t = trybuild::TestCases::new();
t.compile_fail("tests/compile_fail/*.rs");
#[cfg(feature = "tokio1")]
t.compile_fail("tests/compile_fail/tokio/*.rs");
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
t.compile_fail("tests/compile_fail/serde_transport/*.rs");
}

View File

@@ -0,0 +1,15 @@
use tarpc::client;
#[tarpc::service]
trait World {
async fn hello(name: String) -> String;
}
fn main() {
let (client_transport, _) = tarpc::transport::channel::unbounded();
#[deny(unused_must_use)]
{
WorldClient::new(client::Config::default(), client_transport).dispatch;
}
}

View File

@@ -0,0 +1,11 @@
error: unused `RequestDispatch` that must be used
--> tests/compile_fail/must_use_request_dispatch.rs:13:9
|
13 | WorldClient::new(client::Config::default(), client_transport).dispatch;
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/must_use_request_dispatch.rs:11:12
|
11 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^

View File

@@ -0,0 +1,9 @@
use tarpc::serde_transport;
use tokio_serde::formats::Json;
fn main() {
#[deny(unused_must_use)]
{
serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
}
}

View File

@@ -0,0 +1,11 @@
error: unused `tarpc::serde_transport::tcp::Connect` that must be used
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9
|
7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:5:12
|
5 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^

View File

@@ -0,0 +1,29 @@
use tarpc::{
context,
server::{self, Channel},
};
#[tarpc::service]
trait World {
async fn hello(name: String) -> String;
}
#[derive(Clone)]
struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hello, {name}!")
}
}
fn main() {
let (_, server_transport) = tarpc::transport::channel::unbounded();
let server = server::BaseChannel::with_defaults(server_transport);
#[deny(unused_must_use)]
{
server.execute(HelloServer.serve());
}
}

View File

@@ -0,0 +1,11 @@
error: unused `TokioChannelExecutor` that must be used
--> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9
|
27 | server.execute(HelloServer.serve());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12
|
25 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^

View File

@@ -0,0 +1,30 @@
use futures::stream::once;
use tarpc::{
context,
server::{self, incoming::Incoming},
};
#[tarpc::service]
trait World {
async fn hello(name: String) -> String;
}
#[derive(Clone)]
struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hello, {name}!")
}
}
fn main() {
let (_, server_transport) = tarpc::transport::channel::unbounded();
let server = once(async move { server::BaseChannel::with_defaults(server_transport) });
#[deny(unused_must_use)]
{
server.execute(HelloServer.serve());
}
}

View File

@@ -0,0 +1,11 @@
error: unused `TokioServerExecutor` that must be used
--> tests/compile_fail/tokio/must_use_server_executor.rs:28:9
|
28 | server.execute(HelloServer.serve());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/tokio/must_use_server_executor.rs:26:12
|
26 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^

View File

@@ -7,7 +7,7 @@ use tarpc::{
use tokio_serde::formats::Json;
#[tarpc::derive_serde]
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Eq)]
pub enum TestData {
Black,
White,

View File

@@ -108,7 +108,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
#[tokio::test]
async fn serde() -> anyhow::Result<()> {
async fn serde_tcp() -> anyhow::Result<()> {
use tarpc::serde_transport;
use tokio_serde::formats::Json;
@@ -136,6 +136,37 @@ async fn serde() -> anyhow::Result<()> {
Ok(())
}
#[cfg(all(feature = "serde-transport", feature = "unix", unix))]
#[tokio::test]
async fn serde_uds() -> anyhow::Result<()> {
use tarpc::serde_transport;
use tokio_serde::formats::Json;
let _ = tracing_subscriber::fmt::try_init();
let sock = tarpc::serde_transport::unix::TempPathBuf::with_random("uds");
let transport = tarpc::serde_transport::unix::listen(&sock, Json::default).await?;
tokio::spawn(
transport
.take(1)
.filter_map(|r| async { r.ok() })
.map(BaseChannel::with_defaults)
.execute(Server.serve()),
);
let transport = serde_transport::unix::connect(&sock, Json::default).await?;
let client = ServiceClient::new(client::Config::default(), transport).spawn();
// Save results using socket so we can clean the socket even if our test assertions fail
let res1 = client.add(context::current(), 1, 2).await;
let res2 = client.hey(context::current(), "Tim".to_string()).await;
assert_matches!(res1, Ok(3));
assert_matches!(res2, Ok(ref s) if s == "Hey, Tim.");
Ok(())
}
#[tokio::test]
async fn concurrent() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();