mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bed85e2827 | ||
|
|
93f3880025 | ||
|
|
878f594d5b | ||
|
|
aa9bbad109 | ||
|
|
7e872ce925 | ||
|
|
62541b709d | ||
|
|
8c43f94fb6 | ||
|
|
7fa4e5064d | ||
|
|
94db7610bb | ||
|
|
0c08d5e8ca | ||
|
|
75b15fe2aa | ||
|
|
863a08d87e | ||
|
|
49ba8f8b1b | ||
|
|
d832209da3 | ||
|
|
584426d414 | ||
|
|
50eb80c883 | ||
|
|
1f0c80d8c9 |
92
.github/workflows/main.yml
vendored
92
.github/workflows/main.yml
vendored
@@ -14,99 +14,57 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cancel previous
|
||||
uses: styfle/cancel-workflow-action@0.7.0
|
||||
uses: styfle/cancel-workflow-action@0.10.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions/checkout@v3
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
target: mipsel-unknown-linux-gnu
|
||||
override: true
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: check
|
||||
args: --all-features
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: check
|
||||
args: --all-features --target mipsel-unknown-linux-gnu
|
||||
targets: mipsel-unknown-linux-gnu
|
||||
- run: cargo check --all-features
|
||||
- run: cargo check --all-features --target mipsel-unknown-linux-gnu
|
||||
|
||||
test:
|
||||
name: Test Suite
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cancel previous
|
||||
uses: styfle/cancel-workflow-action@0.7.0
|
||||
uses: styfle/cancel-workflow-action@0.10.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features serde1
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features tokio1
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features serde-transport
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features tcp
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --all-features
|
||||
- uses: actions/checkout@v3
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- run: cargo test
|
||||
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde1
|
||||
- run: cargo test --manifest-path tarpc/Cargo.toml --features tokio1
|
||||
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde-transport
|
||||
- run: cargo test --manifest-path tarpc/Cargo.toml --features tcp
|
||||
- run: cargo test --all-features
|
||||
|
||||
fmt:
|
||||
name: Rustfmt
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cancel previous
|
||||
uses: styfle/cancel-workflow-action@0.7.0
|
||||
uses: styfle/cancel-workflow-action@0.10.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions/checkout@v3
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- run: rustup component add rustfmt
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: fmt
|
||||
args: --all -- --check
|
||||
components: rustfmt
|
||||
- run: cargo fmt --all -- --check
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cancel previous
|
||||
uses: styfle/cancel-workflow-action@0.7.0
|
||||
uses: styfle/cancel-workflow-action@0.10.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions/checkout@v3
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- run: rustup component add clippy
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: clippy
|
||||
args: --all-features -- -D warnings
|
||||
components: clippy
|
||||
- run: cargo clippy --all-features -- -D warnings
|
||||
|
||||
@@ -67,7 +67,7 @@ Some other features of tarpc:
|
||||
Add to your `Cargo.toml` dependencies:
|
||||
|
||||
```toml
|
||||
tarpc = "0.29"
|
||||
tarpc = "0.32"
|
||||
```
|
||||
|
||||
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
@@ -82,7 +82,7 @@ your `Cargo.toml`:
|
||||
```toml
|
||||
anyhow = "1.0"
|
||||
futures = "0.3"
|
||||
tarpc = { version = "0.29", features = ["tokio1"] }
|
||||
tarpc = { version = "0.31", features = ["tokio1"] }
|
||||
tokio = { version = "1.0", features = ["macros"] }
|
||||
```
|
||||
|
||||
|
||||
25
RELEASES.md
25
RELEASES.md
@@ -1,3 +1,28 @@
|
||||
## 0.32.0 (2023-03-24)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- As part of a fix to return more channel errors in RPC results, a few error types have changed:
|
||||
|
||||
0. `client::RpcError::Disconnected` was split into the following errors:
|
||||
- Shutdown: the client was shutdown, either intentionally or due to an error. If due to an
|
||||
error, pending RPCs should see the more specific errors below.
|
||||
- Send: an RPC message failed to send over the transport. Only the RPC that failed to be sent
|
||||
will see this error.
|
||||
- Receive: a fatal error occurred while receiving from the transport. All in-flight RPCs will
|
||||
receive this error.
|
||||
0. `client::ChannelError` and `server::ChannelError` are unified in `tarpc::ChannelError`.
|
||||
Previously, server transport errors would not indicate during which activity the transport
|
||||
error occurred. Now, just like the client already was, it will be specific: reading, readying,
|
||||
sending, flushing, or closing.
|
||||
|
||||
## 0.31.0 (2022-11-03)
|
||||
|
||||
### New Features
|
||||
|
||||
This release adds Unix Domain Sockets to the `serde_transport` module.
|
||||
To use it, enable the "unix" feature. See the docs for more information.
|
||||
|
||||
## 0.30.0 (2022-08-12)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc-example-service"
|
||||
version = "0.12.0"
|
||||
version = "0.14.0"
|
||||
rust-version = "1.56"
|
||||
authors = ["Tim Kuehn <tikue@google.com>"]
|
||||
edition = "2021"
|
||||
@@ -21,7 +21,7 @@ futures = "0.3"
|
||||
opentelemetry = { version = "0.17", features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.16", features = ["rt-tokio"] }
|
||||
rand = "0.8"
|
||||
tarpc = { version = "0.30", path = "../tarpc", features = ["full"] }
|
||||
tarpc = { version = "0.32", path = "../tarpc", features = ["full"] }
|
||||
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
|
||||
tracing = { version = "0.1" }
|
||||
tracing-opentelemetry = "0.17"
|
||||
|
||||
@@ -26,7 +26,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
let flags = Flags::parse();
|
||||
init_tracing("Tarpc Example Client")?;
|
||||
|
||||
let transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
|
||||
let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
|
||||
transport.config_mut().max_frame_length(usize::MAX);
|
||||
|
||||
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
|
||||
// config and any Transport as input.
|
||||
@@ -42,7 +43,10 @@ async fn main() -> anyhow::Result<()> {
|
||||
.instrument(tracing::info_span!("Two Hellos"))
|
||||
.await;
|
||||
|
||||
tracing::info!("{:?}", hello);
|
||||
match hello {
|
||||
Ok(hello) => tracing::info!("{hello:?}"),
|
||||
Err(e) => tracing::warn!("{:?}", anyhow::Error::from(e)),
|
||||
}
|
||||
|
||||
// Let the background span processor finish.
|
||||
sleep(Duration::from_micros(1)).await;
|
||||
|
||||
@@ -54,6 +54,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
// JSON transport is provided by the json_transport tarpc module. It makes it easy
|
||||
// to start up a serde-powered json serialization strategy over TCP.
|
||||
let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?;
|
||||
tracing::info!("Listening on port {}", listener.local_addr().port());
|
||||
listener.config_mut().max_frame_length(usize::MAX);
|
||||
listener
|
||||
// Ignore accept errors.
|
||||
|
||||
@@ -285,7 +285,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
args,
|
||||
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
|
||||
method_idents: &methods,
|
||||
request_names: &*request_names,
|
||||
request_names: &request_names,
|
||||
attrs,
|
||||
rpcs,
|
||||
return_types: &rpcs
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc"
|
||||
version = "0.30.0"
|
||||
version = "0.32.0"
|
||||
rust-version = "1.58.0"
|
||||
authors = [
|
||||
"Adam Wright <adam.austin.wright@gmail.com>",
|
||||
@@ -13,7 +13,7 @@ homepage = "https://github.com/google/tarpc"
|
||||
repository = "https://github.com/google/tarpc"
|
||||
keywords = ["rpc", "network", "server", "api", "microservices"]
|
||||
categories = ["asynchronous", "network-programming"]
|
||||
readme = "../README.md"
|
||||
readme = "README.md"
|
||||
description = "An RPC framework for Rust with a focus on ease of use."
|
||||
|
||||
[features]
|
||||
@@ -25,6 +25,7 @@ serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
||||
serde-transport-json = ["tokio-serde/json"]
|
||||
serde-transport-bincode = ["tokio-serde/bincode"]
|
||||
tcp = ["tokio/net"]
|
||||
unix = ["tokio/net"]
|
||||
|
||||
full = [
|
||||
"serde1",
|
||||
@@ -33,6 +34,7 @@ full = [
|
||||
"serde-transport-json",
|
||||
"serde-transport-bincode",
|
||||
"tcp",
|
||||
"unix",
|
||||
]
|
||||
|
||||
[badges]
|
||||
@@ -76,6 +78,8 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
tokio = { version = "1", features = ["full", "test-util"] }
|
||||
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
|
||||
trybuild = "1.0"
|
||||
tokio-rustls = "0.23"
|
||||
rustls-pemfile = "1.0"
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
@@ -101,6 +105,10 @@ required-features = ["full"]
|
||||
name = "custom_transport"
|
||||
required-features = ["serde1", "tokio1", "serde-transport"]
|
||||
|
||||
[[example]]
|
||||
name = "tls_over_tcp"
|
||||
required-features = ["full"]
|
||||
|
||||
[[test]]
|
||||
name = "service_functional"
|
||||
required-features = ["serde-transport"]
|
||||
|
||||
11
tarpc/examples/certs/eddsa/client.cert
Normal file
11
tarpc/examples/certs/eddsa/client.cert
Normal file
@@ -0,0 +1,11 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBlDCCAUagAwIBAgICAxUwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
|
||||
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
|
||||
NjEwMTEwNFowGjEYMBYGA1UEAwwPcG9ueXRvd24gY2xpZW50MCowBQYDK2VwAyEA
|
||||
NTKuLume19IhJfEFd/5OZUuYDKZH6xvy4AGver17OoejgZswgZgwDAYDVR0TAQH/
|
||||
BAIwADALBgNVHQ8EBAMCBsAwFgYDVR0lAQH/BAwwCgYIKwYBBQUHAwIwHQYDVR0O
|
||||
BBYEFDjdrlMu4tyw5MHtbg7WnzSGRBpFMEQGA1UdIwQ9MDuAFHIl7fHKWP6/l8FE
|
||||
fI2YEIM3oHxKoSCkHjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQYIBezAF
|
||||
BgMrZXADQQCaahfj/QLxoCOpvl6y0ZQ9CpojPqBnxV3460j5nUOp040Va2MpF137
|
||||
izCBY7LwgUE/YG6E+kH30G4jMEnqVEYK
|
||||
-----END CERTIFICATE-----
|
||||
19
tarpc/examples/certs/eddsa/client.chain
Normal file
19
tarpc/examples/certs/eddsa/client.chain
Normal file
@@ -0,0 +1,19 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
|
||||
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
|
||||
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
|
||||
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
|
||||
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
|
||||
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
|
||||
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
|
||||
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
|
||||
-----END CERTIFICATE-----
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
|
||||
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
|
||||
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
|
||||
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
|
||||
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
|
||||
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
|
||||
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
|
||||
-----END CERTIFICATE-----
|
||||
3
tarpc/examples/certs/eddsa/client.key
Normal file
3
tarpc/examples/certs/eddsa/client.key
Normal file
@@ -0,0 +1,3 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MC4CAQAwBQYDK2VwBCIEIIJX9ThTHpVS1SNZb6HP4myg4fRInIVGunTRdgnc+weH
|
||||
-----END PRIVATE KEY-----
|
||||
12
tarpc/examples/certs/eddsa/end.cert
Normal file
12
tarpc/examples/certs/eddsa/end.cert
Normal file
@@ -0,0 +1,12 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBuDCCAWqgAwIBAgICAcgwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
|
||||
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
|
||||
NjEwMTEwNFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wKjAFBgMrZXADIQDc
|
||||
RLl3/N2tPoWnzBV3noVn/oheEl8IUtiY11Vg/QXTUKOBwDCBvTAMBgNVHRMBAf8E
|
||||
AjAAMAsGA1UdDwQEAwIGwDAdBgNVHQ4EFgQUk7U2mnxedNWBAH84BsNy5si3ZQow
|
||||
RAYDVR0jBD0wO4AUciXt8cpY/r+XwUR8jZgQgzegfEqhIKQeMBwxGjAYBgNVBAMM
|
||||
EXBvbnl0b3duIEVkRFNBIENBggF7MDsGA1UdEQQ0MDKCDnRlc3RzZXJ2ZXIuY29t
|
||||
ghVzZWNvbmQudGVzdHNlcnZlci5jb22CCWxvY2FsaG9zdDAFBgMrZXADQQCFWIcF
|
||||
9FiztCuUNzgXDNu5kshuflt0RjkjWpGlWzQjGoYM2IvYhNVPeqnCiY92gqwDSBtq
|
||||
amD2TBup4eNUCsQB
|
||||
-----END CERTIFICATE-----
|
||||
19
tarpc/examples/certs/eddsa/end.chain
Normal file
19
tarpc/examples/certs/eddsa/end.chain
Normal file
@@ -0,0 +1,19 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
|
||||
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
|
||||
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
|
||||
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
|
||||
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
|
||||
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
|
||||
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
|
||||
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
|
||||
-----END CERTIFICATE-----
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
|
||||
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
|
||||
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
|
||||
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
|
||||
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
|
||||
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
|
||||
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
|
||||
-----END CERTIFICATE-----
|
||||
3
tarpc/examples/certs/eddsa/end.key
Normal file
3
tarpc/examples/certs/eddsa/end.key
Normal file
@@ -0,0 +1,3 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MC4CAQAwBQYDK2VwBCIEIMU6xGVe8JTpZ3bN/wajHfw6pEHt0Rd7wPBxds9eEFy2
|
||||
-----END PRIVATE KEY-----
|
||||
152
tarpc/examples/tls_over_tcp.rs
Normal file
152
tarpc/examples/tls_over_tcp.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
use rustls_pemfile::certs;
|
||||
use std::io::{BufReader, Cursor};
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient;
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::{self, Certificate, OwnedTrustAnchor, RootCertStore};
|
||||
use tokio_rustls::{webpki, TlsAcceptor, TlsConnector};
|
||||
|
||||
use tarpc::context::Context;
|
||||
use tarpc::serde_transport as transport;
|
||||
use tarpc::server::{BaseChannel, Channel};
|
||||
use tarpc::tokio_serde::formats::Bincode;
|
||||
use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait PingService {
|
||||
async fn ping() -> String;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Service;
|
||||
|
||||
#[tarpc::server]
|
||||
impl PingService for Service {
|
||||
async fn ping(self, _: Context) -> String {
|
||||
"🔒".to_owned()
|
||||
}
|
||||
}
|
||||
|
||||
// certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca
|
||||
// used on client-side for server tls
|
||||
const END_CHAIN: &[u8] = include_bytes!("certs/eddsa/end.chain");
|
||||
// used on client-side for client-auth
|
||||
const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key");
|
||||
const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert");
|
||||
|
||||
// used on server-side for server tls
|
||||
const END_CERT: &str = include_str!("certs/eddsa/end.cert");
|
||||
const END_PRIVATEKEY: &str = include_str!("certs/eddsa/end.key");
|
||||
// used on server-side for client-auth
|
||||
const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain");
|
||||
|
||||
pub fn load_private_key(key: &str) -> rustls::PrivateKey {
|
||||
let mut reader = BufReader::new(Cursor::new(key));
|
||||
loop {
|
||||
match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
|
||||
Some(rustls_pemfile::Item::RSAKey(key)) => return rustls::PrivateKey(key),
|
||||
Some(rustls_pemfile::Item::PKCS8Key(key)) => return rustls::PrivateKey(key),
|
||||
Some(rustls_pemfile::Item::ECKey(key)) => return rustls::PrivateKey(key),
|
||||
None => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
panic!("no keys found in {:?} (encrypted keys not supported)", key);
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// -------------------- start here to setup tls tcp tokio stream --------------------------
|
||||
// ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs
|
||||
// ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs
|
||||
let cert = certs(&mut BufReader::new(Cursor::new(END_CERT)))
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect();
|
||||
let key = load_private_key(END_PRIVATEKEY);
|
||||
let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000);
|
||||
|
||||
// ------------- server side client_auth cert loading start
|
||||
let roots: Vec<Certificate> = certs(&mut BufReader::new(Cursor::new(CLIENT_CHAIN_CLIENT_AUTH)))
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect();
|
||||
let mut client_auth_roots = RootCertStore::empty();
|
||||
for root in roots {
|
||||
client_auth_roots.add(&root).unwrap();
|
||||
}
|
||||
let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots);
|
||||
// ------------- server side client_auth cert loading end
|
||||
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_client_cert_verifier(client_auth) // use .with_no_client_auth() instead if you don't want client-auth
|
||||
.with_single_cert(cert, key)
|
||||
.unwrap();
|
||||
let acceptor = TlsAcceptor::from(Arc::new(config));
|
||||
let listener = TcpListener::bind(&server_addr).await.unwrap();
|
||||
let codec_builder = LengthDelimitedCodec::builder();
|
||||
|
||||
// ref ./custom_transport.rs server side
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (stream, _peer_addr) = listener.accept().await.unwrap();
|
||||
let acceptor = acceptor.clone();
|
||||
let tls_stream = acceptor.accept(stream).await.unwrap();
|
||||
let framed = codec_builder.new_framed(tls_stream);
|
||||
|
||||
let transport = transport::new(framed, Bincode::default());
|
||||
|
||||
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
});
|
||||
|
||||
// ---------------------- client connection ---------------------
|
||||
// cert loading from: https://github.com/tokio-rs/tls/blob/357bc562483dcf04c1f8d08bd1a831b144bf7d4c/tokio-rustls/tests/test.rs#L113
|
||||
// tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
|
||||
let chain = certs(&mut std::io::Cursor::new(END_CHAIN)).unwrap();
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
root_store.add_server_trust_anchors(chain.iter().map(|cert| {
|
||||
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
|
||||
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||
ta.subject,
|
||||
ta.spki,
|
||||
ta.name_constraints,
|
||||
)
|
||||
}));
|
||||
|
||||
let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH);
|
||||
let client_auth_certs: Vec<Certificate> =
|
||||
certs(&mut BufReader::new(Cursor::new(CLIENT_CERT_CLIENT_AUTH)))
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect();
|
||||
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_single_cert(client_auth_certs, client_auth_private_key)?; // use .with_no_client_auth() instead if you don't want client-auth
|
||||
|
||||
let domain = rustls::ServerName::try_from("localhost")?;
|
||||
let connector = TlsConnector::from(Arc::new(config));
|
||||
|
||||
let stream = TcpStream::connect(server_addr).await?;
|
||||
let stream = connector.connect(domain, stream).await?;
|
||||
|
||||
let transport = transport::new(codec_builder.new_framed(stream), Bincode::default());
|
||||
let answer = PingServiceClient::new(Default::default(), transport)
|
||||
.spawn()
|
||||
.ping(tarpc::context::current())
|
||||
.await?;
|
||||
|
||||
println!("ping answer: {answer}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -10,15 +10,14 @@ mod in_flight_requests;
|
||||
|
||||
use crate::{
|
||||
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||
context, trace, ClientMessage, Request, Response, ServerError, Transport,
|
||||
context, trace, ChannelError, ClientMessage, Request, Response, ServerError, Transport,
|
||||
};
|
||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
|
||||
use in_flight_requests::InFlightRequests;
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
error::Error,
|
||||
fmt, mem,
|
||||
fmt,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
@@ -124,7 +123,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
pub async fn call(
|
||||
&self,
|
||||
mut ctx: context::Context,
|
||||
request_name: &str,
|
||||
request_name: &'static str,
|
||||
request: Req,
|
||||
) -> Result<Resp, RpcError> {
|
||||
let span = Span::current();
|
||||
@@ -147,6 +146,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
response: &mut response,
|
||||
request_id,
|
||||
cancellation: &self.cancellation,
|
||||
cancel: true,
|
||||
};
|
||||
self.to_dispatch
|
||||
.send(DispatchRequest {
|
||||
@@ -157,7 +157,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
response_completion,
|
||||
})
|
||||
.await
|
||||
.map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
|
||||
.map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?;
|
||||
response_guard.response().await
|
||||
}
|
||||
}
|
||||
@@ -165,19 +165,25 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
/// A server response that is completed by request dispatch when the corresponding response
|
||||
/// arrives off the wire.
|
||||
struct ResponseGuard<'a, Resp> {
|
||||
response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
response: &'a mut oneshot::Receiver<Result<Resp, RpcError>>,
|
||||
cancellation: &'a RequestCancellation,
|
||||
request_id: u64,
|
||||
cancel: bool,
|
||||
}
|
||||
|
||||
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
|
||||
/// rather cross-cutting errors that can always occur.
|
||||
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum RpcError {
|
||||
/// The client disconnected from the server.
|
||||
#[error("the client disconnected from the server")]
|
||||
Disconnected,
|
||||
#[error("the connection to the server was already shutdown")]
|
||||
Shutdown,
|
||||
/// The client failed to send the request.
|
||||
#[error("the client failed to send the request")]
|
||||
Send(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
/// An error occurred while waiting for the server response.
|
||||
#[error("an error occurred while waiting for the server response")]
|
||||
Receive(#[source] Arc<dyn std::error::Error + Send + Sync + 'static>),
|
||||
/// The request exceeded its deadline.
|
||||
#[error("the request exceeded its deadline")]
|
||||
DeadlineExceeded,
|
||||
@@ -186,24 +192,18 @@ pub enum RpcError {
|
||||
Server(#[from] ServerError),
|
||||
}
|
||||
|
||||
impl From<DeadlineExceededError> for RpcError {
|
||||
fn from(_: DeadlineExceededError) -> Self {
|
||||
RpcError::DeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
impl<Resp> ResponseGuard<'_, Resp> {
|
||||
async fn response(mut self) -> Result<Resp, RpcError> {
|
||||
let response = (&mut self.response).await;
|
||||
// Cancel drop logic once a response has been received.
|
||||
mem::forget(self);
|
||||
self.cancel = false;
|
||||
match response {
|
||||
Ok(resp) => Ok(resp?.message?),
|
||||
Ok(response) => response,
|
||||
Err(oneshot::error::RecvError { .. }) => {
|
||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||
// there's nothing listening on the other side, so there's no point in
|
||||
// propagating cancellation.
|
||||
Err(RpcError::Disconnected)
|
||||
Err(RpcError::Shutdown)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -223,7 +223,9 @@ impl<Resp> Drop for ResponseGuard<'_, Resp> {
|
||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||
// receiver as closed.
|
||||
self.response.close();
|
||||
self.cancellation.cancel(self.request_id);
|
||||
if self.cancel {
|
||||
self.cancellation.cancel(self.request_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,7 +240,6 @@ where
|
||||
{
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
|
||||
let (cancellation, canceled_requests) = cancellations();
|
||||
let canceled_requests = canceled_requests;
|
||||
|
||||
NewClient {
|
||||
client: Channel {
|
||||
@@ -270,42 +271,18 @@ pub struct RequestDispatch<Req, Resp, C> {
|
||||
/// Requests that were dropped.
|
||||
canceled_requests: CanceledRequests,
|
||||
/// Requests already written to the wire that haven't yet received responses.
|
||||
in_flight_requests: InFlightRequests<Resp>,
|
||||
in_flight_requests: InFlightRequests<Result<Resp, RpcError>>,
|
||||
/// Configures limits to prevent unlimited resource usage.
|
||||
config: Config,
|
||||
}
|
||||
|
||||
/// Critical errors that result in a Channel disconnecting.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ChannelError<E>
|
||||
where
|
||||
E: Error + Send + Sync + 'static,
|
||||
{
|
||||
/// Could not read from the transport.
|
||||
#[error("could not read from the transport")]
|
||||
Read(#[source] E),
|
||||
/// Could not ready the transport for writes.
|
||||
#[error("could not ready the transport for writes")]
|
||||
Ready(#[source] E),
|
||||
/// Could not write to the transport.
|
||||
#[error("could not write to the transport")]
|
||||
Write(#[source] E),
|
||||
/// Could not flush the transport.
|
||||
#[error("could not flush the transport")]
|
||||
Flush(#[source] E),
|
||||
/// Could not close the write end of the transport.
|
||||
#[error("could not close the write end of the transport")]
|
||||
Close(#[source] E),
|
||||
/// Could not poll expired requests.
|
||||
#[error("could not poll expired requests")]
|
||||
Timer(#[source] tokio::time::error::Error),
|
||||
}
|
||||
|
||||
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
|
||||
where
|
||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||
{
|
||||
fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<Resp> {
|
||||
fn in_flight_requests<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
) -> &'a mut InFlightRequests<Result<Resp, RpcError>> {
|
||||
self.as_mut().project().in_flight_requests
|
||||
}
|
||||
|
||||
@@ -365,7 +342,17 @@ where
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
self.transport_pin_mut()
|
||||
.poll_next(cx)
|
||||
.map_err(ChannelError::Read)
|
||||
.map_err(|e| {
|
||||
let e = Arc::new(e);
|
||||
for span in self
|
||||
.in_flight_requests()
|
||||
.complete_all_requests(|| Err(RpcError::Receive(e.clone())))
|
||||
{
|
||||
let _entered = span.enter();
|
||||
tracing::info!("ReceiveError");
|
||||
}
|
||||
ChannelError::Read(e)
|
||||
})
|
||||
.map_ok(|response| {
|
||||
self.complete(response);
|
||||
})
|
||||
@@ -395,7 +382,10 @@ where
|
||||
// Receiving Poll::Ready(None) when polling expired requests never indicates "Closed",
|
||||
// because there can temporarily be zero in-flight rquests. Therefore, there is no need to
|
||||
// track the status like is done with pending and cancelled requests.
|
||||
if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx) {
|
||||
if let Poll::Ready(Some(_)) = self
|
||||
.in_flight_requests()
|
||||
.poll_expired(cx, || Err(RpcError::DeadlineExceeded))
|
||||
{
|
||||
// Expired requests are considered complete; there is no compelling reason to send a
|
||||
// cancellation message to the server, since it will have already exhausted its
|
||||
// allotted processing time.
|
||||
@@ -506,7 +496,7 @@ where
|
||||
Some(dispatch_request) => dispatch_request,
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
let entered = span.enter();
|
||||
let _entered = span.enter();
|
||||
// poll_next_request only returns Ready if there is room to buffer another request.
|
||||
// Therefore, we can call write_request without fear of erroring due to a full
|
||||
// buffer.
|
||||
@@ -519,13 +509,16 @@ where
|
||||
trace_context: ctx.trace_context,
|
||||
},
|
||||
});
|
||||
self.start_send(request)?;
|
||||
tracing::info!("SendRequest");
|
||||
drop(entered);
|
||||
|
||||
self.in_flight_requests()
|
||||
.insert_request(request_id, ctx, span, response_completion)
|
||||
.insert_request(request_id, ctx, span.clone(), response_completion)
|
||||
.expect("Request IDs should be unique");
|
||||
match self.start_send(request) {
|
||||
Ok(()) => tracing::info!("SendRequest"),
|
||||
Err(e) => {
|
||||
self.in_flight_requests()
|
||||
.complete_request(request_id, Err(RpcError::Send(Box::new(e))));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
@@ -550,7 +543,10 @@ where
|
||||
|
||||
/// Sends a server response to the client task that initiated the associated request.
|
||||
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
|
||||
self.in_flight_requests().complete_request(response)
|
||||
self.in_flight_requests().complete_request(
|
||||
response.request_id,
|
||||
response.message.map_err(RpcError::Server),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -599,36 +595,43 @@ struct DispatchRequest<Req, Resp> {
|
||||
pub span: Span,
|
||||
pub request_id: u64,
|
||||
pub request: Req,
|
||||
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
pub response_completion: oneshot::Sender<Result<Resp, RpcError>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard};
|
||||
use super::{
|
||||
cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError,
|
||||
};
|
||||
use crate::{
|
||||
client::{
|
||||
in_flight_requests::{DeadlineExceededError, InFlightRequests},
|
||||
Config,
|
||||
},
|
||||
context,
|
||||
client::{in_flight_requests::InFlightRequests, Config},
|
||||
context::{self, current},
|
||||
transport::{self, channel::UnboundedChannel},
|
||||
ClientMessage, Response,
|
||||
ChannelError, ClientMessage, Response,
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{prelude::*, task::*};
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
fmt::Display,
|
||||
marker::PhantomData,
|
||||
pin::Pin,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
sync::Arc,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{
|
||||
mpsc::{self},
|
||||
oneshot,
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::Span;
|
||||
|
||||
#[tokio::test]
|
||||
async fn response_completes_request_future() {
|
||||
let (mut dispatch, mut _channel, mut server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
dispatch
|
||||
@@ -643,7 +646,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
|
||||
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
|
||||
assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -654,9 +657,10 @@ mod tests {
|
||||
response: &mut response,
|
||||
cancellation: &cancellation,
|
||||
request_id: 3,
|
||||
cancel: true,
|
||||
});
|
||||
// resp's drop() is run, which should send a cancel message.
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||
assert_eq!(canceled_requests.poll_recv(cx), Poll::Ready(Some(3)));
|
||||
}
|
||||
|
||||
@@ -674,19 +678,20 @@ mod tests {
|
||||
response: &mut response,
|
||||
cancellation: &cancellation,
|
||||
request_id: 3,
|
||||
cancel: true,
|
||||
}
|
||||
.response()
|
||||
.await
|
||||
.unwrap();
|
||||
drop(cancellation);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||
assert_eq!(canceled_requests.poll_recv(cx), Poll::Ready(None));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stage_request() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let _resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
@@ -704,7 +709,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn stage_request_channel_dropped_doesnt_panic() {
|
||||
let (mut dispatch, mut channel, mut server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
@@ -726,7 +731,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn stage_request_response_future_dropped_is_canceled_before_sending() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
@@ -742,7 +747,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn stage_request_response_future_dropped_is_canceled_after_sending() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let req = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
@@ -763,7 +768,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn stage_request_response_closed_skipped() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
// Test that a request future that's closed its receiver but not yet canceled its request --
|
||||
@@ -775,6 +780,185 @@ mod tests {
|
||||
assert!(dispatch.as_mut().poll_next_request(cx).is_pending());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shutdown_error() {
|
||||
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
|
||||
let (dispatch, mut channel, _) = set_up();
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
// send succeeds
|
||||
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
drop(dispatch);
|
||||
// error on receive
|
||||
assert_matches!(resp.response().await, Err(RpcError::Shutdown));
|
||||
let (dispatch, channel, _) = set_up();
|
||||
drop(dispatch);
|
||||
// error on send
|
||||
let resp = channel
|
||||
.call(current(), "test_request", "hi".to_string())
|
||||
.await;
|
||||
assert_matches!(resp, Err(RpcError::Shutdown));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_write() {
|
||||
let cause = TransportError::Write;
|
||||
let (mut dispatch, mut channel, mut cx) = setup_always_err(cause);
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
assert!(dispatch.as_mut().poll(&mut cx).is_pending());
|
||||
let res = resp.response().await;
|
||||
assert_matches!(res, Err(RpcError::Send(_)));
|
||||
let client_error: anyhow::Error = res.unwrap_err().into();
|
||||
let mut chain = client_error.chain();
|
||||
chain.next(); // original RpcError
|
||||
assert_eq!(
|
||||
chain
|
||||
.next()
|
||||
.unwrap()
|
||||
.downcast_ref::<ChannelError<TransportError>>(),
|
||||
Some(&ChannelError::Write(cause))
|
||||
);
|
||||
assert_eq!(
|
||||
client_error.root_cause().downcast_ref::<TransportError>(),
|
||||
Some(&cause)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_read() {
|
||||
let cause = TransportError::Read;
|
||||
let (mut dispatch, mut channel, mut cx) = setup_always_err(cause);
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
assert_eq!(
|
||||
dispatch.as_mut().pump_write(&mut cx),
|
||||
Poll::Ready(Some(Ok(())))
|
||||
);
|
||||
assert_eq!(
|
||||
dispatch.as_mut().pump_read(&mut cx),
|
||||
Poll::Ready(Some(Err(ChannelError::Read(Arc::new(cause)))))
|
||||
);
|
||||
assert_matches!(resp.response().await, Err(RpcError::Receive(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_ready() {
|
||||
let cause = TransportError::Ready;
|
||||
let (mut dispatch, _, mut cx) = setup_always_err(cause);
|
||||
assert_eq!(
|
||||
dispatch.as_mut().poll(&mut cx),
|
||||
Poll::Ready(Err(ChannelError::Ready(cause)))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_flush() {
|
||||
let cause = TransportError::Flush;
|
||||
let (mut dispatch, _, mut cx) = setup_always_err(cause);
|
||||
assert_eq!(
|
||||
dispatch.as_mut().poll(&mut cx),
|
||||
Poll::Ready(Err(ChannelError::Flush(cause)))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_close() {
|
||||
let cause = TransportError::Close;
|
||||
let (mut dispatch, channel, mut cx) = setup_always_err(cause);
|
||||
drop(channel);
|
||||
assert_eq!(
|
||||
dispatch.as_mut().poll(&mut cx),
|
||||
Poll::Ready(Err(ChannelError::Close(cause)))
|
||||
);
|
||||
}
|
||||
|
||||
fn setup_always_err(
|
||||
cause: TransportError,
|
||||
) -> (
|
||||
Pin<Box<RequestDispatch<String, String, AlwaysErrorTransport<String>>>>,
|
||||
Channel<String, String>,
|
||||
Context<'static>,
|
||||
) {
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
||||
let (cancellation, canceled_requests) = cancellations();
|
||||
let transport: AlwaysErrorTransport<String> = AlwaysErrorTransport(cause, PhantomData);
|
||||
let dispatch = Box::pin(RequestDispatch::<String, String, _> {
|
||||
transport: transport.fuse(),
|
||||
pending_requests,
|
||||
canceled_requests,
|
||||
in_flight_requests: InFlightRequests::default(),
|
||||
config: Config::default(),
|
||||
});
|
||||
let channel = Channel {
|
||||
to_dispatch,
|
||||
cancellation,
|
||||
next_request_id: Arc::new(AtomicUsize::new(0)),
|
||||
};
|
||||
let cx = Context::from_waker(noop_waker_ref());
|
||||
(dispatch, channel, cx)
|
||||
}
|
||||
|
||||
struct AlwaysErrorTransport<I>(TransportError, PhantomData<I>);
|
||||
|
||||
#[derive(Debug, Error, PartialEq, Eq, Clone, Copy)]
|
||||
enum TransportError {
|
||||
Read,
|
||||
Ready,
|
||||
Write,
|
||||
Flush,
|
||||
Close,
|
||||
}
|
||||
|
||||
impl Display for TransportError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(&format!("{self:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Clone, S> Sink<S> for AlwaysErrorTransport<I> {
|
||||
type Error = TransportError;
|
||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
match self.0 {
|
||||
TransportError::Ready => Poll::Ready(Err(self.0)),
|
||||
TransportError::Flush => Poll::Pending,
|
||||
_ => Poll::Ready(Ok(())),
|
||||
}
|
||||
}
|
||||
fn start_send(self: Pin<&mut Self>, _: S) -> Result<(), Self::Error> {
|
||||
if matches!(self.0, TransportError::Write) {
|
||||
Err(self.0)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
if matches!(self.0, TransportError::Flush) {
|
||||
Poll::Ready(Err(self.0))
|
||||
} else {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
if matches!(self.0, TransportError::Close) {
|
||||
Poll::Ready(Err(self.0))
|
||||
} else {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Clone> Stream for AlwaysErrorTransport<I> {
|
||||
type Item = Result<Response<I>, TransportError>;
|
||||
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
if matches!(self.0, TransportError::Read) {
|
||||
Poll::Ready(Some(Err(self.0)))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn set_up() -> (
|
||||
Pin<
|
||||
Box<
|
||||
@@ -796,7 +980,7 @@ mod tests {
|
||||
|
||||
let dispatch = RequestDispatch::<String, String, _> {
|
||||
transport: client_channel.fuse(),
|
||||
pending_requests: pending_requests,
|
||||
pending_requests,
|
||||
canceled_requests,
|
||||
in_flight_requests: InFlightRequests::default(),
|
||||
config: Config::default(),
|
||||
@@ -814,8 +998,8 @@ mod tests {
|
||||
async fn send_request<'a>(
|
||||
channel: &'a mut Channel<String, String>,
|
||||
request: &str,
|
||||
response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
|
||||
response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
|
||||
response_completion: oneshot::Sender<Result<String, RpcError>>,
|
||||
response: &'a mut oneshot::Receiver<Result<String, RpcError>>,
|
||||
) -> ResponseGuard<'a, String> {
|
||||
let request_id =
|
||||
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||
@@ -830,6 +1014,7 @@ mod tests {
|
||||
response,
|
||||
cancellation: &channel.cancellation,
|
||||
request_id,
|
||||
cancel: true,
|
||||
};
|
||||
channel.to_dispatch.send(request).await.unwrap();
|
||||
response_guard
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use crate::{
|
||||
context,
|
||||
util::{Compact, TimeUntil},
|
||||
Response,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use std::{
|
||||
@@ -28,17 +27,11 @@ impl<Resp> Default for InFlightRequests<Resp> {
|
||||
}
|
||||
}
|
||||
|
||||
/// The request exceeded its deadline.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[error("the request exceeded its deadline")]
|
||||
pub struct DeadlineExceededError;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RequestData<Resp> {
|
||||
struct RequestData<Res> {
|
||||
ctx: context::Context,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
response_completion: oneshot::Sender<Res>,
|
||||
/// The key to remove the timer for the request's deadline.
|
||||
deadline_key: delay_queue::Key,
|
||||
}
|
||||
@@ -48,7 +41,7 @@ struct RequestData<Resp> {
|
||||
#[derive(Debug)]
|
||||
pub struct AlreadyExistsError;
|
||||
|
||||
impl<Resp> InFlightRequests<Resp> {
|
||||
impl<Res> InFlightRequests<Res> {
|
||||
/// Returns the number of in-flight requests.
|
||||
pub fn len(&self) -> usize {
|
||||
self.request_data.len()
|
||||
@@ -65,7 +58,7 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
request_id: u64,
|
||||
ctx: context::Context,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
response_completion: oneshot::Sender<Res>,
|
||||
) -> Result<(), AlreadyExistsError> {
|
||||
match self.request_data.entry(request_id) {
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
@@ -84,25 +77,35 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
}
|
||||
|
||||
/// Removes a request without aborting. Returns true iff the request was found.
|
||||
pub fn complete_request(&mut self, response: Response<Resp>) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&response.request_id) {
|
||||
pub fn complete_request(&mut self, request_id: u64, result: Res) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
let _entered = request_data.span.enter();
|
||||
tracing::info!("ReceiveResponse");
|
||||
self.request_data.compact(0.1);
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
let _ = request_data.response_completion.send(Ok(response));
|
||||
let _ = request_data.response_completion.send(result);
|
||||
return true;
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
"No in-flight request found for request_id = {}.",
|
||||
response.request_id
|
||||
);
|
||||
tracing::debug!("No in-flight request found for request_id = {request_id}.");
|
||||
|
||||
// If the response completion was absent, then the request was already canceled.
|
||||
false
|
||||
}
|
||||
|
||||
/// Completes all requests using the provided function.
|
||||
/// Returns Spans for all completes requests.
|
||||
pub fn complete_all_requests<'a>(
|
||||
&'a mut self,
|
||||
mut result: impl FnMut() -> Res + 'a,
|
||||
) -> impl Iterator<Item = Span> + 'a {
|
||||
self.deadlines.clear();
|
||||
self.request_data.drain().map(move |(_, request_data)| {
|
||||
let _ = request_data.response_completion.send(result());
|
||||
request_data.span
|
||||
})
|
||||
}
|
||||
|
||||
/// Cancels a request without completing (typically used when a request handle was dropped
|
||||
/// before the request completed).
|
||||
pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> {
|
||||
@@ -117,16 +120,18 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
|
||||
/// Yields a request that has expired, completing it with a TimedOut error.
|
||||
/// The caller should send cancellation messages for any yielded request ID.
|
||||
pub fn poll_expired(&mut self, cx: &mut Context) -> Poll<Option<u64>> {
|
||||
pub fn poll_expired(
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
expired_error: impl Fn() -> Res,
|
||||
) -> Poll<Option<u64>> {
|
||||
self.deadlines.poll_expired(cx).map(|expired| {
|
||||
let request_id = expired?.into_inner();
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
let _entered = request_data.span.enter();
|
||||
tracing::error!("DeadlineExceeded");
|
||||
self.request_data.compact(0.1);
|
||||
let _ = request_data
|
||||
.response_completion
|
||||
.send(Err(DeadlineExceededError));
|
||||
let _ = request_data.response_completion.send(expired_error());
|
||||
}
|
||||
Some(request_id)
|
||||
})
|
||||
|
||||
@@ -311,6 +311,7 @@ pub use crate::transport::sealed::Transport;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use futures::task::*;
|
||||
use std::sync::Arc;
|
||||
use std::{error::Error, fmt::Display, io, time::SystemTime};
|
||||
|
||||
/// A message from a client to a server.
|
||||
@@ -383,6 +384,29 @@ pub struct ServerError {
|
||||
pub detail: String,
|
||||
}
|
||||
|
||||
/// Critical errors that result in a Channel disconnecting.
|
||||
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
|
||||
pub enum ChannelError<E>
|
||||
where
|
||||
E: Error + Send + Sync + 'static,
|
||||
{
|
||||
/// Could not read from the transport.
|
||||
#[error("could not read from the transport")]
|
||||
Read(#[source] Arc<E>),
|
||||
/// Could not ready the transport for writes.
|
||||
#[error("could not ready the transport for writes")]
|
||||
Ready(#[source] E),
|
||||
/// Could not write to the transport.
|
||||
#[error("could not write to the transport")]
|
||||
Write(#[source] E),
|
||||
/// Could not flush the transport.
|
||||
#[error("could not flush the transport")]
|
||||
Flush(#[source] E),
|
||||
/// Could not close the write end of the transport.
|
||||
#[error("could not close the write end of the transport")]
|
||||
Close(#[source] E),
|
||||
}
|
||||
|
||||
impl<T> Request<T> {
|
||||
/// Returns the deadline for this request.
|
||||
pub fn deadline(&self) -> &SystemTime {
|
||||
|
||||
@@ -129,14 +129,6 @@ pub mod tcp {
|
||||
tokio_util::codec::length_delimited,
|
||||
};
|
||||
|
||||
mod private {
|
||||
use super::*;
|
||||
|
||||
pub trait Sealed {}
|
||||
|
||||
impl<Item, SinkItem, Codec> Sealed for Transport<TcpStream, Item, SinkItem, Codec> {}
|
||||
}
|
||||
|
||||
impl<Item, SinkItem, Codec> Transport<TcpStream, Item, SinkItem, Codec> {
|
||||
/// Returns the peer address of the underlying TcpStream.
|
||||
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
|
||||
@@ -277,6 +269,270 @@ pub mod tcp {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(unix, feature = "unix"))]
|
||||
#[cfg_attr(docsrs, doc(cfg(all(unix, feature = "unix"))))]
|
||||
/// Unix Domain Socket support for generic transport using Tokio.
|
||||
pub mod unix {
|
||||
use {
|
||||
super::*,
|
||||
futures::ready,
|
||||
std::{marker::PhantomData, path::Path},
|
||||
tokio::net::{unix::SocketAddr, UnixListener, UnixStream},
|
||||
tokio_util::codec::length_delimited,
|
||||
};
|
||||
|
||||
impl<Item, SinkItem, Codec> Transport<UnixStream, Item, SinkItem, Codec> {
|
||||
/// Returns the socket address of the remote half of the underlying [`UnixStream`].
|
||||
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
|
||||
self.inner.get_ref().get_ref().peer_addr()
|
||||
}
|
||||
/// Returns the socket address of the local half of the underlying [`UnixStream`].
|
||||
pub fn local_addr(&self) -> io::Result<SocketAddr> {
|
||||
self.inner.get_ref().get_ref().local_addr()
|
||||
}
|
||||
}
|
||||
|
||||
/// A connection Future that also exposes the length-delimited framing config.
|
||||
#[must_use]
|
||||
#[pin_project]
|
||||
pub struct Connect<T, Item, SinkItem, CodecFn> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
codec_fn: CodecFn,
|
||||
config: length_delimited::Builder,
|
||||
ghost: PhantomData<(fn(SinkItem), fn() -> Item)>,
|
||||
}
|
||||
|
||||
impl<T, Item, SinkItem, Codec, CodecFn> Future for Connect<T, Item, SinkItem, CodecFn>
|
||||
where
|
||||
T: Future<Output = io::Result<UnixStream>>,
|
||||
Item: for<'de> Deserialize<'de>,
|
||||
SinkItem: Serialize,
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
CodecFn: Fn() -> Codec,
|
||||
{
|
||||
type Output = io::Result<Transport<UnixStream, Item, SinkItem, Codec>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
||||
let io = ready!(self.as_mut().project().inner.poll(cx))?;
|
||||
Poll::Ready(Ok(new(self.config.new_framed(io), (self.codec_fn)())))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, Item, SinkItem, CodecFn> Connect<T, Item, SinkItem, CodecFn> {
|
||||
/// Returns an immutable reference to the length-delimited codec's config.
|
||||
pub fn config(&self) -> &length_delimited::Builder {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the length-delimited codec's config.
|
||||
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
|
||||
&mut self.config
|
||||
}
|
||||
}
|
||||
|
||||
/// Connects to socket named by `path`, wrapping the connection in a Unix Domain Socket
|
||||
/// transport.
|
||||
pub fn connect<P, Item, SinkItem, Codec, CodecFn>(
|
||||
path: P,
|
||||
codec_fn: CodecFn,
|
||||
) -> Connect<impl Future<Output = io::Result<UnixStream>>, Item, SinkItem, CodecFn>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
Item: for<'de> Deserialize<'de>,
|
||||
SinkItem: Serialize,
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
CodecFn: Fn() -> Codec,
|
||||
{
|
||||
Connect {
|
||||
inner: UnixStream::connect(path),
|
||||
codec_fn,
|
||||
config: LengthDelimitedCodec::builder(),
|
||||
ghost: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Listens on the socket named by `path`, wrapping accepted connections in Unix Domain Socket
|
||||
/// transports.
|
||||
pub async fn listen<P, Item, SinkItem, Codec, CodecFn>(
|
||||
path: P,
|
||||
codec_fn: CodecFn,
|
||||
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
Item: for<'de> Deserialize<'de>,
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
CodecFn: Fn() -> Codec,
|
||||
{
|
||||
let listener = UnixListener::bind(path)?;
|
||||
let local_addr = listener.local_addr()?;
|
||||
Ok(Incoming {
|
||||
listener,
|
||||
codec_fn,
|
||||
local_addr,
|
||||
config: LengthDelimitedCodec::builder(),
|
||||
ghost: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// A [`UnixListener`] that wraps connections in [transports](Transport).
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct Incoming<Item, SinkItem, Codec, CodecFn> {
|
||||
listener: UnixListener,
|
||||
local_addr: SocketAddr,
|
||||
codec_fn: CodecFn,
|
||||
config: length_delimited::Builder,
|
||||
ghost: PhantomData<(fn() -> Item, fn(SinkItem), Codec)>,
|
||||
}
|
||||
|
||||
impl<Item, SinkItem, Codec, CodecFn> Incoming<Item, SinkItem, Codec, CodecFn> {
|
||||
/// Returns the the socket address being listened on.
|
||||
pub fn local_addr(&self) -> &SocketAddr {
|
||||
&self.local_addr
|
||||
}
|
||||
|
||||
/// Returns an immutable reference to the length-delimited codec's config.
|
||||
pub fn config(&self) -> &length_delimited::Builder {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the length-delimited codec's config.
|
||||
pub fn config_mut(&mut self) -> &mut length_delimited::Builder {
|
||||
&mut self.config
|
||||
}
|
||||
}
|
||||
|
||||
impl<Item, SinkItem, Codec, CodecFn> Stream for Incoming<Item, SinkItem, Codec, CodecFn>
|
||||
where
|
||||
Item: for<'de> Deserialize<'de>,
|
||||
SinkItem: Serialize,
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
CodecFn: Fn() -> Codec,
|
||||
{
|
||||
type Item = io::Result<Transport<UnixStream, Item, SinkItem, Codec>>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let conn: UnixStream = ready!(self.as_mut().project().listener.poll_accept(cx)?).0;
|
||||
Poll::Ready(Some(Ok(new(
|
||||
self.config.new_framed(conn),
|
||||
(self.codec_fn)(),
|
||||
))))
|
||||
}
|
||||
}
|
||||
|
||||
/// A temporary `PathBuf` that lives in `std::env::temp_dir` and is removed on drop.
|
||||
pub struct TempPathBuf(std::path::PathBuf);
|
||||
|
||||
impl TempPathBuf {
|
||||
/// A named socket that results in `<tempdir>/<name>`
|
||||
pub fn new<S: AsRef<str>>(name: S) -> Self {
|
||||
let mut sock = std::env::temp_dir();
|
||||
sock.push(name.as_ref());
|
||||
Self(sock)
|
||||
}
|
||||
|
||||
/// Appends a random hex string to the socket name resulting in
|
||||
/// `<tempdir>/<name>_<xxxxx>`
|
||||
pub fn with_random<S: AsRef<str>>(name: S) -> Self {
|
||||
Self::new(format!("{}_{:016x}", name.as_ref(), rand::random::<u64>()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<std::path::Path> for TempPathBuf {
|
||||
fn as_ref(&self) -> &std::path::Path {
|
||||
self.0.as_path()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TempPathBuf {
|
||||
fn drop(&mut self) {
|
||||
// This will remove the file pointed to by this PathBuf if it exists, however Err's can
|
||||
// be returned such as attempting to remove a non-existing file, or one which we don't
|
||||
// have permission to remove. In these cases the Err is swallowed
|
||||
let _ = std::fs::remove_file(&self.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio_serde::formats::SymmetricalJson;
|
||||
|
||||
#[test]
|
||||
fn temp_path_buf_non_random() {
|
||||
let sock = TempPathBuf::new("test");
|
||||
let mut good = std::env::temp_dir();
|
||||
good.push("test");
|
||||
assert_eq!(sock.as_ref(), good);
|
||||
assert_eq!(sock.as_ref().file_name().unwrap(), "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn temp_path_buf_random() {
|
||||
let sock = TempPathBuf::with_random("test");
|
||||
let good = std::env::temp_dir();
|
||||
assert!(sock.as_ref().starts_with(good));
|
||||
// Since there are 16 random characters we just assert the file_name has the right name
|
||||
// and starts with the correct string 'test_'
|
||||
// file name: test_xxxxxxxxxxxxxxxx
|
||||
// test = 4
|
||||
// _ = 1
|
||||
// <hex> = 16
|
||||
// total = 21
|
||||
let fname = sock.as_ref().file_name().unwrap().to_string_lossy();
|
||||
assert!(fname.starts_with("test_"));
|
||||
assert_eq!(fname.len(), 21);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn temp_path_buf_non_existing() {
|
||||
let sock = TempPathBuf::with_random("test");
|
||||
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||
|
||||
// No actual file has been created yet
|
||||
assert!(!sock_path.exists());
|
||||
// Should not panic
|
||||
std::mem::drop(sock);
|
||||
assert!(!sock_path.exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn temp_path_buf_existing_file() {
|
||||
let sock = TempPathBuf::with_random("test");
|
||||
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||
let _file = std::fs::File::create(&sock).unwrap();
|
||||
assert!(sock_path.exists());
|
||||
std::mem::drop(sock);
|
||||
assert!(!sock_path.exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn temp_path_buf_preexisting_file() {
|
||||
let mut pre_existing = std::env::temp_dir();
|
||||
pre_existing.push("test");
|
||||
let _file = std::fs::File::create(&pre_existing).unwrap();
|
||||
let sock = TempPathBuf::new("test");
|
||||
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||
assert!(sock_path.exists());
|
||||
std::mem::drop(sock);
|
||||
assert!(!sock_path.exists());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn temp_path_buf_for_socket() {
|
||||
let sock = TempPathBuf::with_random("test");
|
||||
// Save path for testing after drop
|
||||
let sock_path = std::path::PathBuf::from(sock.as_ref());
|
||||
// create the actual socket
|
||||
let _ = listen(&sock, SymmetricalJson::<String>::default).await;
|
||||
assert!(sock_path.exists());
|
||||
std::mem::drop(sock);
|
||||
assert!(!sock_path.exists());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Transport;
|
||||
@@ -291,7 +547,7 @@ mod tests {
|
||||
use tokio_serde::formats::SymmetricalJson;
|
||||
|
||||
fn ctx() -> Context<'static> {
|
||||
Context::from_waker(&noop_waker_ref())
|
||||
Context::from_waker(noop_waker_ref())
|
||||
}
|
||||
|
||||
struct TestIo(Cursor<Vec<u8>>);
|
||||
@@ -393,4 +649,24 @@ mod tests {
|
||||
assert_matches!(transport.next().await, None);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(all(unix, feature = "unix"))]
|
||||
#[tokio::test]
|
||||
async fn uds() -> io::Result<()> {
|
||||
use super::unix;
|
||||
use super::*;
|
||||
|
||||
let sock = unix::TempPathBuf::with_random("uds");
|
||||
let mut listener = unix::listen(&sock, SymmetricalJson::<String>::default).await?;
|
||||
tokio::spawn(async move {
|
||||
let mut transport = listener.next().await.unwrap().unwrap();
|
||||
let message = transport.next().await.unwrap().unwrap();
|
||||
transport.send(message).await.unwrap();
|
||||
});
|
||||
let mut transport = unix::connect(&sock, SymmetricalJson::<String>::default).await?;
|
||||
transport.send(String::from("test")).await?;
|
||||
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
|
||||
assert_matches!(transport.next().await, None);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
use crate::{
|
||||
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||
context::{self, SpanExt},
|
||||
trace, ClientMessage, Request, Response, Transport,
|
||||
trace, ChannelError, ClientMessage, Request, Response, Transport,
|
||||
};
|
||||
use ::tokio::sync::mpsc;
|
||||
use futures::{
|
||||
@@ -21,14 +21,7 @@ use futures::{
|
||||
};
|
||||
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
error::Error,
|
||||
fmt,
|
||||
marker::PhantomData,
|
||||
mem::{self, ManuallyDrop},
|
||||
pin::Pin,
|
||||
};
|
||||
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc};
|
||||
use tracing::{info_span, instrument::Instrument, Span};
|
||||
|
||||
mod in_flight_requests;
|
||||
@@ -208,10 +201,11 @@ where
|
||||
Ok(TrackedRequest {
|
||||
abort_registration,
|
||||
span,
|
||||
response_guard: ManuallyDrop::new(ResponseGuard {
|
||||
response_guard: ResponseGuard {
|
||||
request_id: request.id,
|
||||
request_cancellation: self.request_cancellation.clone(),
|
||||
}),
|
||||
cancel: false,
|
||||
},
|
||||
request,
|
||||
})
|
||||
}
|
||||
@@ -240,7 +234,7 @@ pub struct TrackedRequest<Req> {
|
||||
/// A span representing the server processing of this request.
|
||||
pub span: Span,
|
||||
/// An inert response guard. Becomes active in an InFlightRequest.
|
||||
pub response_guard: ManuallyDrop<ResponseGuard>,
|
||||
pub response_guard: ResponseGuard,
|
||||
}
|
||||
|
||||
/// The server end of an open connection with a client, receiving requests from, and sending
|
||||
@@ -343,20 +337,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// Critical errors that result in a Channel disconnecting.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ChannelError<E>
|
||||
where
|
||||
E: Error + Send + Sync + 'static,
|
||||
{
|
||||
/// An error occurred reading from, or writing to, the transport.
|
||||
#[error("an error occurred in the transport: {0}")]
|
||||
Transport(#[source] E),
|
||||
/// An error occurred while polling expired requests.
|
||||
#[error("an error occurred while polling expired requests: {0}")]
|
||||
Timer(#[source] ::tokio::time::error::Error),
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
@@ -413,7 +393,7 @@ where
|
||||
let request_status = match self
|
||||
.transport_pin_mut()
|
||||
.poll_next(cx)
|
||||
.map_err(ChannelError::Transport)?
|
||||
.map_err(|e| ChannelError::Read(Arc::new(e)))?
|
||||
{
|
||||
Poll::Ready(Some(message)) => match message {
|
||||
ClientMessage::Request(request) => {
|
||||
@@ -473,7 +453,7 @@ where
|
||||
self.project()
|
||||
.transport
|
||||
.poll_ready(cx)
|
||||
.map_err(ChannelError::Transport)
|
||||
.map_err(ChannelError::Ready)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
|
||||
@@ -486,7 +466,7 @@ where
|
||||
self.project()
|
||||
.transport
|
||||
.start_send(response)
|
||||
.map_err(ChannelError::Transport)
|
||||
.map_err(ChannelError::Write)
|
||||
} else {
|
||||
// If the request isn't tracked anymore, there's no need to send the response.
|
||||
Ok(())
|
||||
@@ -498,14 +478,14 @@ where
|
||||
self.project()
|
||||
.transport
|
||||
.poll_flush(cx)
|
||||
.map_err(ChannelError::Transport)
|
||||
.map_err(ChannelError::Flush)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
.transport
|
||||
.poll_close(cx)
|
||||
.map_err(ChannelError::Transport)
|
||||
.map_err(ChannelError::Close)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -581,13 +561,15 @@ where
|
||||
request,
|
||||
abort_registration,
|
||||
span,
|
||||
response_guard,
|
||||
mut response_guard,
|
||||
}| {
|
||||
// The response guard becomes active once in an InFlightRequest.
|
||||
response_guard.cancel = true;
|
||||
InFlightRequest {
|
||||
request,
|
||||
abort_registration,
|
||||
span,
|
||||
response_guard: ManuallyDrop::into_inner(response_guard),
|
||||
response_guard,
|
||||
response_tx: self.responses_tx.clone(),
|
||||
}
|
||||
},
|
||||
@@ -674,11 +656,14 @@ where
|
||||
pub struct ResponseGuard {
|
||||
request_cancellation: RequestCancellation,
|
||||
request_id: u64,
|
||||
cancel: bool,
|
||||
}
|
||||
|
||||
impl Drop for ResponseGuard {
|
||||
fn drop(&mut self) {
|
||||
self.request_cancellation.cancel(self.request_id);
|
||||
if self.cancel {
|
||||
self.request_cancellation.cancel(self.request_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -721,7 +706,7 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
{
|
||||
let Self {
|
||||
response_tx,
|
||||
response_guard,
|
||||
mut response_guard,
|
||||
abort_registration,
|
||||
span,
|
||||
request:
|
||||
@@ -732,6 +717,9 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
},
|
||||
} = self;
|
||||
let method = serve.method(&message);
|
||||
// TODO(https://github.com/rust-lang/rust-clippy/issues/9111)
|
||||
// remove when clippy is fixed
|
||||
#[allow(clippy::needless_borrow)]
|
||||
span.record("otel.name", &method.unwrap_or(""));
|
||||
let _ = Abortable::new(
|
||||
async move {
|
||||
@@ -752,7 +740,7 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
// Request processing has completed, meaning either the channel canceled the request or
|
||||
// a request was sent back to the channel. Either way, the channel will clean up the
|
||||
// request data, so the request does not need to be canceled.
|
||||
mem::forget(response_guard);
|
||||
response_guard.cancel = false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -837,9 +825,10 @@ mod tests {
|
||||
channel::Channel<Response<Resp>, ClientMessage<Req>>,
|
||||
) {
|
||||
let (tx, rx) = crate::transport::channel::bounded(capacity);
|
||||
let mut config = Config::default();
|
||||
// Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded).
|
||||
config.pending_response_buffer = capacity + 1;
|
||||
let config = Config {
|
||||
pending_response_buffer: capacity + 1,
|
||||
};
|
||||
(Box::pin(BaseChannel::new(config, rx).requests()), tx)
|
||||
}
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ mod tests {
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
assert_eq!(in_flight_requests.cancel_request(0), true);
|
||||
assert!(in_flight_requests.cancel_request(0));
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Ready(Err(_))
|
||||
|
||||
@@ -282,7 +282,7 @@ where
|
||||
fn ctx() -> Context<'static> {
|
||||
use futures::task::*;
|
||||
|
||||
Context::from_waker(&noop_waker_ref())
|
||||
Context::from_waker(noop_waker_ref())
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::{
|
||||
};
|
||||
use futures::{task::*, Sink, Stream};
|
||||
use pin_project::pin_project;
|
||||
use std::{collections::VecDeque, io, mem::ManuallyDrop, pin::Pin, time::SystemTime};
|
||||
use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
|
||||
use tracing::Span;
|
||||
|
||||
#[pin_project]
|
||||
@@ -101,10 +101,11 @@ impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
},
|
||||
abort_registration,
|
||||
span: Span::none(),
|
||||
response_guard: ManuallyDrop::new(ResponseGuard {
|
||||
response_guard: ResponseGuard {
|
||||
request_cancellation,
|
||||
request_id: id,
|
||||
}),
|
||||
cancel: false,
|
||||
},
|
||||
}));
|
||||
}
|
||||
}
|
||||
@@ -134,5 +135,5 @@ impl<T> PollExt for Poll<Option<T>> {
|
||||
}
|
||||
|
||||
pub fn cx() -> Context<'static> {
|
||||
Context::from_waker(&noop_waker_ref())
|
||||
Context::from_waker(noop_waker_ref())
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ error: unused `RequestDispatch` that must be used
|
||||
--> tests/compile_fail/must_use_request_dispatch.rs:13:9
|
||||
|
|
||||
13 | WorldClient::new(client::Config::default(), client_transport).dispatch;
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/must_use_request_dispatch.rs:11:12
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
error: unused `Connect` that must be used
|
||||
error: unused `tarpc::serde_transport::tcp::Connect` that must be used
|
||||
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9
|
||||
|
|
||||
7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:5:12
|
||||
|
||||
@@ -2,7 +2,7 @@ error: unused `TokioChannelExecutor` that must be used
|
||||
--> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9
|
||||
|
|
||||
27 | server.execute(HelloServer.serve());
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12
|
||||
|
||||
@@ -2,7 +2,7 @@ error: unused `TokioServerExecutor` that must be used
|
||||
--> tests/compile_fail/tokio/must_use_server_executor.rs:28:9
|
||||
|
|
||||
28 | server.execute(HelloServer.serve());
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/tokio/must_use_server_executor.rs:26:12
|
||||
|
||||
@@ -7,7 +7,7 @@ use tarpc::{
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
#[tarpc::derive_serde]
|
||||
#[derive(Debug, PartialEq)]
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum TestData {
|
||||
Black,
|
||||
White,
|
||||
|
||||
@@ -108,7 +108,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
|
||||
|
||||
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
||||
#[tokio::test]
|
||||
async fn serde() -> anyhow::Result<()> {
|
||||
async fn serde_tcp() -> anyhow::Result<()> {
|
||||
use tarpc::serde_transport;
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
@@ -136,6 +136,37 @@ async fn serde() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "serde-transport", feature = "unix", unix))]
|
||||
#[tokio::test]
|
||||
async fn serde_uds() -> anyhow::Result<()> {
|
||||
use tarpc::serde_transport;
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let sock = tarpc::serde_transport::unix::TempPathBuf::with_random("uds");
|
||||
let transport = tarpc::serde_transport::unix::listen(&sock, Json::default).await?;
|
||||
tokio::spawn(
|
||||
transport
|
||||
.take(1)
|
||||
.filter_map(|r| async { r.ok() })
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let transport = serde_transport::unix::connect(&sock, Json::default).await?;
|
||||
let client = ServiceClient::new(client::Config::default(), transport).spawn();
|
||||
|
||||
// Save results using socket so we can clean the socket even if our test assertions fail
|
||||
let res1 = client.add(context::current(), 1, 2).await;
|
||||
let res2 = client.hey(context::current(), "Tim".to_string()).await;
|
||||
|
||||
assert_matches!(res1, Ok(3));
|
||||
assert_matches!(res2, Ok(ref s) if s == "Hey, Tim.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
Reference in New Issue
Block a user