9 Commits

Author SHA1 Message Date
Tim Kuehn
bed85e2827 Prepare release of v0.32.0 2023-03-24 15:04:06 -07:00
Bruno
93f3880025 Return transport errors to the caller (#399)
* Make client::InFlightRequests generic over result.

Previously, InFlightRequests required the client response type to be a
server response. However, this prevented injection of non-server
responses: for example, if the client fails to send a request, it should
complete the request with an IO error rather than a server error.

* Gracefully handle client-side send errors.

Previously, a client channel would immediately disconnect when
encountering an error in Transport::try_send. One kind of error that can
occur in try_send is message validation, e.g. validating a message is
not larger than a configured frame size. The problem with shutting down
the client immediately is that debuggability suffers: it can be hard to
understand what caused the client to fail. Also, these errors are not
always fatal, as with frame size limits, so complete shutdown was
extreme.

By bubbling up errors, it's now possible for the caller to
programmatically handle them. For example, the error could be walked
via anyhow::Error:

```
    2023-01-10T02:49:32.528939Z  WARN client: the client failed to send the request

    Caused by:
        0: could not write to the transport
        1: frame size too big
```

* Some follow-up work: right now, read errors will bubble up to all pending RPCs. However, on the write side, only `start_send` bubbles up. `poll_ready`, `poll_flush`, and `poll_close` do not propagate back to pending RPCs. This is probably okay in most circumstances, because fatal write errors likely coincide with fatal read errors, which *do* propagate back to clients. But it might still be worth unifying this logic.

---------

Co-authored-by: Tim Kuehn <tikue@google.com>
2023-03-24 14:31:25 -07:00
cguentherTUChemnitz
878f594d5b Feature/tls over tcp example (#398)
Example tarpc service that encodes messages with bincode written to a TLS-over-TCP transport.

Certs were generated with openssl 3 using https://github.com/rustls/rustls/tree/main/test-ca

New dependencies:
- tokio-rustls to set up the TLS connections
- rustls-pemfile to load certs from .pem files
2023-03-22 10:35:21 -07:00
Tim Kuehn
aa9bbad109 Fix compile_fail tests for formatting changes on stable 2023-03-17 10:07:54 -07:00
Tim Kuehn
7e872ce925 Remove bad mem::forget usage.
mem::forget is a dangerous tool, and it was being used carelessly for
things that have safer alternatives. There was at least one bug where a
cloned tokio::sync::mpsc::UnboundedSender used for request cancellation
was being leaked on every successful server response, so its refcounts
were never decremented. Because these are atomic refcounts, they'll wrap
around rather than overflow when reaching the maximum value, so I don't
believe this could lead to panics or unsoundness.
2022-11-23 18:01:12 -08:00
Tim Kuehn
62541b709d Replace actions-rs 2022-11-23 16:39:29 -08:00
Tim Kuehn
8c43f94fb6 Remove unused Sealed trait 2022-11-17 00:57:56 -08:00
Tim Kuehn
7fa4e5064d Ignore clippy false positive 2022-11-13 00:25:07 -08:00
Tim Kuehn
94db7610bb Require a static lifetime for request_name. 2022-11-05 11:43:03 -07:00
23 changed files with 611 additions and 211 deletions

View File

@@ -18,20 +18,11 @@ jobs:
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
- uses: dtolnay/rust-toolchain@stable
with:
profile: minimal
toolchain: stable
target: mipsel-unknown-linux-gnu
override: true
- uses: actions-rs/cargo@v1
with:
command: check
args: --all-features
- uses: actions-rs/cargo@v1
with:
command: check
args: --all-features --target mipsel-unknown-linux-gnu
targets: mipsel-unknown-linux-gnu
- run: cargo check --all-features
- run: cargo check --all-features --target mipsel-unknown-linux-gnu
test:
name: Test Suite
@@ -42,34 +33,13 @@ jobs:
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: test
- uses: actions-rs/cargo@v1
with:
command: test
args: --manifest-path tarpc/Cargo.toml --features serde1
- uses: actions-rs/cargo@v1
with:
command: test
args: --manifest-path tarpc/Cargo.toml --features tokio1
- uses: actions-rs/cargo@v1
with:
command: test
args: --manifest-path tarpc/Cargo.toml --features serde-transport
- uses: actions-rs/cargo@v1
with:
command: test
args: --manifest-path tarpc/Cargo.toml --features tcp
- uses: actions-rs/cargo@v1
with:
command: test
args: --all-features
- uses: dtolnay/rust-toolchain@stable
- run: cargo test
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde1
- run: cargo test --manifest-path tarpc/Cargo.toml --features tokio1
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde-transport
- run: cargo test --manifest-path tarpc/Cargo.toml --features tcp
- run: cargo test --all-features
fmt:
name: Rustfmt
@@ -80,16 +50,10 @@ jobs:
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
- uses: dtolnay/rust-toolchain@stable
with:
profile: minimal
toolchain: stable
override: true
- run: rustup component add rustfmt
- uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
components: rustfmt
- run: cargo fmt --all -- --check
clippy:
name: Clippy
@@ -100,13 +64,7 @@ jobs:
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
- uses: dtolnay/rust-toolchain@stable
with:
profile: minimal
toolchain: stable
override: true
- run: rustup component add clippy
- uses: actions-rs/cargo@v1
with:
command: clippy
args: --all-features -- -D warnings
components: clippy
- run: cargo clippy --all-features -- -D warnings

View File

@@ -67,7 +67,7 @@ Some other features of tarpc:
Add to your `Cargo.toml` dependencies:
```toml
tarpc = "0.31"
tarpc = "0.32"
```
The `tarpc::service` attribute expands to a collection of items that form an rpc service.

View File

@@ -1,3 +1,21 @@
## 0.32.0 (2023-03-24)
### Breaking Changes
- As part of a fix to return more channel errors in RPC results, a few error types have changed:
0. `client::RpcError::Disconnected` was split into the following errors:
- Shutdown: the client was shutdown, either intentionally or due to an error. If due to an
error, pending RPCs should see the more specific errors below.
- Send: an RPC message failed to send over the transport. Only the RPC that failed to be sent
will see this error.
- Receive: a fatal error occurred while receiving from the transport. All in-flight RPCs will
receive this error.
0. `client::ChannelError` and `server::ChannelError` are unified in `tarpc::ChannelError`.
Previously, server transport errors would not indicate during which activity the transport
error occurred. Now, just like the client already was, it will be specific: reading, readying,
sending, flushing, or closing.
## 0.31.0 (2022-11-03)
### New Features

View File

@@ -1,6 +1,6 @@
[package]
name = "tarpc-example-service"
version = "0.13.0"
version = "0.14.0"
rust-version = "1.56"
authors = ["Tim Kuehn <tikue@google.com>"]
edition = "2021"
@@ -21,7 +21,7 @@ futures = "0.3"
opentelemetry = { version = "0.17", features = ["rt-tokio"] }
opentelemetry-jaeger = { version = "0.16", features = ["rt-tokio"] }
rand = "0.8"
tarpc = { version = "0.31", path = "../tarpc", features = ["full"] }
tarpc = { version = "0.32", path = "../tarpc", features = ["full"] }
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
tracing = { version = "0.1" }
tracing-opentelemetry = "0.17"

View File

@@ -26,7 +26,8 @@ async fn main() -> anyhow::Result<()> {
let flags = Flags::parse();
init_tracing("Tarpc Example Client")?;
let transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
transport.config_mut().max_frame_length(usize::MAX);
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
// config and any Transport as input.
@@ -42,7 +43,10 @@ async fn main() -> anyhow::Result<()> {
.instrument(tracing::info_span!("Two Hellos"))
.await;
tracing::info!("{:?}", hello);
match hello {
Ok(hello) => tracing::info!("{hello:?}"),
Err(e) => tracing::warn!("{:?}", anyhow::Error::from(e)),
}
// Let the background span processor finish.
sleep(Duration::from_micros(1)).await;

View File

@@ -1,6 +1,6 @@
[package]
name = "tarpc"
version = "0.31.0"
version = "0.32.0"
rust-version = "1.58.0"
authors = [
"Adam Wright <adam.austin.wright@gmail.com>",
@@ -78,6 +78,8 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tokio = { version = "1", features = ["full", "test-util"] }
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
trybuild = "1.0"
tokio-rustls = "0.23"
rustls-pemfile = "1.0"
[package.metadata.docs.rs]
all-features = true
@@ -103,6 +105,10 @@ required-features = ["full"]
name = "custom_transport"
required-features = ["serde1", "tokio1", "serde-transport"]
[[example]]
name = "tls_over_tcp"
required-features = ["full"]
[[test]]
name = "service_functional"
required-features = ["serde-transport"]

View File

@@ -0,0 +1,11 @@
-----BEGIN CERTIFICATE-----
MIIBlDCCAUagAwIBAgICAxUwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
NjEwMTEwNFowGjEYMBYGA1UEAwwPcG9ueXRvd24gY2xpZW50MCowBQYDK2VwAyEA
NTKuLume19IhJfEFd/5OZUuYDKZH6xvy4AGver17OoejgZswgZgwDAYDVR0TAQH/
BAIwADALBgNVHQ8EBAMCBsAwFgYDVR0lAQH/BAwwCgYIKwYBBQUHAwIwHQYDVR0O
BBYEFDjdrlMu4tyw5MHtbg7WnzSGRBpFMEQGA1UdIwQ9MDuAFHIl7fHKWP6/l8FE
fI2YEIM3oHxKoSCkHjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQYIBezAF
BgMrZXADQQCaahfj/QLxoCOpvl6y0ZQ9CpojPqBnxV3460j5nUOp040Va2MpF137
izCBY7LwgUE/YG6E+kH30G4jMEnqVEYK
-----END CERTIFICATE-----

View File

@@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
-----END CERTIFICATE-----

View File

@@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIIJX9ThTHpVS1SNZb6HP4myg4fRInIVGunTRdgnc+weH
-----END PRIVATE KEY-----

View File

@@ -0,0 +1,12 @@
-----BEGIN CERTIFICATE-----
MIIBuDCCAWqgAwIBAgICAcgwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
NjEwMTEwNFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wKjAFBgMrZXADIQDc
RLl3/N2tPoWnzBV3noVn/oheEl8IUtiY11Vg/QXTUKOBwDCBvTAMBgNVHRMBAf8E
AjAAMAsGA1UdDwQEAwIGwDAdBgNVHQ4EFgQUk7U2mnxedNWBAH84BsNy5si3ZQow
RAYDVR0jBD0wO4AUciXt8cpY/r+XwUR8jZgQgzegfEqhIKQeMBwxGjAYBgNVBAMM
EXBvbnl0b3duIEVkRFNBIENBggF7MDsGA1UdEQQ0MDKCDnRlc3RzZXJ2ZXIuY29t
ghVzZWNvbmQudGVzdHNlcnZlci5jb22CCWxvY2FsaG9zdDAFBgMrZXADQQCFWIcF
9FiztCuUNzgXDNu5kshuflt0RjkjWpGlWzQjGoYM2IvYhNVPeqnCiY92gqwDSBtq
amD2TBup4eNUCsQB
-----END CERTIFICATE-----

View File

@@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
-----END CERTIFICATE-----

View File

@@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIMU6xGVe8JTpZ3bN/wajHfw6pEHt0Rd7wPBxds9eEFy2
-----END PRIVATE KEY-----

View File

@@ -0,0 +1,152 @@
use rustls_pemfile::certs;
use std::io::{BufReader, Cursor};
use std::net::{IpAddr, Ipv4Addr};
use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio_rustls::rustls::{self, Certificate, OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{webpki, TlsAcceptor, TlsConnector};
use tarpc::context::Context;
use tarpc::serde_transport as transport;
use tarpc::server::{BaseChannel, Channel};
use tarpc::tokio_serde::formats::Bincode;
use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec;
#[tarpc::service]
pub trait PingService {
async fn ping() -> String;
}
#[derive(Clone)]
struct Service;
#[tarpc::server]
impl PingService for Service {
async fn ping(self, _: Context) -> String {
"🔒".to_owned()
}
}
// certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca
// used on client-side for server tls
const END_CHAIN: &[u8] = include_bytes!("certs/eddsa/end.chain");
// used on client-side for client-auth
const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key");
const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert");
// used on server-side for server tls
const END_CERT: &str = include_str!("certs/eddsa/end.cert");
const END_PRIVATEKEY: &str = include_str!("certs/eddsa/end.key");
// used on server-side for client-auth
const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain");
pub fn load_private_key(key: &str) -> rustls::PrivateKey {
let mut reader = BufReader::new(Cursor::new(key));
loop {
match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
Some(rustls_pemfile::Item::RSAKey(key)) => return rustls::PrivateKey(key),
Some(rustls_pemfile::Item::PKCS8Key(key)) => return rustls::PrivateKey(key),
Some(rustls_pemfile::Item::ECKey(key)) => return rustls::PrivateKey(key),
None => break,
_ => {}
}
}
panic!("no keys found in {:?} (encrypted keys not supported)", key);
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// -------------------- start here to setup tls tcp tokio stream --------------------------
// ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs
// ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs
let cert = certs(&mut BufReader::new(Cursor::new(END_CERT)))
.unwrap()
.into_iter()
.map(rustls::Certificate)
.collect();
let key = load_private_key(END_PRIVATEKEY);
let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000);
// ------------- server side client_auth cert loading start
let roots: Vec<Certificate> = certs(&mut BufReader::new(Cursor::new(CLIENT_CHAIN_CLIENT_AUTH)))
.unwrap()
.into_iter()
.map(rustls::Certificate)
.collect();
let mut client_auth_roots = RootCertStore::empty();
for root in roots {
client_auth_roots.add(&root).unwrap();
}
let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots);
// ------------- server side client_auth cert loading end
let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_client_cert_verifier(client_auth) // use .with_no_client_auth() instead if you don't want client-auth
.with_single_cert(cert, key)
.unwrap();
let acceptor = TlsAcceptor::from(Arc::new(config));
let listener = TcpListener::bind(&server_addr).await.unwrap();
let codec_builder = LengthDelimitedCodec::builder();
// ref ./custom_transport.rs server side
tokio::spawn(async move {
loop {
let (stream, _peer_addr) = listener.accept().await.unwrap();
let acceptor = acceptor.clone();
let tls_stream = acceptor.accept(stream).await.unwrap();
let framed = codec_builder.new_framed(tls_stream);
let transport = transport::new(framed, Bincode::default());
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
tokio::spawn(fut);
}
});
// ---------------------- client connection ---------------------
// cert loading from: https://github.com/tokio-rs/tls/blob/357bc562483dcf04c1f8d08bd1a831b144bf7d4c/tokio-rustls/tests/test.rs#L113
// tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
let chain = certs(&mut std::io::Cursor::new(END_CHAIN)).unwrap();
let mut root_store = rustls::RootCertStore::empty();
root_store.add_server_trust_anchors(chain.iter().map(|cert| {
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH);
let client_auth_certs: Vec<Certificate> =
certs(&mut BufReader::new(Cursor::new(CLIENT_CERT_CLIENT_AUTH)))
.unwrap()
.into_iter()
.map(rustls::Certificate)
.collect();
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_single_cert(client_auth_certs, client_auth_private_key)?; // use .with_no_client_auth() instead if you don't want client-auth
let domain = rustls::ServerName::try_from("localhost")?;
let connector = TlsConnector::from(Arc::new(config));
let stream = TcpStream::connect(server_addr).await?;
let stream = connector.connect(domain, stream).await?;
let transport = transport::new(codec_builder.new_framed(stream), Bincode::default());
let answer = PingServiceClient::new(Default::default(), transport)
.spawn()
.ping(tarpc::context::current())
.await?;
println!("ping answer: {answer}");
Ok(())
}

View File

@@ -10,15 +10,14 @@ mod in_flight_requests;
use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation},
context, trace, ClientMessage, Request, Response, ServerError, Transport,
context, trace, ChannelError, ClientMessage, Request, Response, ServerError, Transport,
};
use futures::{prelude::*, ready, stream::Fuse, task::*};
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
use in_flight_requests::InFlightRequests;
use pin_project::pin_project;
use std::{
convert::TryFrom,
error::Error,
fmt, mem,
fmt,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
@@ -124,7 +123,7 @@ impl<Req, Resp> Channel<Req, Resp> {
pub async fn call(
&self,
mut ctx: context::Context,
request_name: &str,
request_name: &'static str,
request: Req,
) -> Result<Resp, RpcError> {
let span = Span::current();
@@ -147,6 +146,7 @@ impl<Req, Resp> Channel<Req, Resp> {
response: &mut response,
request_id,
cancellation: &self.cancellation,
cancel: true,
};
self.to_dispatch
.send(DispatchRequest {
@@ -157,7 +157,7 @@ impl<Req, Resp> Channel<Req, Resp> {
response_completion,
})
.await
.map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
.map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?;
response_guard.response().await
}
}
@@ -165,19 +165,25 @@ impl<Req, Resp> Channel<Req, Resp> {
/// A server response that is completed by request dispatch when the corresponding response
/// arrives off the wire.
struct ResponseGuard<'a, Resp> {
response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
response: &'a mut oneshot::Receiver<Result<Resp, RpcError>>,
cancellation: &'a RequestCancellation,
request_id: u64,
cancel: bool,
}
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
/// rather cross-cutting errors that can always occur.
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[derive(thiserror::Error, Debug)]
pub enum RpcError {
/// The client disconnected from the server.
#[error("the client disconnected from the server")]
Disconnected,
#[error("the connection to the server was already shutdown")]
Shutdown,
/// The client failed to send the request.
#[error("the client failed to send the request")]
Send(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
/// An error occurred while waiting for the server response.
#[error("an error occurred while waiting for the server response")]
Receive(#[source] Arc<dyn std::error::Error + Send + Sync + 'static>),
/// The request exceeded its deadline.
#[error("the request exceeded its deadline")]
DeadlineExceeded,
@@ -186,24 +192,18 @@ pub enum RpcError {
Server(#[from] ServerError),
}
impl From<DeadlineExceededError> for RpcError {
fn from(_: DeadlineExceededError) -> Self {
RpcError::DeadlineExceeded
}
}
impl<Resp> ResponseGuard<'_, Resp> {
async fn response(mut self) -> Result<Resp, RpcError> {
let response = (&mut self.response).await;
// Cancel drop logic once a response has been received.
mem::forget(self);
self.cancel = false;
match response {
Ok(resp) => Ok(resp?.message?),
Ok(response) => response,
Err(oneshot::error::RecvError { .. }) => {
// The oneshot is Canceled when the dispatch task ends. In that case,
// there's nothing listening on the other side, so there's no point in
// propagating cancellation.
Err(RpcError::Disconnected)
Err(RpcError::Shutdown)
}
}
}
@@ -223,7 +223,9 @@ impl<Resp> Drop for ResponseGuard<'_, Resp> {
// dispatch task misses an early-arriving cancellation message, then it will see the
// receiver as closed.
self.response.close();
self.cancellation.cancel(self.request_id);
if self.cancel {
self.cancellation.cancel(self.request_id);
}
}
}
@@ -238,7 +240,6 @@ where
{
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
let (cancellation, canceled_requests) = cancellations();
let canceled_requests = canceled_requests;
NewClient {
client: Channel {
@@ -270,42 +271,18 @@ pub struct RequestDispatch<Req, Resp, C> {
/// Requests that were dropped.
canceled_requests: CanceledRequests,
/// Requests already written to the wire that haven't yet received responses.
in_flight_requests: InFlightRequests<Resp>,
in_flight_requests: InFlightRequests<Result<Resp, RpcError>>,
/// Configures limits to prevent unlimited resource usage.
config: Config,
}
/// Critical errors that result in a Channel disconnecting.
#[derive(thiserror::Error, Debug)]
pub enum ChannelError<E>
where
E: Error + Send + Sync + 'static,
{
/// Could not read from the transport.
#[error("could not read from the transport")]
Read(#[source] E),
/// Could not ready the transport for writes.
#[error("could not ready the transport for writes")]
Ready(#[source] E),
/// Could not write to the transport.
#[error("could not write to the transport")]
Write(#[source] E),
/// Could not flush the transport.
#[error("could not flush the transport")]
Flush(#[source] E),
/// Could not close the write end of the transport.
#[error("could not close the write end of the transport")]
Close(#[source] E),
/// Could not poll expired requests.
#[error("could not poll expired requests")]
Timer(#[source] tokio::time::error::Error),
}
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<Resp> {
fn in_flight_requests<'a>(
self: &'a mut Pin<&mut Self>,
) -> &'a mut InFlightRequests<Result<Resp, RpcError>> {
self.as_mut().project().in_flight_requests
}
@@ -365,7 +342,17 @@ where
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
self.transport_pin_mut()
.poll_next(cx)
.map_err(ChannelError::Read)
.map_err(|e| {
let e = Arc::new(e);
for span in self
.in_flight_requests()
.complete_all_requests(|| Err(RpcError::Receive(e.clone())))
{
let _entered = span.enter();
tracing::info!("ReceiveError");
}
ChannelError::Read(e)
})
.map_ok(|response| {
self.complete(response);
})
@@ -395,7 +382,10 @@ where
// Receiving Poll::Ready(None) when polling expired requests never indicates "Closed",
// because there can temporarily be zero in-flight rquests. Therefore, there is no need to
// track the status like is done with pending and cancelled requests.
if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx) {
if let Poll::Ready(Some(_)) = self
.in_flight_requests()
.poll_expired(cx, || Err(RpcError::DeadlineExceeded))
{
// Expired requests are considered complete; there is no compelling reason to send a
// cancellation message to the server, since it will have already exhausted its
// allotted processing time.
@@ -506,7 +496,7 @@ where
Some(dispatch_request) => dispatch_request,
None => return Poll::Ready(None),
};
let entered = span.enter();
let _entered = span.enter();
// poll_next_request only returns Ready if there is room to buffer another request.
// Therefore, we can call write_request without fear of erroring due to a full
// buffer.
@@ -519,13 +509,16 @@ where
trace_context: ctx.trace_context,
},
});
self.start_send(request)?;
tracing::info!("SendRequest");
drop(entered);
self.in_flight_requests()
.insert_request(request_id, ctx, span, response_completion)
.insert_request(request_id, ctx, span.clone(), response_completion)
.expect("Request IDs should be unique");
match self.start_send(request) {
Ok(()) => tracing::info!("SendRequest"),
Err(e) => {
self.in_flight_requests()
.complete_request(request_id, Err(RpcError::Send(Box::new(e))));
}
}
Poll::Ready(Some(Ok(())))
}
@@ -550,7 +543,10 @@ where
/// Sends a server response to the client task that initiated the associated request.
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
self.in_flight_requests().complete_request(response)
self.in_flight_requests().complete_request(
response.request_id,
response.message.map_err(RpcError::Server),
)
}
}
@@ -599,30 +595,37 @@ struct DispatchRequest<Req, Resp> {
pub span: Span,
pub request_id: u64,
pub request: Req,
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
pub response_completion: oneshot::Sender<Result<Resp, RpcError>>,
}
#[cfg(test)]
mod tests {
use super::{cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard};
use super::{
cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError,
};
use crate::{
client::{
in_flight_requests::{DeadlineExceededError, InFlightRequests},
Config,
},
context,
client::{in_flight_requests::InFlightRequests, Config},
context::{self, current},
transport::{self, channel::UnboundedChannel},
ClientMessage, Response,
ChannelError, ClientMessage, Response,
};
use assert_matches::assert_matches;
use futures::{prelude::*, task::*};
use std::{
convert::TryFrom,
fmt::Display,
marker::PhantomData,
pin::Pin,
sync::atomic::{AtomicUsize, Ordering},
sync::Arc,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use thiserror::Error;
use tokio::sync::{
mpsc::{self},
oneshot,
};
use tokio::sync::{mpsc, oneshot};
use tracing::Span;
#[tokio::test]
@@ -643,7 +646,7 @@ mod tests {
.await
.unwrap();
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp");
}
#[tokio::test]
@@ -654,6 +657,7 @@ mod tests {
response: &mut response,
cancellation: &cancellation,
request_id: 3,
cancel: true,
});
// resp's drop() is run, which should send a cancel message.
let cx = &mut Context::from_waker(noop_waker_ref());
@@ -674,6 +678,7 @@ mod tests {
response: &mut response,
cancellation: &cancellation,
request_id: 3,
cancel: true,
}
.response()
.await
@@ -775,6 +780,185 @@ mod tests {
assert!(dispatch.as_mut().poll_next_request(cx).is_pending());
}
#[tokio::test]
async fn test_shutdown_error() {
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
let (dispatch, mut channel, _) = set_up();
let (tx, mut rx) = oneshot::channel();
// send succeeds
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
drop(dispatch);
// error on receive
assert_matches!(resp.response().await, Err(RpcError::Shutdown));
let (dispatch, channel, _) = set_up();
drop(dispatch);
// error on send
let resp = channel
.call(current(), "test_request", "hi".to_string())
.await;
assert_matches!(resp, Err(RpcError::Shutdown));
}
#[tokio::test]
async fn test_transport_error_write() {
let cause = TransportError::Write;
let (mut dispatch, mut channel, mut cx) = setup_always_err(cause);
let (tx, mut rx) = oneshot::channel();
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
assert!(dispatch.as_mut().poll(&mut cx).is_pending());
let res = resp.response().await;
assert_matches!(res, Err(RpcError::Send(_)));
let client_error: anyhow::Error = res.unwrap_err().into();
let mut chain = client_error.chain();
chain.next(); // original RpcError
assert_eq!(
chain
.next()
.unwrap()
.downcast_ref::<ChannelError<TransportError>>(),
Some(&ChannelError::Write(cause))
);
assert_eq!(
client_error.root_cause().downcast_ref::<TransportError>(),
Some(&cause)
);
}
#[tokio::test]
async fn test_transport_error_read() {
let cause = TransportError::Read;
let (mut dispatch, mut channel, mut cx) = setup_always_err(cause);
let (tx, mut rx) = oneshot::channel();
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
assert_eq!(
dispatch.as_mut().pump_write(&mut cx),
Poll::Ready(Some(Ok(())))
);
assert_eq!(
dispatch.as_mut().pump_read(&mut cx),
Poll::Ready(Some(Err(ChannelError::Read(Arc::new(cause)))))
);
assert_matches!(resp.response().await, Err(RpcError::Receive(_)));
}
#[tokio::test]
async fn test_transport_error_ready() {
let cause = TransportError::Ready;
let (mut dispatch, _, mut cx) = setup_always_err(cause);
assert_eq!(
dispatch.as_mut().poll(&mut cx),
Poll::Ready(Err(ChannelError::Ready(cause)))
);
}
#[tokio::test]
async fn test_transport_error_flush() {
let cause = TransportError::Flush;
let (mut dispatch, _, mut cx) = setup_always_err(cause);
assert_eq!(
dispatch.as_mut().poll(&mut cx),
Poll::Ready(Err(ChannelError::Flush(cause)))
);
}
#[tokio::test]
async fn test_transport_error_close() {
let cause = TransportError::Close;
let (mut dispatch, channel, mut cx) = setup_always_err(cause);
drop(channel);
assert_eq!(
dispatch.as_mut().poll(&mut cx),
Poll::Ready(Err(ChannelError::Close(cause)))
);
}
fn setup_always_err(
cause: TransportError,
) -> (
Pin<Box<RequestDispatch<String, String, AlwaysErrorTransport<String>>>>,
Channel<String, String>,
Context<'static>,
) {
let (to_dispatch, pending_requests) = mpsc::channel(1);
let (cancellation, canceled_requests) = cancellations();
let transport: AlwaysErrorTransport<String> = AlwaysErrorTransport(cause, PhantomData);
let dispatch = Box::pin(RequestDispatch::<String, String, _> {
transport: transport.fuse(),
pending_requests,
canceled_requests,
in_flight_requests: InFlightRequests::default(),
config: Config::default(),
});
let channel = Channel {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicUsize::new(0)),
};
let cx = Context::from_waker(noop_waker_ref());
(dispatch, channel, cx)
}
struct AlwaysErrorTransport<I>(TransportError, PhantomData<I>);
#[derive(Debug, Error, PartialEq, Eq, Clone, Copy)]
enum TransportError {
Read,
Ready,
Write,
Flush,
Close,
}
impl Display for TransportError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&format!("{self:?}"))
}
}
impl<I: Clone, S> Sink<S> for AlwaysErrorTransport<I> {
type Error = TransportError;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.0 {
TransportError::Ready => Poll::Ready(Err(self.0)),
TransportError::Flush => Poll::Pending,
_ => Poll::Ready(Ok(())),
}
}
fn start_send(self: Pin<&mut Self>, _: S) -> Result<(), Self::Error> {
if matches!(self.0, TransportError::Write) {
Err(self.0)
} else {
Ok(())
}
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if matches!(self.0, TransportError::Flush) {
Poll::Ready(Err(self.0))
} else {
Poll::Ready(Ok(()))
}
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if matches!(self.0, TransportError::Close) {
Poll::Ready(Err(self.0))
} else {
Poll::Ready(Ok(()))
}
}
}
impl<I: Clone> Stream for AlwaysErrorTransport<I> {
type Item = Result<Response<I>, TransportError>;
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if matches!(self.0, TransportError::Read) {
Poll::Ready(Some(Err(self.0)))
} else {
Poll::Pending
}
}
}
fn set_up() -> (
Pin<
Box<
@@ -814,8 +998,8 @@ mod tests {
async fn send_request<'a>(
channel: &'a mut Channel<String, String>,
request: &str,
response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
response_completion: oneshot::Sender<Result<String, RpcError>>,
response: &'a mut oneshot::Receiver<Result<String, RpcError>>,
) -> ResponseGuard<'a, String> {
let request_id =
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
@@ -830,6 +1014,7 @@ mod tests {
response,
cancellation: &channel.cancellation,
request_id,
cancel: true,
};
channel.to_dispatch.send(request).await.unwrap();
response_guard

View File

@@ -1,7 +1,6 @@
use crate::{
context,
util::{Compact, TimeUntil},
Response,
};
use fnv::FnvHashMap;
use std::{
@@ -28,17 +27,11 @@ impl<Resp> Default for InFlightRequests<Resp> {
}
}
/// The request exceeded its deadline.
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[error("the request exceeded its deadline")]
pub struct DeadlineExceededError;
#[derive(Debug)]
struct RequestData<Resp> {
struct RequestData<Res> {
ctx: context::Context,
span: Span,
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
response_completion: oneshot::Sender<Res>,
/// The key to remove the timer for the request's deadline.
deadline_key: delay_queue::Key,
}
@@ -48,7 +41,7 @@ struct RequestData<Resp> {
#[derive(Debug)]
pub struct AlreadyExistsError;
impl<Resp> InFlightRequests<Resp> {
impl<Res> InFlightRequests<Res> {
/// Returns the number of in-flight requests.
pub fn len(&self) -> usize {
self.request_data.len()
@@ -65,7 +58,7 @@ impl<Resp> InFlightRequests<Resp> {
request_id: u64,
ctx: context::Context,
span: Span,
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
response_completion: oneshot::Sender<Res>,
) -> Result<(), AlreadyExistsError> {
match self.request_data.entry(request_id) {
hash_map::Entry::Vacant(vacant) => {
@@ -84,25 +77,35 @@ impl<Resp> InFlightRequests<Resp> {
}
/// Removes a request without aborting. Returns true iff the request was found.
pub fn complete_request(&mut self, response: Response<Resp>) -> bool {
if let Some(request_data) = self.request_data.remove(&response.request_id) {
pub fn complete_request(&mut self, request_id: u64, result: Res) -> bool {
if let Some(request_data) = self.request_data.remove(&request_id) {
let _entered = request_data.span.enter();
tracing::info!("ReceiveResponse");
self.request_data.compact(0.1);
self.deadlines.remove(&request_data.deadline_key);
let _ = request_data.response_completion.send(Ok(response));
let _ = request_data.response_completion.send(result);
return true;
}
tracing::debug!(
"No in-flight request found for request_id = {}.",
response.request_id
);
tracing::debug!("No in-flight request found for request_id = {request_id}.");
// If the response completion was absent, then the request was already canceled.
false
}
/// Completes all requests using the provided function.
/// Returns Spans for all completes requests.
pub fn complete_all_requests<'a>(
&'a mut self,
mut result: impl FnMut() -> Res + 'a,
) -> impl Iterator<Item = Span> + 'a {
self.deadlines.clear();
self.request_data.drain().map(move |(_, request_data)| {
let _ = request_data.response_completion.send(result());
request_data.span
})
}
/// Cancels a request without completing (typically used when a request handle was dropped
/// before the request completed).
pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> {
@@ -117,16 +120,18 @@ impl<Resp> InFlightRequests<Resp> {
/// Yields a request that has expired, completing it with a TimedOut error.
/// The caller should send cancellation messages for any yielded request ID.
pub fn poll_expired(&mut self, cx: &mut Context) -> Poll<Option<u64>> {
pub fn poll_expired(
&mut self,
cx: &mut Context,
expired_error: impl Fn() -> Res,
) -> Poll<Option<u64>> {
self.deadlines.poll_expired(cx).map(|expired| {
let request_id = expired?.into_inner();
if let Some(request_data) = self.request_data.remove(&request_id) {
let _entered = request_data.span.enter();
tracing::error!("DeadlineExceeded");
self.request_data.compact(0.1);
let _ = request_data
.response_completion
.send(Err(DeadlineExceededError));
let _ = request_data.response_completion.send(expired_error());
}
Some(request_id)
})

View File

@@ -311,6 +311,7 @@ pub use crate::transport::sealed::Transport;
use anyhow::Context as _;
use futures::task::*;
use std::sync::Arc;
use std::{error::Error, fmt::Display, io, time::SystemTime};
/// A message from a client to a server.
@@ -383,6 +384,29 @@ pub struct ServerError {
pub detail: String,
}
/// Critical errors that result in a Channel disconnecting.
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
pub enum ChannelError<E>
where
E: Error + Send + Sync + 'static,
{
/// Could not read from the transport.
#[error("could not read from the transport")]
Read(#[source] Arc<E>),
/// Could not ready the transport for writes.
#[error("could not ready the transport for writes")]
Ready(#[source] E),
/// Could not write to the transport.
#[error("could not write to the transport")]
Write(#[source] E),
/// Could not flush the transport.
#[error("could not flush the transport")]
Flush(#[source] E),
/// Could not close the write end of the transport.
#[error("could not close the write end of the transport")]
Close(#[source] E),
}
impl<T> Request<T> {
/// Returns the deadline for this request.
pub fn deadline(&self) -> &SystemTime {

View File

@@ -129,14 +129,6 @@ pub mod tcp {
tokio_util::codec::length_delimited,
};
mod private {
use super::*;
pub trait Sealed {}
impl<Item, SinkItem, Codec> Sealed for Transport<TcpStream, Item, SinkItem, Codec> {}
}
impl<Item, SinkItem, Codec> Transport<TcpStream, Item, SinkItem, Codec> {
/// Returns the peer address of the underlying TcpStream.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {

View File

@@ -9,7 +9,7 @@
use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation},
context::{self, SpanExt},
trace, ClientMessage, Request, Response, Transport,
trace, ChannelError, ClientMessage, Request, Response, Transport,
};
use ::tokio::sync::mpsc;
use futures::{
@@ -21,14 +21,7 @@ use futures::{
};
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
use pin_project::pin_project;
use std::{
convert::TryFrom,
error::Error,
fmt,
marker::PhantomData,
mem::{self, ManuallyDrop},
pin::Pin,
};
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc};
use tracing::{info_span, instrument::Instrument, Span};
mod in_flight_requests;
@@ -208,10 +201,11 @@ where
Ok(TrackedRequest {
abort_registration,
span,
response_guard: ManuallyDrop::new(ResponseGuard {
response_guard: ResponseGuard {
request_id: request.id,
request_cancellation: self.request_cancellation.clone(),
}),
cancel: false,
},
request,
})
}
@@ -240,7 +234,7 @@ pub struct TrackedRequest<Req> {
/// A span representing the server processing of this request.
pub span: Span,
/// An inert response guard. Becomes active in an InFlightRequest.
pub response_guard: ManuallyDrop<ResponseGuard>,
pub response_guard: ResponseGuard,
}
/// The server end of an open connection with a client, receiving requests from, and sending
@@ -343,20 +337,6 @@ where
}
}
/// Critical errors that result in a Channel disconnecting.
#[derive(thiserror::Error, Debug)]
pub enum ChannelError<E>
where
E: Error + Send + Sync + 'static,
{
/// An error occurred reading from, or writing to, the transport.
#[error("an error occurred in the transport: {0}")]
Transport(#[source] E),
/// An error occurred while polling expired requests.
#[error("an error occurred while polling expired requests: {0}")]
Timer(#[source] ::tokio::time::error::Error),
}
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
@@ -413,7 +393,7 @@ where
let request_status = match self
.transport_pin_mut()
.poll_next(cx)
.map_err(ChannelError::Transport)?
.map_err(|e| ChannelError::Read(Arc::new(e)))?
{
Poll::Ready(Some(message)) => match message {
ClientMessage::Request(request) => {
@@ -473,7 +453,7 @@ where
self.project()
.transport
.poll_ready(cx)
.map_err(ChannelError::Transport)
.map_err(ChannelError::Ready)
}
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
@@ -486,7 +466,7 @@ where
self.project()
.transport
.start_send(response)
.map_err(ChannelError::Transport)
.map_err(ChannelError::Write)
} else {
// If the request isn't tracked anymore, there's no need to send the response.
Ok(())
@@ -498,14 +478,14 @@ where
self.project()
.transport
.poll_flush(cx)
.map_err(ChannelError::Transport)
.map_err(ChannelError::Flush)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project()
.transport
.poll_close(cx)
.map_err(ChannelError::Transport)
.map_err(ChannelError::Close)
}
}
@@ -581,13 +561,15 @@ where
request,
abort_registration,
span,
response_guard,
mut response_guard,
}| {
// The response guard becomes active once in an InFlightRequest.
response_guard.cancel = true;
InFlightRequest {
request,
abort_registration,
span,
response_guard: ManuallyDrop::into_inner(response_guard),
response_guard,
response_tx: self.responses_tx.clone(),
}
},
@@ -674,11 +656,14 @@ where
pub struct ResponseGuard {
request_cancellation: RequestCancellation,
request_id: u64,
cancel: bool,
}
impl Drop for ResponseGuard {
fn drop(&mut self) {
self.request_cancellation.cancel(self.request_id);
if self.cancel {
self.request_cancellation.cancel(self.request_id);
}
}
}
@@ -721,7 +706,7 @@ impl<Req, Res> InFlightRequest<Req, Res> {
{
let Self {
response_tx,
response_guard,
mut response_guard,
abort_registration,
span,
request:
@@ -732,6 +717,9 @@ impl<Req, Res> InFlightRequest<Req, Res> {
},
} = self;
let method = serve.method(&message);
// TODO(https://github.com/rust-lang/rust-clippy/issues/9111)
// remove when clippy is fixed
#[allow(clippy::needless_borrow)]
span.record("otel.name", &method.unwrap_or(""));
let _ = Abortable::new(
async move {
@@ -752,7 +740,7 @@ impl<Req, Res> InFlightRequest<Req, Res> {
// Request processing has completed, meaning either the channel canceled the request or
// a request was sent back to the channel. Either way, the channel will clean up the
// request data, so the request does not need to be canceled.
mem::forget(response_guard);
response_guard.cancel = false;
}
}

View File

@@ -12,7 +12,7 @@ use crate::{
};
use futures::{task::*, Sink, Stream};
use pin_project::pin_project;
use std::{collections::VecDeque, io, mem::ManuallyDrop, pin::Pin, time::SystemTime};
use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
use tracing::Span;
#[pin_project]
@@ -101,10 +101,11 @@ impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
},
abort_registration,
span: Span::none(),
response_guard: ManuallyDrop::new(ResponseGuard {
response_guard: ResponseGuard {
request_cancellation,
request_id: id,
}),
cancel: false,
},
}));
}
}

View File

@@ -2,7 +2,7 @@ error: unused `RequestDispatch` that must be used
--> tests/compile_fail/must_use_request_dispatch.rs:13:9
|
13 | WorldClient::new(client::Config::default(), client_transport).dispatch;
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/must_use_request_dispatch.rs:11:12

View File

@@ -2,7 +2,7 @@ error: unused `tarpc::serde_transport::tcp::Connect` that must be used
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9
|
7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:5:12

View File

@@ -2,7 +2,7 @@ error: unused `TokioChannelExecutor` that must be used
--> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9
|
27 | server.execute(HelloServer.serve());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12

View File

@@ -2,7 +2,7 @@ error: unused `TokioServerExecutor` that must be used
--> tests/compile_fail/tokio/must_use_server_executor.rs:28:9
|
28 | server.execute(HelloServer.serve());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/tokio/must_use_server_executor.rs:26:12