mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bed85e2827 | ||
|
|
93f3880025 | ||
|
|
878f594d5b | ||
|
|
aa9bbad109 | ||
|
|
7e872ce925 | ||
|
|
62541b709d | ||
|
|
8c43f94fb6 | ||
|
|
7fa4e5064d | ||
|
|
94db7610bb |
76
.github/workflows/main.yml
vendored
76
.github/workflows/main.yml
vendored
@@ -18,20 +18,11 @@ jobs:
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
target: mipsel-unknown-linux-gnu
|
||||
override: true
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: check
|
||||
args: --all-features
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: check
|
||||
args: --all-features --target mipsel-unknown-linux-gnu
|
||||
targets: mipsel-unknown-linux-gnu
|
||||
- run: cargo check --all-features
|
||||
- run: cargo check --all-features --target mipsel-unknown-linux-gnu
|
||||
|
||||
test:
|
||||
name: Test Suite
|
||||
@@ -42,34 +33,13 @@ jobs:
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features serde1
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features tokio1
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features serde-transport
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features tcp
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --all-features
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- run: cargo test
|
||||
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde1
|
||||
- run: cargo test --manifest-path tarpc/Cargo.toml --features tokio1
|
||||
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde-transport
|
||||
- run: cargo test --manifest-path tarpc/Cargo.toml --features tcp
|
||||
- run: cargo test --all-features
|
||||
|
||||
fmt:
|
||||
name: Rustfmt
|
||||
@@ -80,16 +50,10 @@ jobs:
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- run: rustup component add rustfmt
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: fmt
|
||||
args: --all -- --check
|
||||
components: rustfmt
|
||||
- run: cargo fmt --all -- --check
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
@@ -100,13 +64,7 @@ jobs:
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- run: rustup component add clippy
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: clippy
|
||||
args: --all-features -- -D warnings
|
||||
components: clippy
|
||||
- run: cargo clippy --all-features -- -D warnings
|
||||
|
||||
@@ -67,7 +67,7 @@ Some other features of tarpc:
|
||||
Add to your `Cargo.toml` dependencies:
|
||||
|
||||
```toml
|
||||
tarpc = "0.31"
|
||||
tarpc = "0.32"
|
||||
```
|
||||
|
||||
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
|
||||
18
RELEASES.md
18
RELEASES.md
@@ -1,3 +1,21 @@
|
||||
## 0.32.0 (2023-03-24)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- As part of a fix to return more channel errors in RPC results, a few error types have changed:
|
||||
|
||||
0. `client::RpcError::Disconnected` was split into the following errors:
|
||||
- Shutdown: the client was shutdown, either intentionally or due to an error. If due to an
|
||||
error, pending RPCs should see the more specific errors below.
|
||||
- Send: an RPC message failed to send over the transport. Only the RPC that failed to be sent
|
||||
will see this error.
|
||||
- Receive: a fatal error occurred while receiving from the transport. All in-flight RPCs will
|
||||
receive this error.
|
||||
0. `client::ChannelError` and `server::ChannelError` are unified in `tarpc::ChannelError`.
|
||||
Previously, server transport errors would not indicate during which activity the transport
|
||||
error occurred. Now, just like the client already was, it will be specific: reading, readying,
|
||||
sending, flushing, or closing.
|
||||
|
||||
## 0.31.0 (2022-11-03)
|
||||
|
||||
### New Features
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc-example-service"
|
||||
version = "0.13.0"
|
||||
version = "0.14.0"
|
||||
rust-version = "1.56"
|
||||
authors = ["Tim Kuehn <tikue@google.com>"]
|
||||
edition = "2021"
|
||||
@@ -21,7 +21,7 @@ futures = "0.3"
|
||||
opentelemetry = { version = "0.17", features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.16", features = ["rt-tokio"] }
|
||||
rand = "0.8"
|
||||
tarpc = { version = "0.31", path = "../tarpc", features = ["full"] }
|
||||
tarpc = { version = "0.32", path = "../tarpc", features = ["full"] }
|
||||
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
|
||||
tracing = { version = "0.1" }
|
||||
tracing-opentelemetry = "0.17"
|
||||
|
||||
@@ -26,7 +26,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
let flags = Flags::parse();
|
||||
init_tracing("Tarpc Example Client")?;
|
||||
|
||||
let transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
|
||||
let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
|
||||
transport.config_mut().max_frame_length(usize::MAX);
|
||||
|
||||
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
|
||||
// config and any Transport as input.
|
||||
@@ -42,7 +43,10 @@ async fn main() -> anyhow::Result<()> {
|
||||
.instrument(tracing::info_span!("Two Hellos"))
|
||||
.await;
|
||||
|
||||
tracing::info!("{:?}", hello);
|
||||
match hello {
|
||||
Ok(hello) => tracing::info!("{hello:?}"),
|
||||
Err(e) => tracing::warn!("{:?}", anyhow::Error::from(e)),
|
||||
}
|
||||
|
||||
// Let the background span processor finish.
|
||||
sleep(Duration::from_micros(1)).await;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc"
|
||||
version = "0.31.0"
|
||||
version = "0.32.0"
|
||||
rust-version = "1.58.0"
|
||||
authors = [
|
||||
"Adam Wright <adam.austin.wright@gmail.com>",
|
||||
@@ -78,6 +78,8 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
tokio = { version = "1", features = ["full", "test-util"] }
|
||||
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
|
||||
trybuild = "1.0"
|
||||
tokio-rustls = "0.23"
|
||||
rustls-pemfile = "1.0"
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
@@ -103,6 +105,10 @@ required-features = ["full"]
|
||||
name = "custom_transport"
|
||||
required-features = ["serde1", "tokio1", "serde-transport"]
|
||||
|
||||
[[example]]
|
||||
name = "tls_over_tcp"
|
||||
required-features = ["full"]
|
||||
|
||||
[[test]]
|
||||
name = "service_functional"
|
||||
required-features = ["serde-transport"]
|
||||
|
||||
11
tarpc/examples/certs/eddsa/client.cert
Normal file
11
tarpc/examples/certs/eddsa/client.cert
Normal file
@@ -0,0 +1,11 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBlDCCAUagAwIBAgICAxUwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
|
||||
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
|
||||
NjEwMTEwNFowGjEYMBYGA1UEAwwPcG9ueXRvd24gY2xpZW50MCowBQYDK2VwAyEA
|
||||
NTKuLume19IhJfEFd/5OZUuYDKZH6xvy4AGver17OoejgZswgZgwDAYDVR0TAQH/
|
||||
BAIwADALBgNVHQ8EBAMCBsAwFgYDVR0lAQH/BAwwCgYIKwYBBQUHAwIwHQYDVR0O
|
||||
BBYEFDjdrlMu4tyw5MHtbg7WnzSGRBpFMEQGA1UdIwQ9MDuAFHIl7fHKWP6/l8FE
|
||||
fI2YEIM3oHxKoSCkHjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQYIBezAF
|
||||
BgMrZXADQQCaahfj/QLxoCOpvl6y0ZQ9CpojPqBnxV3460j5nUOp040Va2MpF137
|
||||
izCBY7LwgUE/YG6E+kH30G4jMEnqVEYK
|
||||
-----END CERTIFICATE-----
|
||||
19
tarpc/examples/certs/eddsa/client.chain
Normal file
19
tarpc/examples/certs/eddsa/client.chain
Normal file
@@ -0,0 +1,19 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
|
||||
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
|
||||
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
|
||||
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
|
||||
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
|
||||
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
|
||||
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
|
||||
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
|
||||
-----END CERTIFICATE-----
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
|
||||
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
|
||||
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
|
||||
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
|
||||
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
|
||||
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
|
||||
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
|
||||
-----END CERTIFICATE-----
|
||||
3
tarpc/examples/certs/eddsa/client.key
Normal file
3
tarpc/examples/certs/eddsa/client.key
Normal file
@@ -0,0 +1,3 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MC4CAQAwBQYDK2VwBCIEIIJX9ThTHpVS1SNZb6HP4myg4fRInIVGunTRdgnc+weH
|
||||
-----END PRIVATE KEY-----
|
||||
12
tarpc/examples/certs/eddsa/end.cert
Normal file
12
tarpc/examples/certs/eddsa/end.cert
Normal file
@@ -0,0 +1,12 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBuDCCAWqgAwIBAgICAcgwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
|
||||
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
|
||||
NjEwMTEwNFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wKjAFBgMrZXADIQDc
|
||||
RLl3/N2tPoWnzBV3noVn/oheEl8IUtiY11Vg/QXTUKOBwDCBvTAMBgNVHRMBAf8E
|
||||
AjAAMAsGA1UdDwQEAwIGwDAdBgNVHQ4EFgQUk7U2mnxedNWBAH84BsNy5si3ZQow
|
||||
RAYDVR0jBD0wO4AUciXt8cpY/r+XwUR8jZgQgzegfEqhIKQeMBwxGjAYBgNVBAMM
|
||||
EXBvbnl0b3duIEVkRFNBIENBggF7MDsGA1UdEQQ0MDKCDnRlc3RzZXJ2ZXIuY29t
|
||||
ghVzZWNvbmQudGVzdHNlcnZlci5jb22CCWxvY2FsaG9zdDAFBgMrZXADQQCFWIcF
|
||||
9FiztCuUNzgXDNu5kshuflt0RjkjWpGlWzQjGoYM2IvYhNVPeqnCiY92gqwDSBtq
|
||||
amD2TBup4eNUCsQB
|
||||
-----END CERTIFICATE-----
|
||||
19
tarpc/examples/certs/eddsa/end.chain
Normal file
19
tarpc/examples/certs/eddsa/end.chain
Normal file
@@ -0,0 +1,19 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
|
||||
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
|
||||
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
|
||||
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
|
||||
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
|
||||
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
|
||||
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
|
||||
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
|
||||
-----END CERTIFICATE-----
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
|
||||
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
|
||||
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
|
||||
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
|
||||
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
|
||||
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
|
||||
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
|
||||
-----END CERTIFICATE-----
|
||||
3
tarpc/examples/certs/eddsa/end.key
Normal file
3
tarpc/examples/certs/eddsa/end.key
Normal file
@@ -0,0 +1,3 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MC4CAQAwBQYDK2VwBCIEIMU6xGVe8JTpZ3bN/wajHfw6pEHt0Rd7wPBxds9eEFy2
|
||||
-----END PRIVATE KEY-----
|
||||
152
tarpc/examples/tls_over_tcp.rs
Normal file
152
tarpc/examples/tls_over_tcp.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
use rustls_pemfile::certs;
|
||||
use std::io::{BufReader, Cursor};
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient;
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::{self, Certificate, OwnedTrustAnchor, RootCertStore};
|
||||
use tokio_rustls::{webpki, TlsAcceptor, TlsConnector};
|
||||
|
||||
use tarpc::context::Context;
|
||||
use tarpc::serde_transport as transport;
|
||||
use tarpc::server::{BaseChannel, Channel};
|
||||
use tarpc::tokio_serde::formats::Bincode;
|
||||
use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait PingService {
|
||||
async fn ping() -> String;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Service;
|
||||
|
||||
#[tarpc::server]
|
||||
impl PingService for Service {
|
||||
async fn ping(self, _: Context) -> String {
|
||||
"🔒".to_owned()
|
||||
}
|
||||
}
|
||||
|
||||
// certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca
|
||||
// used on client-side for server tls
|
||||
const END_CHAIN: &[u8] = include_bytes!("certs/eddsa/end.chain");
|
||||
// used on client-side for client-auth
|
||||
const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key");
|
||||
const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert");
|
||||
|
||||
// used on server-side for server tls
|
||||
const END_CERT: &str = include_str!("certs/eddsa/end.cert");
|
||||
const END_PRIVATEKEY: &str = include_str!("certs/eddsa/end.key");
|
||||
// used on server-side for client-auth
|
||||
const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain");
|
||||
|
||||
pub fn load_private_key(key: &str) -> rustls::PrivateKey {
|
||||
let mut reader = BufReader::new(Cursor::new(key));
|
||||
loop {
|
||||
match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
|
||||
Some(rustls_pemfile::Item::RSAKey(key)) => return rustls::PrivateKey(key),
|
||||
Some(rustls_pemfile::Item::PKCS8Key(key)) => return rustls::PrivateKey(key),
|
||||
Some(rustls_pemfile::Item::ECKey(key)) => return rustls::PrivateKey(key),
|
||||
None => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
panic!("no keys found in {:?} (encrypted keys not supported)", key);
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// -------------------- start here to setup tls tcp tokio stream --------------------------
|
||||
// ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs
|
||||
// ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs
|
||||
let cert = certs(&mut BufReader::new(Cursor::new(END_CERT)))
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect();
|
||||
let key = load_private_key(END_PRIVATEKEY);
|
||||
let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000);
|
||||
|
||||
// ------------- server side client_auth cert loading start
|
||||
let roots: Vec<Certificate> = certs(&mut BufReader::new(Cursor::new(CLIENT_CHAIN_CLIENT_AUTH)))
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect();
|
||||
let mut client_auth_roots = RootCertStore::empty();
|
||||
for root in roots {
|
||||
client_auth_roots.add(&root).unwrap();
|
||||
}
|
||||
let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots);
|
||||
// ------------- server side client_auth cert loading end
|
||||
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_client_cert_verifier(client_auth) // use .with_no_client_auth() instead if you don't want client-auth
|
||||
.with_single_cert(cert, key)
|
||||
.unwrap();
|
||||
let acceptor = TlsAcceptor::from(Arc::new(config));
|
||||
let listener = TcpListener::bind(&server_addr).await.unwrap();
|
||||
let codec_builder = LengthDelimitedCodec::builder();
|
||||
|
||||
// ref ./custom_transport.rs server side
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (stream, _peer_addr) = listener.accept().await.unwrap();
|
||||
let acceptor = acceptor.clone();
|
||||
let tls_stream = acceptor.accept(stream).await.unwrap();
|
||||
let framed = codec_builder.new_framed(tls_stream);
|
||||
|
||||
let transport = transport::new(framed, Bincode::default());
|
||||
|
||||
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
});
|
||||
|
||||
// ---------------------- client connection ---------------------
|
||||
// cert loading from: https://github.com/tokio-rs/tls/blob/357bc562483dcf04c1f8d08bd1a831b144bf7d4c/tokio-rustls/tests/test.rs#L113
|
||||
// tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
|
||||
let chain = certs(&mut std::io::Cursor::new(END_CHAIN)).unwrap();
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
root_store.add_server_trust_anchors(chain.iter().map(|cert| {
|
||||
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
|
||||
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||
ta.subject,
|
||||
ta.spki,
|
||||
ta.name_constraints,
|
||||
)
|
||||
}));
|
||||
|
||||
let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH);
|
||||
let client_auth_certs: Vec<Certificate> =
|
||||
certs(&mut BufReader::new(Cursor::new(CLIENT_CERT_CLIENT_AUTH)))
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect();
|
||||
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_single_cert(client_auth_certs, client_auth_private_key)?; // use .with_no_client_auth() instead if you don't want client-auth
|
||||
|
||||
let domain = rustls::ServerName::try_from("localhost")?;
|
||||
let connector = TlsConnector::from(Arc::new(config));
|
||||
|
||||
let stream = TcpStream::connect(server_addr).await?;
|
||||
let stream = connector.connect(domain, stream).await?;
|
||||
|
||||
let transport = transport::new(codec_builder.new_framed(stream), Bincode::default());
|
||||
let answer = PingServiceClient::new(Default::default(), transport)
|
||||
.spawn()
|
||||
.ping(tarpc::context::current())
|
||||
.await?;
|
||||
|
||||
println!("ping answer: {answer}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -10,15 +10,14 @@ mod in_flight_requests;
|
||||
|
||||
use crate::{
|
||||
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||
context, trace, ClientMessage, Request, Response, ServerError, Transport,
|
||||
context, trace, ChannelError, ClientMessage, Request, Response, ServerError, Transport,
|
||||
};
|
||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
|
||||
use in_flight_requests::InFlightRequests;
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
error::Error,
|
||||
fmt, mem,
|
||||
fmt,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
@@ -124,7 +123,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
pub async fn call(
|
||||
&self,
|
||||
mut ctx: context::Context,
|
||||
request_name: &str,
|
||||
request_name: &'static str,
|
||||
request: Req,
|
||||
) -> Result<Resp, RpcError> {
|
||||
let span = Span::current();
|
||||
@@ -147,6 +146,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
response: &mut response,
|
||||
request_id,
|
||||
cancellation: &self.cancellation,
|
||||
cancel: true,
|
||||
};
|
||||
self.to_dispatch
|
||||
.send(DispatchRequest {
|
||||
@@ -157,7 +157,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
response_completion,
|
||||
})
|
||||
.await
|
||||
.map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
|
||||
.map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?;
|
||||
response_guard.response().await
|
||||
}
|
||||
}
|
||||
@@ -165,19 +165,25 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
/// A server response that is completed by request dispatch when the corresponding response
|
||||
/// arrives off the wire.
|
||||
struct ResponseGuard<'a, Resp> {
|
||||
response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
response: &'a mut oneshot::Receiver<Result<Resp, RpcError>>,
|
||||
cancellation: &'a RequestCancellation,
|
||||
request_id: u64,
|
||||
cancel: bool,
|
||||
}
|
||||
|
||||
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
|
||||
/// rather cross-cutting errors that can always occur.
|
||||
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum RpcError {
|
||||
/// The client disconnected from the server.
|
||||
#[error("the client disconnected from the server")]
|
||||
Disconnected,
|
||||
#[error("the connection to the server was already shutdown")]
|
||||
Shutdown,
|
||||
/// The client failed to send the request.
|
||||
#[error("the client failed to send the request")]
|
||||
Send(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
/// An error occurred while waiting for the server response.
|
||||
#[error("an error occurred while waiting for the server response")]
|
||||
Receive(#[source] Arc<dyn std::error::Error + Send + Sync + 'static>),
|
||||
/// The request exceeded its deadline.
|
||||
#[error("the request exceeded its deadline")]
|
||||
DeadlineExceeded,
|
||||
@@ -186,24 +192,18 @@ pub enum RpcError {
|
||||
Server(#[from] ServerError),
|
||||
}
|
||||
|
||||
impl From<DeadlineExceededError> for RpcError {
|
||||
fn from(_: DeadlineExceededError) -> Self {
|
||||
RpcError::DeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
impl<Resp> ResponseGuard<'_, Resp> {
|
||||
async fn response(mut self) -> Result<Resp, RpcError> {
|
||||
let response = (&mut self.response).await;
|
||||
// Cancel drop logic once a response has been received.
|
||||
mem::forget(self);
|
||||
self.cancel = false;
|
||||
match response {
|
||||
Ok(resp) => Ok(resp?.message?),
|
||||
Ok(response) => response,
|
||||
Err(oneshot::error::RecvError { .. }) => {
|
||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||
// there's nothing listening on the other side, so there's no point in
|
||||
// propagating cancellation.
|
||||
Err(RpcError::Disconnected)
|
||||
Err(RpcError::Shutdown)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -223,7 +223,9 @@ impl<Resp> Drop for ResponseGuard<'_, Resp> {
|
||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||
// receiver as closed.
|
||||
self.response.close();
|
||||
self.cancellation.cancel(self.request_id);
|
||||
if self.cancel {
|
||||
self.cancellation.cancel(self.request_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,7 +240,6 @@ where
|
||||
{
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
|
||||
let (cancellation, canceled_requests) = cancellations();
|
||||
let canceled_requests = canceled_requests;
|
||||
|
||||
NewClient {
|
||||
client: Channel {
|
||||
@@ -270,42 +271,18 @@ pub struct RequestDispatch<Req, Resp, C> {
|
||||
/// Requests that were dropped.
|
||||
canceled_requests: CanceledRequests,
|
||||
/// Requests already written to the wire that haven't yet received responses.
|
||||
in_flight_requests: InFlightRequests<Resp>,
|
||||
in_flight_requests: InFlightRequests<Result<Resp, RpcError>>,
|
||||
/// Configures limits to prevent unlimited resource usage.
|
||||
config: Config,
|
||||
}
|
||||
|
||||
/// Critical errors that result in a Channel disconnecting.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ChannelError<E>
|
||||
where
|
||||
E: Error + Send + Sync + 'static,
|
||||
{
|
||||
/// Could not read from the transport.
|
||||
#[error("could not read from the transport")]
|
||||
Read(#[source] E),
|
||||
/// Could not ready the transport for writes.
|
||||
#[error("could not ready the transport for writes")]
|
||||
Ready(#[source] E),
|
||||
/// Could not write to the transport.
|
||||
#[error("could not write to the transport")]
|
||||
Write(#[source] E),
|
||||
/// Could not flush the transport.
|
||||
#[error("could not flush the transport")]
|
||||
Flush(#[source] E),
|
||||
/// Could not close the write end of the transport.
|
||||
#[error("could not close the write end of the transport")]
|
||||
Close(#[source] E),
|
||||
/// Could not poll expired requests.
|
||||
#[error("could not poll expired requests")]
|
||||
Timer(#[source] tokio::time::error::Error),
|
||||
}
|
||||
|
||||
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
|
||||
where
|
||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||
{
|
||||
fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<Resp> {
|
||||
fn in_flight_requests<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
) -> &'a mut InFlightRequests<Result<Resp, RpcError>> {
|
||||
self.as_mut().project().in_flight_requests
|
||||
}
|
||||
|
||||
@@ -365,7 +342,17 @@ where
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
self.transport_pin_mut()
|
||||
.poll_next(cx)
|
||||
.map_err(ChannelError::Read)
|
||||
.map_err(|e| {
|
||||
let e = Arc::new(e);
|
||||
for span in self
|
||||
.in_flight_requests()
|
||||
.complete_all_requests(|| Err(RpcError::Receive(e.clone())))
|
||||
{
|
||||
let _entered = span.enter();
|
||||
tracing::info!("ReceiveError");
|
||||
}
|
||||
ChannelError::Read(e)
|
||||
})
|
||||
.map_ok(|response| {
|
||||
self.complete(response);
|
||||
})
|
||||
@@ -395,7 +382,10 @@ where
|
||||
// Receiving Poll::Ready(None) when polling expired requests never indicates "Closed",
|
||||
// because there can temporarily be zero in-flight rquests. Therefore, there is no need to
|
||||
// track the status like is done with pending and cancelled requests.
|
||||
if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx) {
|
||||
if let Poll::Ready(Some(_)) = self
|
||||
.in_flight_requests()
|
||||
.poll_expired(cx, || Err(RpcError::DeadlineExceeded))
|
||||
{
|
||||
// Expired requests are considered complete; there is no compelling reason to send a
|
||||
// cancellation message to the server, since it will have already exhausted its
|
||||
// allotted processing time.
|
||||
@@ -506,7 +496,7 @@ where
|
||||
Some(dispatch_request) => dispatch_request,
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
let entered = span.enter();
|
||||
let _entered = span.enter();
|
||||
// poll_next_request only returns Ready if there is room to buffer another request.
|
||||
// Therefore, we can call write_request without fear of erroring due to a full
|
||||
// buffer.
|
||||
@@ -519,13 +509,16 @@ where
|
||||
trace_context: ctx.trace_context,
|
||||
},
|
||||
});
|
||||
self.start_send(request)?;
|
||||
tracing::info!("SendRequest");
|
||||
drop(entered);
|
||||
|
||||
self.in_flight_requests()
|
||||
.insert_request(request_id, ctx, span, response_completion)
|
||||
.insert_request(request_id, ctx, span.clone(), response_completion)
|
||||
.expect("Request IDs should be unique");
|
||||
match self.start_send(request) {
|
||||
Ok(()) => tracing::info!("SendRequest"),
|
||||
Err(e) => {
|
||||
self.in_flight_requests()
|
||||
.complete_request(request_id, Err(RpcError::Send(Box::new(e))));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
@@ -550,7 +543,10 @@ where
|
||||
|
||||
/// Sends a server response to the client task that initiated the associated request.
|
||||
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
|
||||
self.in_flight_requests().complete_request(response)
|
||||
self.in_flight_requests().complete_request(
|
||||
response.request_id,
|
||||
response.message.map_err(RpcError::Server),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -599,30 +595,37 @@ struct DispatchRequest<Req, Resp> {
|
||||
pub span: Span,
|
||||
pub request_id: u64,
|
||||
pub request: Req,
|
||||
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
pub response_completion: oneshot::Sender<Result<Resp, RpcError>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard};
|
||||
use super::{
|
||||
cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError,
|
||||
};
|
||||
use crate::{
|
||||
client::{
|
||||
in_flight_requests::{DeadlineExceededError, InFlightRequests},
|
||||
Config,
|
||||
},
|
||||
context,
|
||||
client::{in_flight_requests::InFlightRequests, Config},
|
||||
context::{self, current},
|
||||
transport::{self, channel::UnboundedChannel},
|
||||
ClientMessage, Response,
|
||||
ChannelError, ClientMessage, Response,
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{prelude::*, task::*};
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
fmt::Display,
|
||||
marker::PhantomData,
|
||||
pin::Pin,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
sync::Arc,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{
|
||||
mpsc::{self},
|
||||
oneshot,
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::Span;
|
||||
|
||||
#[tokio::test]
|
||||
@@ -643,7 +646,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
|
||||
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
|
||||
assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -654,6 +657,7 @@ mod tests {
|
||||
response: &mut response,
|
||||
cancellation: &cancellation,
|
||||
request_id: 3,
|
||||
cancel: true,
|
||||
});
|
||||
// resp's drop() is run, which should send a cancel message.
|
||||
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||
@@ -674,6 +678,7 @@ mod tests {
|
||||
response: &mut response,
|
||||
cancellation: &cancellation,
|
||||
request_id: 3,
|
||||
cancel: true,
|
||||
}
|
||||
.response()
|
||||
.await
|
||||
@@ -775,6 +780,185 @@ mod tests {
|
||||
assert!(dispatch.as_mut().poll_next_request(cx).is_pending());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shutdown_error() {
|
||||
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
|
||||
let (dispatch, mut channel, _) = set_up();
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
// send succeeds
|
||||
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
drop(dispatch);
|
||||
// error on receive
|
||||
assert_matches!(resp.response().await, Err(RpcError::Shutdown));
|
||||
let (dispatch, channel, _) = set_up();
|
||||
drop(dispatch);
|
||||
// error on send
|
||||
let resp = channel
|
||||
.call(current(), "test_request", "hi".to_string())
|
||||
.await;
|
||||
assert_matches!(resp, Err(RpcError::Shutdown));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_write() {
|
||||
let cause = TransportError::Write;
|
||||
let (mut dispatch, mut channel, mut cx) = setup_always_err(cause);
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
assert!(dispatch.as_mut().poll(&mut cx).is_pending());
|
||||
let res = resp.response().await;
|
||||
assert_matches!(res, Err(RpcError::Send(_)));
|
||||
let client_error: anyhow::Error = res.unwrap_err().into();
|
||||
let mut chain = client_error.chain();
|
||||
chain.next(); // original RpcError
|
||||
assert_eq!(
|
||||
chain
|
||||
.next()
|
||||
.unwrap()
|
||||
.downcast_ref::<ChannelError<TransportError>>(),
|
||||
Some(&ChannelError::Write(cause))
|
||||
);
|
||||
assert_eq!(
|
||||
client_error.root_cause().downcast_ref::<TransportError>(),
|
||||
Some(&cause)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_read() {
|
||||
let cause = TransportError::Read;
|
||||
let (mut dispatch, mut channel, mut cx) = setup_always_err(cause);
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
assert_eq!(
|
||||
dispatch.as_mut().pump_write(&mut cx),
|
||||
Poll::Ready(Some(Ok(())))
|
||||
);
|
||||
assert_eq!(
|
||||
dispatch.as_mut().pump_read(&mut cx),
|
||||
Poll::Ready(Some(Err(ChannelError::Read(Arc::new(cause)))))
|
||||
);
|
||||
assert_matches!(resp.response().await, Err(RpcError::Receive(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_ready() {
|
||||
let cause = TransportError::Ready;
|
||||
let (mut dispatch, _, mut cx) = setup_always_err(cause);
|
||||
assert_eq!(
|
||||
dispatch.as_mut().poll(&mut cx),
|
||||
Poll::Ready(Err(ChannelError::Ready(cause)))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_flush() {
|
||||
let cause = TransportError::Flush;
|
||||
let (mut dispatch, _, mut cx) = setup_always_err(cause);
|
||||
assert_eq!(
|
||||
dispatch.as_mut().poll(&mut cx),
|
||||
Poll::Ready(Err(ChannelError::Flush(cause)))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_close() {
|
||||
let cause = TransportError::Close;
|
||||
let (mut dispatch, channel, mut cx) = setup_always_err(cause);
|
||||
drop(channel);
|
||||
assert_eq!(
|
||||
dispatch.as_mut().poll(&mut cx),
|
||||
Poll::Ready(Err(ChannelError::Close(cause)))
|
||||
);
|
||||
}
|
||||
|
||||
fn setup_always_err(
|
||||
cause: TransportError,
|
||||
) -> (
|
||||
Pin<Box<RequestDispatch<String, String, AlwaysErrorTransport<String>>>>,
|
||||
Channel<String, String>,
|
||||
Context<'static>,
|
||||
) {
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
||||
let (cancellation, canceled_requests) = cancellations();
|
||||
let transport: AlwaysErrorTransport<String> = AlwaysErrorTransport(cause, PhantomData);
|
||||
let dispatch = Box::pin(RequestDispatch::<String, String, _> {
|
||||
transport: transport.fuse(),
|
||||
pending_requests,
|
||||
canceled_requests,
|
||||
in_flight_requests: InFlightRequests::default(),
|
||||
config: Config::default(),
|
||||
});
|
||||
let channel = Channel {
|
||||
to_dispatch,
|
||||
cancellation,
|
||||
next_request_id: Arc::new(AtomicUsize::new(0)),
|
||||
};
|
||||
let cx = Context::from_waker(noop_waker_ref());
|
||||
(dispatch, channel, cx)
|
||||
}
|
||||
|
||||
struct AlwaysErrorTransport<I>(TransportError, PhantomData<I>);
|
||||
|
||||
#[derive(Debug, Error, PartialEq, Eq, Clone, Copy)]
|
||||
enum TransportError {
|
||||
Read,
|
||||
Ready,
|
||||
Write,
|
||||
Flush,
|
||||
Close,
|
||||
}
|
||||
|
||||
impl Display for TransportError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(&format!("{self:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Clone, S> Sink<S> for AlwaysErrorTransport<I> {
|
||||
type Error = TransportError;
|
||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
match self.0 {
|
||||
TransportError::Ready => Poll::Ready(Err(self.0)),
|
||||
TransportError::Flush => Poll::Pending,
|
||||
_ => Poll::Ready(Ok(())),
|
||||
}
|
||||
}
|
||||
fn start_send(self: Pin<&mut Self>, _: S) -> Result<(), Self::Error> {
|
||||
if matches!(self.0, TransportError::Write) {
|
||||
Err(self.0)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
if matches!(self.0, TransportError::Flush) {
|
||||
Poll::Ready(Err(self.0))
|
||||
} else {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
if matches!(self.0, TransportError::Close) {
|
||||
Poll::Ready(Err(self.0))
|
||||
} else {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Clone> Stream for AlwaysErrorTransport<I> {
|
||||
type Item = Result<Response<I>, TransportError>;
|
||||
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
if matches!(self.0, TransportError::Read) {
|
||||
Poll::Ready(Some(Err(self.0)))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn set_up() -> (
|
||||
Pin<
|
||||
Box<
|
||||
@@ -814,8 +998,8 @@ mod tests {
|
||||
async fn send_request<'a>(
|
||||
channel: &'a mut Channel<String, String>,
|
||||
request: &str,
|
||||
response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
|
||||
response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
|
||||
response_completion: oneshot::Sender<Result<String, RpcError>>,
|
||||
response: &'a mut oneshot::Receiver<Result<String, RpcError>>,
|
||||
) -> ResponseGuard<'a, String> {
|
||||
let request_id =
|
||||
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||
@@ -830,6 +1014,7 @@ mod tests {
|
||||
response,
|
||||
cancellation: &channel.cancellation,
|
||||
request_id,
|
||||
cancel: true,
|
||||
};
|
||||
channel.to_dispatch.send(request).await.unwrap();
|
||||
response_guard
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use crate::{
|
||||
context,
|
||||
util::{Compact, TimeUntil},
|
||||
Response,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use std::{
|
||||
@@ -28,17 +27,11 @@ impl<Resp> Default for InFlightRequests<Resp> {
|
||||
}
|
||||
}
|
||||
|
||||
/// The request exceeded its deadline.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[error("the request exceeded its deadline")]
|
||||
pub struct DeadlineExceededError;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RequestData<Resp> {
|
||||
struct RequestData<Res> {
|
||||
ctx: context::Context,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
response_completion: oneshot::Sender<Res>,
|
||||
/// The key to remove the timer for the request's deadline.
|
||||
deadline_key: delay_queue::Key,
|
||||
}
|
||||
@@ -48,7 +41,7 @@ struct RequestData<Resp> {
|
||||
#[derive(Debug)]
|
||||
pub struct AlreadyExistsError;
|
||||
|
||||
impl<Resp> InFlightRequests<Resp> {
|
||||
impl<Res> InFlightRequests<Res> {
|
||||
/// Returns the number of in-flight requests.
|
||||
pub fn len(&self) -> usize {
|
||||
self.request_data.len()
|
||||
@@ -65,7 +58,7 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
request_id: u64,
|
||||
ctx: context::Context,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
response_completion: oneshot::Sender<Res>,
|
||||
) -> Result<(), AlreadyExistsError> {
|
||||
match self.request_data.entry(request_id) {
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
@@ -84,25 +77,35 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
}
|
||||
|
||||
/// Removes a request without aborting. Returns true iff the request was found.
|
||||
pub fn complete_request(&mut self, response: Response<Resp>) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&response.request_id) {
|
||||
pub fn complete_request(&mut self, request_id: u64, result: Res) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
let _entered = request_data.span.enter();
|
||||
tracing::info!("ReceiveResponse");
|
||||
self.request_data.compact(0.1);
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
let _ = request_data.response_completion.send(Ok(response));
|
||||
let _ = request_data.response_completion.send(result);
|
||||
return true;
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
"No in-flight request found for request_id = {}.",
|
||||
response.request_id
|
||||
);
|
||||
tracing::debug!("No in-flight request found for request_id = {request_id}.");
|
||||
|
||||
// If the response completion was absent, then the request was already canceled.
|
||||
false
|
||||
}
|
||||
|
||||
/// Completes all requests using the provided function.
|
||||
/// Returns Spans for all completes requests.
|
||||
pub fn complete_all_requests<'a>(
|
||||
&'a mut self,
|
||||
mut result: impl FnMut() -> Res + 'a,
|
||||
) -> impl Iterator<Item = Span> + 'a {
|
||||
self.deadlines.clear();
|
||||
self.request_data.drain().map(move |(_, request_data)| {
|
||||
let _ = request_data.response_completion.send(result());
|
||||
request_data.span
|
||||
})
|
||||
}
|
||||
|
||||
/// Cancels a request without completing (typically used when a request handle was dropped
|
||||
/// before the request completed).
|
||||
pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> {
|
||||
@@ -117,16 +120,18 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
|
||||
/// Yields a request that has expired, completing it with a TimedOut error.
|
||||
/// The caller should send cancellation messages for any yielded request ID.
|
||||
pub fn poll_expired(&mut self, cx: &mut Context) -> Poll<Option<u64>> {
|
||||
pub fn poll_expired(
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
expired_error: impl Fn() -> Res,
|
||||
) -> Poll<Option<u64>> {
|
||||
self.deadlines.poll_expired(cx).map(|expired| {
|
||||
let request_id = expired?.into_inner();
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
let _entered = request_data.span.enter();
|
||||
tracing::error!("DeadlineExceeded");
|
||||
self.request_data.compact(0.1);
|
||||
let _ = request_data
|
||||
.response_completion
|
||||
.send(Err(DeadlineExceededError));
|
||||
let _ = request_data.response_completion.send(expired_error());
|
||||
}
|
||||
Some(request_id)
|
||||
})
|
||||
|
||||
@@ -311,6 +311,7 @@ pub use crate::transport::sealed::Transport;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use futures::task::*;
|
||||
use std::sync::Arc;
|
||||
use std::{error::Error, fmt::Display, io, time::SystemTime};
|
||||
|
||||
/// A message from a client to a server.
|
||||
@@ -383,6 +384,29 @@ pub struct ServerError {
|
||||
pub detail: String,
|
||||
}
|
||||
|
||||
/// Critical errors that result in a Channel disconnecting.
|
||||
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
|
||||
pub enum ChannelError<E>
|
||||
where
|
||||
E: Error + Send + Sync + 'static,
|
||||
{
|
||||
/// Could not read from the transport.
|
||||
#[error("could not read from the transport")]
|
||||
Read(#[source] Arc<E>),
|
||||
/// Could not ready the transport for writes.
|
||||
#[error("could not ready the transport for writes")]
|
||||
Ready(#[source] E),
|
||||
/// Could not write to the transport.
|
||||
#[error("could not write to the transport")]
|
||||
Write(#[source] E),
|
||||
/// Could not flush the transport.
|
||||
#[error("could not flush the transport")]
|
||||
Flush(#[source] E),
|
||||
/// Could not close the write end of the transport.
|
||||
#[error("could not close the write end of the transport")]
|
||||
Close(#[source] E),
|
||||
}
|
||||
|
||||
impl<T> Request<T> {
|
||||
/// Returns the deadline for this request.
|
||||
pub fn deadline(&self) -> &SystemTime {
|
||||
|
||||
@@ -129,14 +129,6 @@ pub mod tcp {
|
||||
tokio_util::codec::length_delimited,
|
||||
};
|
||||
|
||||
mod private {
|
||||
use super::*;
|
||||
|
||||
pub trait Sealed {}
|
||||
|
||||
impl<Item, SinkItem, Codec> Sealed for Transport<TcpStream, Item, SinkItem, Codec> {}
|
||||
}
|
||||
|
||||
impl<Item, SinkItem, Codec> Transport<TcpStream, Item, SinkItem, Codec> {
|
||||
/// Returns the peer address of the underlying TcpStream.
|
||||
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
use crate::{
|
||||
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||
context::{self, SpanExt},
|
||||
trace, ClientMessage, Request, Response, Transport,
|
||||
trace, ChannelError, ClientMessage, Request, Response, Transport,
|
||||
};
|
||||
use ::tokio::sync::mpsc;
|
||||
use futures::{
|
||||
@@ -21,14 +21,7 @@ use futures::{
|
||||
};
|
||||
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
error::Error,
|
||||
fmt,
|
||||
marker::PhantomData,
|
||||
mem::{self, ManuallyDrop},
|
||||
pin::Pin,
|
||||
};
|
||||
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc};
|
||||
use tracing::{info_span, instrument::Instrument, Span};
|
||||
|
||||
mod in_flight_requests;
|
||||
@@ -208,10 +201,11 @@ where
|
||||
Ok(TrackedRequest {
|
||||
abort_registration,
|
||||
span,
|
||||
response_guard: ManuallyDrop::new(ResponseGuard {
|
||||
response_guard: ResponseGuard {
|
||||
request_id: request.id,
|
||||
request_cancellation: self.request_cancellation.clone(),
|
||||
}),
|
||||
cancel: false,
|
||||
},
|
||||
request,
|
||||
})
|
||||
}
|
||||
@@ -240,7 +234,7 @@ pub struct TrackedRequest<Req> {
|
||||
/// A span representing the server processing of this request.
|
||||
pub span: Span,
|
||||
/// An inert response guard. Becomes active in an InFlightRequest.
|
||||
pub response_guard: ManuallyDrop<ResponseGuard>,
|
||||
pub response_guard: ResponseGuard,
|
||||
}
|
||||
|
||||
/// The server end of an open connection with a client, receiving requests from, and sending
|
||||
@@ -343,20 +337,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// Critical errors that result in a Channel disconnecting.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ChannelError<E>
|
||||
where
|
||||
E: Error + Send + Sync + 'static,
|
||||
{
|
||||
/// An error occurred reading from, or writing to, the transport.
|
||||
#[error("an error occurred in the transport: {0}")]
|
||||
Transport(#[source] E),
|
||||
/// An error occurred while polling expired requests.
|
||||
#[error("an error occurred while polling expired requests: {0}")]
|
||||
Timer(#[source] ::tokio::time::error::Error),
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
@@ -413,7 +393,7 @@ where
|
||||
let request_status = match self
|
||||
.transport_pin_mut()
|
||||
.poll_next(cx)
|
||||
.map_err(ChannelError::Transport)?
|
||||
.map_err(|e| ChannelError::Read(Arc::new(e)))?
|
||||
{
|
||||
Poll::Ready(Some(message)) => match message {
|
||||
ClientMessage::Request(request) => {
|
||||
@@ -473,7 +453,7 @@ where
|
||||
self.project()
|
||||
.transport
|
||||
.poll_ready(cx)
|
||||
.map_err(ChannelError::Transport)
|
||||
.map_err(ChannelError::Ready)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
|
||||
@@ -486,7 +466,7 @@ where
|
||||
self.project()
|
||||
.transport
|
||||
.start_send(response)
|
||||
.map_err(ChannelError::Transport)
|
||||
.map_err(ChannelError::Write)
|
||||
} else {
|
||||
// If the request isn't tracked anymore, there's no need to send the response.
|
||||
Ok(())
|
||||
@@ -498,14 +478,14 @@ where
|
||||
self.project()
|
||||
.transport
|
||||
.poll_flush(cx)
|
||||
.map_err(ChannelError::Transport)
|
||||
.map_err(ChannelError::Flush)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
.transport
|
||||
.poll_close(cx)
|
||||
.map_err(ChannelError::Transport)
|
||||
.map_err(ChannelError::Close)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -581,13 +561,15 @@ where
|
||||
request,
|
||||
abort_registration,
|
||||
span,
|
||||
response_guard,
|
||||
mut response_guard,
|
||||
}| {
|
||||
// The response guard becomes active once in an InFlightRequest.
|
||||
response_guard.cancel = true;
|
||||
InFlightRequest {
|
||||
request,
|
||||
abort_registration,
|
||||
span,
|
||||
response_guard: ManuallyDrop::into_inner(response_guard),
|
||||
response_guard,
|
||||
response_tx: self.responses_tx.clone(),
|
||||
}
|
||||
},
|
||||
@@ -674,11 +656,14 @@ where
|
||||
pub struct ResponseGuard {
|
||||
request_cancellation: RequestCancellation,
|
||||
request_id: u64,
|
||||
cancel: bool,
|
||||
}
|
||||
|
||||
impl Drop for ResponseGuard {
|
||||
fn drop(&mut self) {
|
||||
self.request_cancellation.cancel(self.request_id);
|
||||
if self.cancel {
|
||||
self.request_cancellation.cancel(self.request_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -721,7 +706,7 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
{
|
||||
let Self {
|
||||
response_tx,
|
||||
response_guard,
|
||||
mut response_guard,
|
||||
abort_registration,
|
||||
span,
|
||||
request:
|
||||
@@ -732,6 +717,9 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
},
|
||||
} = self;
|
||||
let method = serve.method(&message);
|
||||
// TODO(https://github.com/rust-lang/rust-clippy/issues/9111)
|
||||
// remove when clippy is fixed
|
||||
#[allow(clippy::needless_borrow)]
|
||||
span.record("otel.name", &method.unwrap_or(""));
|
||||
let _ = Abortable::new(
|
||||
async move {
|
||||
@@ -752,7 +740,7 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
// Request processing has completed, meaning either the channel canceled the request or
|
||||
// a request was sent back to the channel. Either way, the channel will clean up the
|
||||
// request data, so the request does not need to be canceled.
|
||||
mem::forget(response_guard);
|
||||
response_guard.cancel = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::{
|
||||
};
|
||||
use futures::{task::*, Sink, Stream};
|
||||
use pin_project::pin_project;
|
||||
use std::{collections::VecDeque, io, mem::ManuallyDrop, pin::Pin, time::SystemTime};
|
||||
use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
|
||||
use tracing::Span;
|
||||
|
||||
#[pin_project]
|
||||
@@ -101,10 +101,11 @@ impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
},
|
||||
abort_registration,
|
||||
span: Span::none(),
|
||||
response_guard: ManuallyDrop::new(ResponseGuard {
|
||||
response_guard: ResponseGuard {
|
||||
request_cancellation,
|
||||
request_id: id,
|
||||
}),
|
||||
cancel: false,
|
||||
},
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ error: unused `RequestDispatch` that must be used
|
||||
--> tests/compile_fail/must_use_request_dispatch.rs:13:9
|
||||
|
|
||||
13 | WorldClient::new(client::Config::default(), client_transport).dispatch;
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/must_use_request_dispatch.rs:11:12
|
||||
|
||||
@@ -2,7 +2,7 @@ error: unused `tarpc::serde_transport::tcp::Connect` that must be used
|
||||
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9
|
||||
|
|
||||
7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:5:12
|
||||
|
||||
@@ -2,7 +2,7 @@ error: unused `TokioChannelExecutor` that must be used
|
||||
--> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9
|
||||
|
|
||||
27 | server.execute(HelloServer.serve());
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12
|
||||
|
||||
@@ -2,7 +2,7 @@ error: unused `TokioServerExecutor` that must be used
|
||||
--> tests/compile_fail/tokio/must_use_server_executor.rs:28:9
|
||||
|
|
||||
28 | server.execute(HelloServer.serve());
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/tokio/must_use_server_executor.rs:26:12
|
||||
|
||||
Reference in New Issue
Block a user