mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
Compare commits
21 Commits
v0.26.0
...
client-clo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e4c3a2b8b | ||
|
|
d78b24b631 | ||
|
|
49900d7a35 | ||
|
|
1e680e3a5a | ||
|
|
2591d21e94 | ||
|
|
6632f68d95 | ||
|
|
25985ad56a | ||
|
|
d6a24e9420 | ||
|
|
281a78f3c7 | ||
|
|
a0787d0091 | ||
|
|
d2acba0e8a | ||
|
|
ea7b6763c4 | ||
|
|
eb67c540b9 | ||
|
|
4151d0abd3 | ||
|
|
d0c11a6efa | ||
|
|
82c4da1743 | ||
|
|
0a15e0b75c | ||
|
|
0b315c29bf | ||
|
|
56f09bf61f | ||
|
|
6d82e82419 | ||
|
|
9bebaf814a |
12
README.md
12
README.md
@@ -67,7 +67,7 @@ Some other features of tarpc:
|
|||||||
Add to your `Cargo.toml` dependencies:
|
Add to your `Cargo.toml` dependencies:
|
||||||
|
|
||||||
```toml
|
```toml
|
||||||
tarpc = "0.26"
|
tarpc = "0.27"
|
||||||
```
|
```
|
||||||
|
|
||||||
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||||
@@ -80,8 +80,9 @@ This example uses [tokio](https://tokio.rs), so add the following dependencies t
|
|||||||
your `Cargo.toml`:
|
your `Cargo.toml`:
|
||||||
|
|
||||||
```toml
|
```toml
|
||||||
futures = "1.0"
|
anyhow = "1.0"
|
||||||
tarpc = { version = "0.26", features = ["tokio1"] }
|
futures = "0.3"
|
||||||
|
tarpc = { version = "0.27", features = ["tokio1"] }
|
||||||
tokio = { version = "1.0", features = ["macros"] }
|
tokio = { version = "1.0", features = ["macros"] }
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -99,9 +100,8 @@ use futures::{
|
|||||||
};
|
};
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
client, context,
|
client, context,
|
||||||
server::{self, Incoming},
|
server::{self, incoming::Incoming},
|
||||||
};
|
};
|
||||||
use std::io;
|
|
||||||
|
|
||||||
// This is the service definition. It looks a lot like a trait definition.
|
// This is the service definition. It looks a lot like a trait definition.
|
||||||
// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||||
@@ -140,7 +140,7 @@ available behind the `tcp` feature.
|
|||||||
|
|
||||||
```rust
|
```rust
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> io::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||||
|
|
||||||
let server = server::BaseChannel::with_defaults(server_transport);
|
let server = server::BaseChannel::with_defaults(server_transport);
|
||||||
|
|||||||
41
RELEASES.md
41
RELEASES.md
@@ -1,3 +1,36 @@
|
|||||||
|
## 0.27.1 (2021-09-22)
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
### RPC error type is changing
|
||||||
|
|
||||||
|
RPC return types are changing from `Result<Response, io::Error>` to `Result<Response,
|
||||||
|
tarpc::client::RpcError>`.
|
||||||
|
|
||||||
|
Becaue tarpc is a library, not an application, it should strive to
|
||||||
|
use structured errors in its API so that users have maximal flexibility
|
||||||
|
in how they handle errors. io::Error makes that hard, because it is a
|
||||||
|
kitchen-sink error type.
|
||||||
|
|
||||||
|
RPCs in particular only have 3 classes of errors:
|
||||||
|
|
||||||
|
- The connection breaks.
|
||||||
|
- The request expires.
|
||||||
|
- The server decides not to process the request.
|
||||||
|
|
||||||
|
RPC responses can also contain application-specific errors, but from the
|
||||||
|
perspective of the RPC library, those are opaque to the framework, classified
|
||||||
|
as successful responsees.
|
||||||
|
|
||||||
|
### Open Telemetry
|
||||||
|
|
||||||
|
The Opentelemetry dependency is updated to version 0.16.x.
|
||||||
|
|
||||||
|
## 0.27.0 (2021-09-22)
|
||||||
|
|
||||||
|
This version was yanked due to tarpc-plugins version mismatches.
|
||||||
|
|
||||||
|
|
||||||
## 0.26.0 (2021-04-14)
|
## 0.26.0 (2021-04-14)
|
||||||
|
|
||||||
### New Features
|
### New Features
|
||||||
@@ -68,10 +101,10 @@ tracing_subscriber::fmt.
|
|||||||
|
|
||||||
### References
|
### References
|
||||||
|
|
||||||
[1] https://github.com/tokio-rs/tracing
|
1. https://github.com/tokio-rs/tracing
|
||||||
[2] https://opentelemetry.io
|
2. https://opentelemetry.io
|
||||||
[3] https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
|
3. https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
|
||||||
[4] https://github.com/env-logger-rs/env_logger
|
4. https://github.com/env-logger-rs/env_logger
|
||||||
|
|
||||||
## 0.25.0 (2021-03-10)
|
## 0.25.0 (2021-03-10)
|
||||||
|
|
||||||
|
|||||||
@@ -17,15 +17,13 @@ anyhow = "1.0"
|
|||||||
clap = "3.0.0-beta.2"
|
clap = "3.0.0-beta.2"
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
opentelemetry = { version = "0.13", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.16", features = ["rt-tokio"] }
|
||||||
opentelemetry-jaeger = { version = "0.12", features = ["tokio"] }
|
opentelemetry-jaeger = { version = "0.15", features = ["rt-tokio"] }
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
serde = { version = "1.0" }
|
tarpc = { version = "0.27", path = "../tarpc", features = ["full"] }
|
||||||
tarpc = { version = "0.26", path = "../tarpc", features = ["full"] }
|
|
||||||
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
|
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
|
||||||
tokio-serde = { version = "0.8", features = ["json"] }
|
|
||||||
tracing = { version = "0.1" }
|
tracing = { version = "0.1" }
|
||||||
tracing-opentelemetry = "0.12"
|
tracing-opentelemetry = "0.15"
|
||||||
tracing-subscriber = "0.2"
|
tracing-subscriber = "0.2"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ use std::{
|
|||||||
};
|
};
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
context,
|
context,
|
||||||
server::{self, Channel, Incoming},
|
server::{self, incoming::Incoming, Channel},
|
||||||
tokio_serde::formats::Json,
|
tokio_serde::formats::Json,
|
||||||
};
|
};
|
||||||
use tokio::time;
|
use tokio::time;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "tarpc-plugins"
|
name = "tarpc-plugins"
|
||||||
version = "0.11.0"
|
version = "0.12.0"
|
||||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
|||||||
@@ -277,7 +277,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
|||||||
response_fut_name,
|
response_fut_name,
|
||||||
service_ident: ident,
|
service_ident: ident,
|
||||||
server_ident: &format_ident!("Serve{}", ident),
|
server_ident: &format_ident!("Serve{}", ident),
|
||||||
response_fut_ident: &Ident::new(&response_fut_name, ident.span()),
|
response_fut_ident: &Ident::new(response_fut_name, ident.span()),
|
||||||
client_ident: &format_ident!("{}Client", ident),
|
client_ident: &format_ident!("{}Client", ident),
|
||||||
request_ident: &format_ident!("{}Request", ident),
|
request_ident: &format_ident!("{}Request", ident),
|
||||||
response_ident: &format_ident!("{}Response", ident),
|
response_ident: &format_ident!("{}Response", ident),
|
||||||
@@ -406,11 +406,7 @@ fn verify_types_were_provided(
|
|||||||
) -> syn::Result<()> {
|
) -> syn::Result<()> {
|
||||||
let mut result = Ok(());
|
let mut result = Ok(());
|
||||||
for (method, expected) in expected {
|
for (method, expected) in expected {
|
||||||
if provided
|
if !provided.iter().any(|typedecl| typedecl.ident == expected) {
|
||||||
.iter()
|
|
||||||
.find(|typedecl| typedecl.ident == expected)
|
|
||||||
.is_none()
|
|
||||||
{
|
|
||||||
let mut e = syn::Error::new(
|
let mut e = syn::Error::new(
|
||||||
span,
|
span,
|
||||||
format!("not all trait items implemented, missing: `{}`", expected),
|
format!("not all trait items implemented, missing: `{}`", expected),
|
||||||
@@ -483,7 +479,7 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
),
|
),
|
||||||
output,
|
output,
|
||||||
)| {
|
)| {
|
||||||
let ty_doc = format!("The response future returned by {}.", ident);
|
let ty_doc = format!("The response future returned by [`{}::{}`].", service_ident, ident);
|
||||||
quote! {
|
quote! {
|
||||||
#[doc = #ty_doc]
|
#[doc = #ty_doc]
|
||||||
type #future_type: std::future::Future<Output = #output>;
|
type #future_type: std::future::Future<Output = #output>;
|
||||||
@@ -582,6 +578,7 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
|
|
||||||
quote! {
|
quote! {
|
||||||
/// The request sent over the wire from the client to the server.
|
/// The request sent over the wire from the client to the server.
|
||||||
|
#[allow(missing_docs)]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#derive_serialize
|
#derive_serialize
|
||||||
#vis enum #request_ident {
|
#vis enum #request_ident {
|
||||||
@@ -602,6 +599,7 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
|
|
||||||
quote! {
|
quote! {
|
||||||
/// The response sent over the wire from the server to the client.
|
/// The response sent over the wire from the server to the client.
|
||||||
|
#[allow(missing_docs)]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#derive_serialize
|
#derive_serialize
|
||||||
#vis enum #response_ident {
|
#vis enum #response_ident {
|
||||||
@@ -622,6 +620,7 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
|
|
||||||
quote! {
|
quote! {
|
||||||
/// A future resolving to a server response.
|
/// A future resolving to a server response.
|
||||||
|
#[allow(missing_docs)]
|
||||||
#vis enum #response_fut_ident<S: #service_ident> {
|
#vis enum #response_fut_ident<S: #service_ident> {
|
||||||
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
|
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
|
||||||
}
|
}
|
||||||
@@ -747,7 +746,7 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
#( #method_attrs )*
|
#( #method_attrs )*
|
||||||
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
|
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
|
||||||
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
|
-> impl std::future::Future<Output = Result<#return_types, tarpc::client::RpcError>> + '_ {
|
||||||
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
|
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
|
||||||
let resp = self.0.call(ctx, #request_names, request);
|
let resp = self.0.call(ctx, #request_names, request);
|
||||||
async move {
|
async move {
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "tarpc"
|
name = "tarpc"
|
||||||
version = "0.26.0"
|
version = "0.27.1"
|
||||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
@@ -18,9 +18,11 @@ default = []
|
|||||||
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
|
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
|
||||||
tokio1 = ["tokio/rt-multi-thread"]
|
tokio1 = ["tokio/rt-multi-thread"]
|
||||||
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
||||||
|
serde-transport-json = ["tokio-serde/json"]
|
||||||
|
serde-transport-bincode = ["tokio-serde/bincode"]
|
||||||
tcp = ["tokio/net"]
|
tcp = ["tokio/net"]
|
||||||
|
|
||||||
full = ["serde1", "tokio1", "serde-transport", "tcp"]
|
full = ["serde1", "tokio1", "serde-transport", "serde-transport-json", "serde-transport-bincode", "tcp"]
|
||||||
|
|
||||||
[badges]
|
[badges]
|
||||||
travis-ci = { repository = "google/tarpc" }
|
travis-ci = { repository = "google/tarpc" }
|
||||||
@@ -34,14 +36,14 @@ pin-project = "1.0"
|
|||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
serde = { optional = true, version = "1.0", features = ["derive"] }
|
serde = { optional = true, version = "1.0", features = ["derive"] }
|
||||||
static_assertions = "1.1.0"
|
static_assertions = "1.1.0"
|
||||||
tarpc-plugins = { path = "../plugins", version = "0.11" }
|
tarpc-plugins = { path = "../plugins", version = "0.12" }
|
||||||
thiserror = "1.0"
|
thiserror = "1.0"
|
||||||
tokio = { version = "1", features = ["time"] }
|
tokio = { version = "1", features = ["time"] }
|
||||||
tokio-util = { version = "0.6.3", features = ["time"] }
|
tokio-util = { version = "0.6.3", features = ["time"] }
|
||||||
tokio-serde = { optional = true, version = "0.8" }
|
tokio-serde = { optional = true, version = "0.8" }
|
||||||
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
|
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
|
||||||
tracing-opentelemetry = { version = "0.12", default-features = false }
|
tracing-opentelemetry = { version = "0.15", default-features = false }
|
||||||
opentelemetry = { version = "0.13", default-features = false }
|
opentelemetry = { version = "0.16", default-features = false }
|
||||||
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
@@ -50,8 +52,8 @@ bincode = "1.3"
|
|||||||
bytes = { version = "1", features = ["serde"] }
|
bytes = { version = "1", features = ["serde"] }
|
||||||
flate2 = "1.0"
|
flate2 = "1.0"
|
||||||
futures-test = "0.3"
|
futures-test = "0.3"
|
||||||
opentelemetry = { version = "0.13", default-features = false, features = ["rt-tokio"] }
|
opentelemetry = { version = "0.16", default-features = false, features = ["rt-tokio"] }
|
||||||
opentelemetry-jaeger = { version = "0.12", features = ["tokio"] }
|
opentelemetry-jaeger = { version = "0.15", features = ["rt-tokio"] }
|
||||||
pin-utils = "0.1.0-alpha"
|
pin-utils = "0.1.0-alpha"
|
||||||
serde_bytes = "0.11"
|
serde_bytes = "0.11"
|
||||||
tracing-subscriber = "0.2"
|
tracing-subscriber = "0.2"
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ use tarpc::{
|
|||||||
client, context,
|
client, context,
|
||||||
serde_transport::tcp,
|
serde_transport::tcp,
|
||||||
server::{BaseChannel, Channel},
|
server::{BaseChannel, Channel},
|
||||||
|
tokio_serde::formats::Bincode,
|
||||||
};
|
};
|
||||||
use tokio_serde::formats::Bincode;
|
|
||||||
|
|
||||||
/// Type of compression that should be enabled on the request. The transport is free to ignore this.
|
/// Type of compression that should be enabled on the request. The transport is free to ignore this.
|
||||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)]
|
#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)]
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
use futures::future;
|
|
||||||
use tarpc::context::Context;
|
|
||||||
use tarpc::serde_transport as transport;
|
use tarpc::serde_transport as transport;
|
||||||
use tarpc::server::{BaseChannel, Channel};
|
use tarpc::server::{BaseChannel, Channel};
|
||||||
|
use tarpc::{context::Context, tokio_serde::formats::Bincode};
|
||||||
use tokio::net::{UnixListener, UnixStream};
|
use tokio::net::{UnixListener, UnixStream};
|
||||||
use tokio_serde::formats::Bincode;
|
|
||||||
use tokio_util::codec::length_delimited::LengthDelimitedCodec;
|
use tokio_util::codec::length_delimited::LengthDelimitedCodec;
|
||||||
|
|
||||||
#[tarpc::service]
|
#[tarpc::service]
|
||||||
@@ -14,16 +12,13 @@ pub trait PingService {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct Service;
|
struct Service;
|
||||||
|
|
||||||
|
#[tarpc::server]
|
||||||
impl PingService for Service {
|
impl PingService for Service {
|
||||||
type PingFut = future::Ready<()>;
|
async fn ping(self, _: Context) {}
|
||||||
|
|
||||||
fn ping(self, _: Context) -> Self::PingFut {
|
|
||||||
future::ready(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> std::io::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let bind_addr = "/tmp/tarpc_on_unix_example.sock";
|
let bind_addr = "/tmp/tarpc_on_unix_example.sock";
|
||||||
|
|
||||||
let _ = std::fs::remove_file(bind_addr);
|
let _ = std::fs::remove_file(bind_addr);
|
||||||
@@ -46,5 +41,7 @@ async fn main() -> std::io::Result<()> {
|
|||||||
PingServiceClient::new(Default::default(), transport)
|
PingServiceClient::new(Default::default(), transport)
|
||||||
.spawn()
|
.spawn()
|
||||||
.ping(tarpc::context::current())
|
.ping(tarpc::context::current())
|
||||||
.await
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use futures::future::{self, Ready};
|
use futures::future::{self, Ready};
|
||||||
use std::io;
|
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
client, context,
|
client, context,
|
||||||
server::{self, Channel},
|
server::{self, Channel},
|
||||||
@@ -35,7 +34,7 @@ impl World for HelloServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> io::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||||
|
|
||||||
let server = server::BaseChannel::with_defaults(server_transport);
|
let server = server::BaseChannel::with_defaults(server_transport);
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use futures::{future, prelude::*};
|
|||||||
use std::env;
|
use std::env;
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
client, context,
|
client, context,
|
||||||
server::{BaseChannel, Incoming},
|
server::{incoming::Incoming, BaseChannel},
|
||||||
};
|
};
|
||||||
use tokio_serde::formats::Json;
|
use tokio_serde::formats::Json;
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
|
|||||||
@@ -8,14 +8,14 @@
|
|||||||
|
|
||||||
mod in_flight_requests;
|
mod in_flight_requests;
|
||||||
|
|
||||||
use crate::{context, trace, ClientMessage, Request, Response, Transport};
|
use crate::{context, trace, ClientMessage, Request, Response, ServerError, Transport};
|
||||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||||
use in_flight_requests::InFlightRequests;
|
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
|
||||||
use pin_project::pin_project;
|
use pin_project::pin_project;
|
||||||
use std::{
|
use std::{
|
||||||
convert::TryFrom,
|
convert::TryFrom,
|
||||||
error::Error,
|
error::Error,
|
||||||
fmt, io, mem,
|
fmt, mem,
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
sync::{
|
sync::{
|
||||||
atomic::{AtomicUsize, Ordering},
|
atomic::{AtomicUsize, Ordering},
|
||||||
@@ -60,15 +60,16 @@ pub struct NewClient<C, D> {
|
|||||||
impl<C, D, E> NewClient<C, D>
|
impl<C, D, E> NewClient<C, D>
|
||||||
where
|
where
|
||||||
D: Future<Output = Result<(), E>> + Send + 'static,
|
D: Future<Output = Result<(), E>> + Send + 'static,
|
||||||
E: std::fmt::Display,
|
E: std::error::Error + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
/// Helper method to spawn the dispatch on the default executor.
|
/// Helper method to spawn the dispatch on the default executor.
|
||||||
#[cfg(feature = "tokio1")]
|
#[cfg(feature = "tokio1")]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||||
pub fn spawn(self) -> C {
|
pub fn spawn(self) -> C {
|
||||||
let dispatch = self
|
let dispatch = self.dispatch.unwrap_or_else(move |e| {
|
||||||
.dispatch
|
let e = anyhow::Error::new(e);
|
||||||
.unwrap_or_else(move |e| tracing::warn!("Connection broken: {}", e));
|
tracing::warn!("Connection broken: {:?}", e);
|
||||||
|
});
|
||||||
tokio::spawn(dispatch);
|
tokio::spawn(dispatch);
|
||||||
self.client
|
self.client
|
||||||
}
|
}
|
||||||
@@ -125,7 +126,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
|||||||
mut ctx: context::Context,
|
mut ctx: context::Context,
|
||||||
request_name: &str,
|
request_name: &str,
|
||||||
request: Req,
|
request: Req,
|
||||||
) -> io::Result<Resp> {
|
) -> Result<Resp, RpcError> {
|
||||||
let span = Span::current();
|
let span = Span::current();
|
||||||
ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
|
ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
@@ -156,7 +157,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
|||||||
response_completion,
|
response_completion,
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.map_err(|mpsc::error::SendError(_)| io::Error::from(io::ErrorKind::ConnectionReset))?;
|
.map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
|
||||||
response_guard.response().await
|
response_guard.response().await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -164,23 +165,44 @@ impl<Req, Resp> Channel<Req, Resp> {
|
|||||||
/// A server response that is completed by request dispatch when the corresponding response
|
/// A server response that is completed by request dispatch when the corresponding response
|
||||||
/// arrives off the wire.
|
/// arrives off the wire.
|
||||||
struct ResponseGuard<'a, Resp> {
|
struct ResponseGuard<'a, Resp> {
|
||||||
response: &'a mut oneshot::Receiver<Response<Resp>>,
|
response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
|
||||||
cancellation: &'a RequestCancellation,
|
cancellation: &'a RequestCancellation,
|
||||||
request_id: u64,
|
request_id: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
|
||||||
|
/// rather cross-cutting errors that can always occur.
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum RpcError {
|
||||||
|
/// The client disconnected from the server.
|
||||||
|
#[error("the client disconnected from the server")]
|
||||||
|
Disconnected,
|
||||||
|
/// The request exceeded its deadline.
|
||||||
|
#[error("the request exceeded its deadline")]
|
||||||
|
DeadlineExceeded,
|
||||||
|
/// The server aborted request processing.
|
||||||
|
#[error("the server aborted request processing")]
|
||||||
|
Server(#[from] ServerError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<DeadlineExceededError> for RpcError {
|
||||||
|
fn from(_: DeadlineExceededError) -> Self {
|
||||||
|
RpcError::DeadlineExceeded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<Resp> ResponseGuard<'_, Resp> {
|
impl<Resp> ResponseGuard<'_, Resp> {
|
||||||
async fn response(mut self) -> io::Result<Resp> {
|
async fn response(mut self) -> Result<Resp, RpcError> {
|
||||||
let response = (&mut self.response).await;
|
let response = (&mut self.response).await;
|
||||||
// Cancel drop logic once a response has been received.
|
// Cancel drop logic once a response has been received.
|
||||||
mem::forget(self);
|
mem::forget(self);
|
||||||
match response {
|
match response {
|
||||||
Ok(resp) => Ok(resp.message?),
|
Ok(resp) => Ok(resp?.message?),
|
||||||
Err(oneshot::error::RecvError { .. }) => {
|
Err(oneshot::error::RecvError { .. }) => {
|
||||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||||
// there's nothing listening on the other side, so there's no point in
|
// there's nothing listening on the other side, so there's no point in
|
||||||
// propagating cancellation.
|
// propagating cancellation.
|
||||||
Err(io::Error::from(io::ErrorKind::ConnectionReset))
|
Err(RpcError::Disconnected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -257,11 +279,23 @@ pub enum ChannelError<E>
|
|||||||
where
|
where
|
||||||
E: Error + Send + Sync + 'static,
|
E: Error + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
/// An error occurred reading from, or writing to, the transport.
|
/// Could not read from the transport.
|
||||||
#[error("an error occurred in the transport: {0}")]
|
#[error("could not read from the transport")]
|
||||||
Transport(#[source] E),
|
Read(#[source] E),
|
||||||
/// An error occurred while polling expired requests.
|
/// Could not ready the transport for writes.
|
||||||
#[error("an error occurred while polling expired requests: {0}")]
|
#[error("could not ready the transport for writes")]
|
||||||
|
Ready(#[source] E),
|
||||||
|
/// Could not write to the transport.
|
||||||
|
#[error("could not write to the transport")]
|
||||||
|
Write(#[source] E),
|
||||||
|
/// Could not flush the transport.
|
||||||
|
#[error("could not flush the transport")]
|
||||||
|
Flush(#[source] E),
|
||||||
|
/// Could not close the write end of the transport.
|
||||||
|
#[error("could not close the write end of the transport")]
|
||||||
|
Close(#[source] E),
|
||||||
|
/// Could not poll expired requests.
|
||||||
|
#[error("could not poll expired requests")]
|
||||||
Timer(#[source] tokio::time::error::Error),
|
Timer(#[source] tokio::time::error::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,6 +311,42 @@ where
|
|||||||
self.as_mut().project().transport
|
self.as_mut().project().transport
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn poll_ready<'a>(
|
||||||
|
self: &'a mut Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||||
|
self.transport_pin_mut()
|
||||||
|
.poll_ready(cx)
|
||||||
|
.map_err(ChannelError::Ready)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(
|
||||||
|
self: &mut Pin<&mut Self>,
|
||||||
|
message: ClientMessage<Req>,
|
||||||
|
) -> Result<(), ChannelError<C::Error>> {
|
||||||
|
self.transport_pin_mut()
|
||||||
|
.start_send(message)
|
||||||
|
.map_err(ChannelError::Write)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush<'a>(
|
||||||
|
self: &'a mut Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||||
|
self.transport_pin_mut()
|
||||||
|
.poll_flush(cx)
|
||||||
|
.map_err(ChannelError::Flush)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close<'a>(
|
||||||
|
self: &'a mut Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||||
|
self.transport_pin_mut()
|
||||||
|
.poll_close(cx)
|
||||||
|
.map_err(ChannelError::Close)
|
||||||
|
}
|
||||||
|
|
||||||
fn canceled_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut CanceledRequests {
|
fn canceled_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut CanceledRequests {
|
||||||
self.as_mut().project().canceled_requests
|
self.as_mut().project().canceled_requests
|
||||||
}
|
}
|
||||||
@@ -293,7 +363,7 @@ where
|
|||||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||||
self.transport_pin_mut()
|
self.transport_pin_mut()
|
||||||
.poll_next(cx)
|
.poll_next(cx)
|
||||||
.map_err(ChannelError::Transport)
|
.map_err(ChannelError::Read)
|
||||||
.map_ok(|response| {
|
.map_ok(|response| {
|
||||||
self.complete(response);
|
self.complete(response);
|
||||||
})
|
})
|
||||||
@@ -308,21 +378,13 @@ where
|
|||||||
Closed,
|
Closed,
|
||||||
}
|
}
|
||||||
|
|
||||||
let pending_requests_status = match self
|
let pending_requests_status = match self.as_mut().poll_write_request(cx)? {
|
||||||
.as_mut()
|
|
||||||
.poll_write_request(cx)
|
|
||||||
.map_err(ChannelError::Transport)?
|
|
||||||
{
|
|
||||||
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
|
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
|
||||||
Poll::Ready(None) => ReceiverStatus::Closed,
|
Poll::Ready(None) => ReceiverStatus::Closed,
|
||||||
Poll::Pending => ReceiverStatus::Pending,
|
Poll::Pending => ReceiverStatus::Pending,
|
||||||
};
|
};
|
||||||
|
|
||||||
let canceled_requests_status = match self
|
let canceled_requests_status = match self.as_mut().poll_write_cancel(cx)? {
|
||||||
.as_mut()
|
|
||||||
.poll_write_cancel(cx)
|
|
||||||
.map_err(ChannelError::Transport)?
|
|
||||||
{
|
|
||||||
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
|
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
|
||||||
Poll::Ready(None) => ReceiverStatus::Closed,
|
Poll::Ready(None) => ReceiverStatus::Closed,
|
||||||
Poll::Pending => ReceiverStatus::Pending,
|
Poll::Pending => ReceiverStatus::Pending,
|
||||||
@@ -344,18 +406,12 @@ where
|
|||||||
|
|
||||||
match (pending_requests_status, canceled_requests_status) {
|
match (pending_requests_status, canceled_requests_status) {
|
||||||
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
|
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
|
||||||
ready!(self
|
ready!(self.poll_close(cx)?);
|
||||||
.transport_pin_mut()
|
|
||||||
.poll_flush(cx)
|
|
||||||
.map_err(ChannelError::Transport)?);
|
|
||||||
Poll::Ready(None)
|
Poll::Ready(None)
|
||||||
}
|
}
|
||||||
(ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => {
|
(ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => {
|
||||||
// No more messages to process, so flush any messages buffered in the transport.
|
// No more messages to process, so flush any messages buffered in the transport.
|
||||||
ready!(self
|
ready!(self.poll_flush(cx)?);
|
||||||
.transport_pin_mut()
|
|
||||||
.poll_flush(cx)
|
|
||||||
.map_err(ChannelError::Transport)?);
|
|
||||||
|
|
||||||
// Even if we fully-flush, we return Pending, because we have no more requests
|
// Even if we fully-flush, we return Pending, because we have no more requests
|
||||||
// or cancellations right now.
|
// or cancellations right now.
|
||||||
@@ -371,7 +427,7 @@ where
|
|||||||
fn poll_next_request(
|
fn poll_next_request(
|
||||||
mut self: Pin<&mut Self>,
|
mut self: Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Option<Result<DispatchRequest<Req, Resp>, C::Error>>> {
|
) -> Poll<Option<Result<DispatchRequest<Req, Resp>, ChannelError<C::Error>>>> {
|
||||||
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
|
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
"At in-flight request capacity ({}/{}).",
|
"At in-flight request capacity ({}/{}).",
|
||||||
@@ -409,7 +465,7 @@ where
|
|||||||
fn poll_next_cancellation(
|
fn poll_next_cancellation(
|
||||||
mut self: Pin<&mut Self>,
|
mut self: Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Option<Result<(context::Context, Span, u64), C::Error>>> {
|
) -> Poll<Option<Result<(context::Context, Span, u64), ChannelError<C::Error>>>> {
|
||||||
ready!(self.ensure_writeable(cx)?);
|
ready!(self.ensure_writeable(cx)?);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
@@ -431,9 +487,9 @@ where
|
|||||||
fn ensure_writeable<'a>(
|
fn ensure_writeable<'a>(
|
||||||
self: &'a mut Pin<&mut Self>,
|
self: &'a mut Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Option<Result<(), C::Error>>> {
|
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||||
while self.transport_pin_mut().poll_ready(cx)?.is_pending() {
|
while self.poll_ready(cx)?.is_pending() {
|
||||||
ready!(self.transport_pin_mut().poll_flush(cx)?);
|
ready!(self.poll_flush(cx)?);
|
||||||
}
|
}
|
||||||
Poll::Ready(Some(Ok(())))
|
Poll::Ready(Some(Ok(())))
|
||||||
}
|
}
|
||||||
@@ -441,7 +497,7 @@ where
|
|||||||
fn poll_write_request<'a>(
|
fn poll_write_request<'a>(
|
||||||
self: &'a mut Pin<&mut Self>,
|
self: &'a mut Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Option<Result<(), C::Error>>> {
|
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||||
let DispatchRequest {
|
let DispatchRequest {
|
||||||
ctx,
|
ctx,
|
||||||
span,
|
span,
|
||||||
@@ -465,7 +521,7 @@ where
|
|||||||
trace_context: ctx.trace_context,
|
trace_context: ctx.trace_context,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
self.transport_pin_mut().start_send(request)?;
|
self.start_send(request)?;
|
||||||
let deadline = ctx.deadline;
|
let deadline = ctx.deadline;
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
tarpc.deadline = %humantime::format_rfc3339(deadline),
|
tarpc.deadline = %humantime::format_rfc3339(deadline),
|
||||||
@@ -482,7 +538,7 @@ where
|
|||||||
fn poll_write_cancel<'a>(
|
fn poll_write_cancel<'a>(
|
||||||
self: &'a mut Pin<&mut Self>,
|
self: &'a mut Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Option<Result<(), C::Error>>> {
|
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||||
let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
|
let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
|
||||||
Some(triple) => triple,
|
Some(triple) => triple,
|
||||||
None => return Poll::Ready(None),
|
None => return Poll::Ready(None),
|
||||||
@@ -493,7 +549,7 @@ where
|
|||||||
trace_context: context.trace_context,
|
trace_context: context.trace_context,
|
||||||
request_id,
|
request_id,
|
||||||
};
|
};
|
||||||
self.transport_pin_mut().start_send(cancel)?;
|
self.start_send(cancel)?;
|
||||||
tracing::info!("CancelRequest");
|
tracing::info!("CancelRequest");
|
||||||
Poll::Ready(Some(Ok(())))
|
Poll::Ready(Some(Ok(())))
|
||||||
}
|
}
|
||||||
@@ -549,7 +605,7 @@ struct DispatchRequest<Req, Resp> {
|
|||||||
pub span: Span,
|
pub span: Span,
|
||||||
pub request_id: u64,
|
pub request_id: u64,
|
||||||
pub request: Req,
|
pub request: Req,
|
||||||
pub response_completion: oneshot::Sender<Response<Resp>>,
|
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sends request cancellation signals.
|
/// Sends request cancellation signals.
|
||||||
@@ -596,7 +652,10 @@ mod tests {
|
|||||||
RequestDispatch, ResponseGuard,
|
RequestDispatch, ResponseGuard,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
client::{in_flight_requests::InFlightRequests, Config},
|
client::{
|
||||||
|
in_flight_requests::{DeadlineExceededError, InFlightRequests},
|
||||||
|
Config,
|
||||||
|
},
|
||||||
context,
|
context,
|
||||||
transport::{self, channel::UnboundedChannel},
|
transport::{self, channel::UnboundedChannel},
|
||||||
ClientMessage, Response,
|
ClientMessage, Response,
|
||||||
@@ -630,7 +689,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
|
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
|
||||||
assert_matches!(rx.try_recv(), Ok(Response { request_id: 0, message: Ok(resp) }) if resp == "Resp");
|
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -651,10 +710,10 @@ mod tests {
|
|||||||
async fn dispatch_response_doesnt_cancel_after_complete() {
|
async fn dispatch_response_doesnt_cancel_after_complete() {
|
||||||
let (cancellation, mut canceled_requests) = cancellations();
|
let (cancellation, mut canceled_requests) = cancellations();
|
||||||
let (tx, mut response) = oneshot::channel();
|
let (tx, mut response) = oneshot::channel();
|
||||||
tx.send(Response {
|
tx.send(Ok(Response {
|
||||||
request_id: 0,
|
request_id: 0,
|
||||||
message: Ok("well done"),
|
message: Ok("well done"),
|
||||||
})
|
}))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// resp's drop() is run, but should not send a cancel message.
|
// resp's drop() is run, but should not send a cancel message.
|
||||||
ResponseGuard {
|
ResponseGuard {
|
||||||
@@ -799,8 +858,8 @@ mod tests {
|
|||||||
async fn send_request<'a>(
|
async fn send_request<'a>(
|
||||||
channel: &'a mut Channel<String, String>,
|
channel: &'a mut Channel<String, String>,
|
||||||
request: &str,
|
request: &str,
|
||||||
response_completion: oneshot::Sender<Response<String>>,
|
response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
|
||||||
response: &'a mut oneshot::Receiver<Response<String>>,
|
response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
|
||||||
) -> ResponseGuard<'a, String> {
|
) -> ResponseGuard<'a, String> {
|
||||||
let request_id =
|
let request_id =
|
||||||
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
context,
|
context,
|
||||||
util::{Compact, TimeUntil},
|
util::{Compact, TimeUntil},
|
||||||
Response, ServerError,
|
Response,
|
||||||
};
|
};
|
||||||
use fnv::FnvHashMap;
|
use fnv::FnvHashMap;
|
||||||
use std::{
|
use std::{
|
||||||
collections::hash_map,
|
collections::hash_map,
|
||||||
io,
|
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
@@ -29,11 +28,17 @@ impl<Resp> Default for InFlightRequests<Resp> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The request exceeded its deadline.
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
#[error("the request exceeded its deadline")]
|
||||||
|
pub struct DeadlineExceededError;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct RequestData<Resp> {
|
struct RequestData<Resp> {
|
||||||
ctx: context::Context,
|
ctx: context::Context,
|
||||||
span: Span,
|
span: Span,
|
||||||
response_completion: oneshot::Sender<Response<Resp>>,
|
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||||
/// The key to remove the timer for the request's deadline.
|
/// The key to remove the timer for the request's deadline.
|
||||||
deadline_key: delay_queue::Key,
|
deadline_key: delay_queue::Key,
|
||||||
}
|
}
|
||||||
@@ -60,7 +65,7 @@ impl<Resp> InFlightRequests<Resp> {
|
|||||||
request_id: u64,
|
request_id: u64,
|
||||||
ctx: context::Context,
|
ctx: context::Context,
|
||||||
span: Span,
|
span: Span,
|
||||||
response_completion: oneshot::Sender<Response<Resp>>,
|
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||||
) -> Result<(), AlreadyExistsError> {
|
) -> Result<(), AlreadyExistsError> {
|
||||||
match self.request_data.entry(request_id) {
|
match self.request_data.entry(request_id) {
|
||||||
hash_map::Entry::Vacant(vacant) => {
|
hash_map::Entry::Vacant(vacant) => {
|
||||||
@@ -85,7 +90,7 @@ impl<Resp> InFlightRequests<Resp> {
|
|||||||
tracing::info!("ReceiveResponse");
|
tracing::info!("ReceiveResponse");
|
||||||
self.request_data.compact(0.1);
|
self.request_data.compact(0.1);
|
||||||
self.deadlines.remove(&request_data.deadline_key);
|
self.deadlines.remove(&request_data.deadline_key);
|
||||||
let _ = request_data.response_completion.send(response);
|
let _ = request_data.response_completion.send(Ok(response));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,19 +129,9 @@ impl<Resp> InFlightRequests<Resp> {
|
|||||||
self.request_data.compact(0.1);
|
self.request_data.compact(0.1);
|
||||||
let _ = request_data
|
let _ = request_data
|
||||||
.response_completion
|
.response_completion
|
||||||
.send(Self::deadline_exceeded_error(request_id));
|
.send(Err(DeadlineExceededError));
|
||||||
}
|
}
|
||||||
request_id
|
request_id
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn deadline_exceeded_error(request_id: u64) -> Response<Resp> {
|
|
||||||
Response {
|
|
||||||
request_id,
|
|
||||||
message: Err(ServerError {
|
|
||||||
kind: io::ErrorKind::TimedOut,
|
|
||||||
detail: Some("Client dropped expired request.".to_string()),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ impl SpanExt for tracing::Span {
|
|||||||
.with_remote_span_context(opentelemetry::trace::SpanContext::new(
|
.with_remote_span_context(opentelemetry::trace::SpanContext::new(
|
||||||
opentelemetry::trace::TraceId::from(context.trace_context.trace_id),
|
opentelemetry::trace::TraceId::from(context.trace_context.trace_id),
|
||||||
opentelemetry::trace::SpanId::from(context.trace_context.span_id),
|
opentelemetry::trace::SpanId::from(context.trace_context.span_id),
|
||||||
context.trace_context.sampling_decision as u8,
|
opentelemetry::trace::TraceFlags::from(context.trace_context.sampling_decision),
|
||||||
true,
|
true,
|
||||||
opentelemetry::trace::TraceState::default(),
|
opentelemetry::trace::TraceState::default(),
|
||||||
))
|
))
|
||||||
|
|||||||
@@ -54,7 +54,7 @@
|
|||||||
//! Add to your `Cargo.toml` dependencies:
|
//! Add to your `Cargo.toml` dependencies:
|
||||||
//!
|
//!
|
||||||
//! ```toml
|
//! ```toml
|
||||||
//! tarpc = "0.26"
|
//! tarpc = "0.27"
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||||
@@ -67,8 +67,9 @@
|
|||||||
//! your `Cargo.toml`:
|
//! your `Cargo.toml`:
|
||||||
//!
|
//!
|
||||||
//! ```toml
|
//! ```toml
|
||||||
//! futures = "1.0"
|
//! anyhow = "1.0"
|
||||||
//! tarpc = { version = "0.26", features = ["tokio1"] }
|
//! futures = "0.3"
|
||||||
|
//! tarpc = { version = "0.27", features = ["tokio1"] }
|
||||||
//! tokio = { version = "1.0", features = ["macros"] }
|
//! tokio = { version = "1.0", features = ["macros"] }
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
@@ -87,9 +88,8 @@
|
|||||||
//! };
|
//! };
|
||||||
//! use tarpc::{
|
//! use tarpc::{
|
||||||
//! client, context,
|
//! client, context,
|
||||||
//! server::{self, Incoming},
|
//! server::{self, incoming::Incoming},
|
||||||
//! };
|
//! };
|
||||||
//! use std::io;
|
|
||||||
//!
|
//!
|
||||||
//! // This is the service definition. It looks a lot like a trait definition.
|
//! // This is the service definition. It looks a lot like a trait definition.
|
||||||
//! // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
//! // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||||
@@ -111,9 +111,8 @@
|
|||||||
//! # };
|
//! # };
|
||||||
//! # use tarpc::{
|
//! # use tarpc::{
|
||||||
//! # client, context,
|
//! # client, context,
|
||||||
//! # server::{self, Incoming},
|
//! # server::{self, incoming::Incoming},
|
||||||
//! # };
|
//! # };
|
||||||
//! # use std::io;
|
|
||||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||||
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||||
//! # #[tarpc::service]
|
//! # #[tarpc::service]
|
||||||
@@ -153,7 +152,6 @@
|
|||||||
//! # client, context,
|
//! # client, context,
|
||||||
//! # server::{self, Channel},
|
//! # server::{self, Channel},
|
||||||
//! # };
|
//! # };
|
||||||
//! # use std::io;
|
|
||||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||||
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||||
//! # #[tarpc::service]
|
//! # #[tarpc::service]
|
||||||
@@ -177,7 +175,7 @@
|
|||||||
//! # fn main() {}
|
//! # fn main() {}
|
||||||
//! # #[cfg(feature = "tokio1")]
|
//! # #[cfg(feature = "tokio1")]
|
||||||
//! #[tokio::main]
|
//! #[tokio::main]
|
||||||
//! async fn main() -> io::Result<()> {
|
//! async fn main() -> anyhow::Result<()> {
|
||||||
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||||
//!
|
//!
|
||||||
//! let server = server::BaseChannel::with_defaults(server_transport);
|
//! let server = server::BaseChannel::with_defaults(server_transport);
|
||||||
@@ -364,8 +362,9 @@ pub struct Response<T> {
|
|||||||
pub message: Result<T, ServerError>,
|
pub message: Result<T, ServerError>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An error response from a server to a client.
|
/// An error indicating the server aborted the request early, e.g., due to request throttling.
|
||||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
|
||||||
|
#[error("{kind:?}: {detail}")]
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
pub struct ServerError {
|
pub struct ServerError {
|
||||||
@@ -380,13 +379,7 @@ pub struct ServerError {
|
|||||||
/// The type of error that occurred to fail the request.
|
/// The type of error that occurred to fail the request.
|
||||||
pub kind: io::ErrorKind,
|
pub kind: io::ErrorKind,
|
||||||
/// A message describing more detail about the error that occurred.
|
/// A message describing more detail about the error that occurred.
|
||||||
pub detail: Option<String>,
|
pub detail: String,
|
||||||
}
|
|
||||||
|
|
||||||
impl From<ServerError> for io::Error {
|
|
||||||
fn from(e: ServerError) -> io::Error {
|
|
||||||
io::Error::new(e.kind, e.detail.unwrap_or_default())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Request<T> {
|
impl<T> Request<T> {
|
||||||
|
|||||||
@@ -42,12 +42,10 @@ where
|
|||||||
type Item = io::Result<Item>;
|
type Item = io::Result<Item>;
|
||||||
|
|
||||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
|
||||||
self.project().inner.poll_next(cx).map_err(|e| {
|
self.project()
|
||||||
io::Error::new(
|
.inner
|
||||||
io::ErrorKind::Other,
|
.poll_next(cx)
|
||||||
format!("while reading from transport: {}", e.into()),
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,39 +61,31 @@ where
|
|||||||
type Error = io::Error;
|
type Error = io::Error;
|
||||||
|
|
||||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
self.project().inner.poll_ready(cx).map_err(|e| {
|
self.project()
|
||||||
io::Error::new(
|
.inner
|
||||||
io::ErrorKind::Other,
|
.poll_ready(cx)
|
||||||
format!("while readying write half of transport: {}", e.into()),
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
||||||
self.project().inner.start_send(item).map_err(|e| {
|
self.project()
|
||||||
io::Error::new(
|
.inner
|
||||||
io::ErrorKind::Other,
|
.start_send(item)
|
||||||
format!("while writing to transport: {}", e.into()),
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
self.project().inner.poll_flush(cx).map_err(|e| {
|
self.project()
|
||||||
io::Error::new(
|
.inner
|
||||||
io::ErrorKind::Other,
|
.poll_flush(cx)
|
||||||
format!("while flushing transport: {}", e.into()),
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
self.project().inner.poll_close(cx).map_err(|e| {
|
self.project()
|
||||||
io::Error::new(
|
.inner
|
||||||
io::ErrorKind::Other,
|
.poll_close(cx)
|
||||||
format!("while closing write half of transport: {}", e.into()),
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ use crate::{
|
|||||||
context::{self, SpanExt},
|
context::{self, SpanExt},
|
||||||
trace, ClientMessage, Request, Response, Transport,
|
trace, ClientMessage, Request, Response, Transport,
|
||||||
};
|
};
|
||||||
|
use ::tokio::sync::mpsc;
|
||||||
use futures::{
|
use futures::{
|
||||||
future::{AbortRegistration, Abortable},
|
future::{AbortRegistration, Abortable},
|
||||||
prelude::*,
|
prelude::*,
|
||||||
@@ -19,23 +20,23 @@ use futures::{
|
|||||||
};
|
};
|
||||||
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
|
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
|
||||||
use pin_project::pin_project;
|
use pin_project::pin_project;
|
||||||
use std::{
|
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin};
|
||||||
convert::TryFrom, error::Error, fmt, hash::Hash, marker::PhantomData, pin::Pin,
|
|
||||||
time::SystemTime,
|
|
||||||
};
|
|
||||||
use tokio::sync::mpsc;
|
|
||||||
use tracing::{info_span, instrument::Instrument, Span};
|
use tracing::{info_span, instrument::Instrument, Span};
|
||||||
|
|
||||||
mod filter;
|
|
||||||
mod in_flight_requests;
|
mod in_flight_requests;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod testing;
|
mod testing;
|
||||||
mod throttle;
|
|
||||||
|
|
||||||
pub use self::{
|
/// Provides functionality to apply server limits.
|
||||||
filter::ChannelFilter,
|
pub mod limits;
|
||||||
throttle::{Throttler, ThrottlerStream},
|
|
||||||
};
|
/// Provides helper methods for streams of Channels.
|
||||||
|
pub mod incoming;
|
||||||
|
|
||||||
|
/// Provides convenience functionality for tokio-enabled applications.
|
||||||
|
#[cfg(feature = "tokio1")]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||||
|
pub mod tokio;
|
||||||
|
|
||||||
/// Settings that control the behavior of [channels](Channel).
|
/// Settings that control the behavior of [channels](Channel).
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
@@ -94,51 +95,13 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An extension trait for [streams](Stream) of [`Channels`](Channel).
|
/// BaseChannel is the standard implementation of a [`Channel`].
|
||||||
pub trait Incoming<C>
|
|
||||||
where
|
|
||||||
Self: Sized + Stream<Item = C>,
|
|
||||||
C: Channel,
|
|
||||||
{
|
|
||||||
/// Enforces channel per-key limits.
|
|
||||||
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> filter::ChannelFilter<Self, K, KF>
|
|
||||||
where
|
|
||||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
|
||||||
KF: Fn(&C) -> K,
|
|
||||||
{
|
|
||||||
ChannelFilter::new(self, n, keymaker)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Caps the number of concurrent requests per channel.
|
|
||||||
fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream<Self> {
|
|
||||||
ThrottlerStream::new(self, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
|
|
||||||
/// concurrently by spawning on tokio's default executor, and each request will be also
|
|
||||||
/// be spawned on tokio's default executor.
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
|
||||||
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
|
|
||||||
where
|
|
||||||
S: Serve<C::Req, Resp = C::Resp>,
|
|
||||||
{
|
|
||||||
TokioServerExecutor { inner: self, serve }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, C> Incoming<C> for S
|
|
||||||
where
|
|
||||||
S: Sized + Stream<Item = C>,
|
|
||||||
C: Channel,
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/// BaseChannel is a [Transport] that keeps track of in-flight requests. It converts a
|
|
||||||
/// [`Transport`](Transport) of [`ClientMessages`](ClientMessage) into a stream of
|
|
||||||
/// [requests](ClientMessage::Request).
|
|
||||||
///
|
///
|
||||||
/// Besides requests, the other type of client message is [cancellation
|
/// BaseChannel manages a [`Transport`](Transport) of client [`messages`](ClientMessage) and
|
||||||
|
/// implements a [`Stream`] of [requests](TrackedRequest). See the [`Channel`] documentation for
|
||||||
|
/// how to use channels.
|
||||||
|
///
|
||||||
|
/// Besides requests, the other type of client message handled by `BaseChannel` is [cancellation
|
||||||
/// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation
|
/// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation
|
||||||
/// messages. Instead, it internally handles them by cancelling corresponding requests (removing
|
/// messages. Instead, it internally handles them by cancelling corresponding requests (removing
|
||||||
/// the corresponding in-flight requests and aborting their handlers).
|
/// the corresponding in-flight requests and aborting their handlers).
|
||||||
@@ -190,6 +153,47 @@ where
|
|||||||
fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<T>> {
|
fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<T>> {
|
||||||
self.as_mut().project().transport
|
self.as_mut().project().transport
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn start_request(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
mut request: Request<Req>,
|
||||||
|
) -> Result<TrackedRequest<Req>, AlreadyExistsError> {
|
||||||
|
let span = info_span!(
|
||||||
|
"RPC",
|
||||||
|
rpc.trace_id = %request.context.trace_id(),
|
||||||
|
otel.kind = "server",
|
||||||
|
otel.name = tracing::field::Empty,
|
||||||
|
);
|
||||||
|
span.set_context(&request.context);
|
||||||
|
request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
|
||||||
|
tracing::trace!(
|
||||||
|
"OpenTelemetry subscriber not installed; making unsampled \
|
||||||
|
child context."
|
||||||
|
);
|
||||||
|
request.context.trace_context.new_child()
|
||||||
|
});
|
||||||
|
let entered = span.enter();
|
||||||
|
tracing::info!("ReceiveRequest");
|
||||||
|
let start = self.project().in_flight_requests.start_request(
|
||||||
|
request.id,
|
||||||
|
request.context.deadline,
|
||||||
|
span.clone(),
|
||||||
|
);
|
||||||
|
match start {
|
||||||
|
Ok(abort_registration) => {
|
||||||
|
drop(entered);
|
||||||
|
Ok(TrackedRequest {
|
||||||
|
request,
|
||||||
|
abort_registration,
|
||||||
|
span,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Err(AlreadyExistsError) => {
|
||||||
|
tracing::trace!("DuplicateRequest");
|
||||||
|
Err(AlreadyExistsError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
|
impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
|
||||||
@@ -198,8 +202,20 @@ impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The server end of an open connection with a client, streaming in requests from, and sinking
|
/// A request tracked by a [`Channel`].
|
||||||
/// responses to, the client.
|
#[derive(Debug)]
|
||||||
|
pub struct TrackedRequest<Req> {
|
||||||
|
/// The request sent by the client.
|
||||||
|
pub request: Request<Req>,
|
||||||
|
/// A registration to abort a future when the [`Channel`] that produced this request stops
|
||||||
|
/// tracking it.
|
||||||
|
pub abort_registration: AbortRegistration,
|
||||||
|
/// A span representing the server processing of this request.
|
||||||
|
pub span: Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The server end of an open connection with a client, receiving requests from, and sending
|
||||||
|
/// responses to, the client. `Channel` is a [`Transport`] with request lifecycle management.
|
||||||
///
|
///
|
||||||
/// The ways to use a Channel, in order of simplest to most complex, is:
|
/// The ways to use a Channel, in order of simplest to most complex, is:
|
||||||
/// 1. [`Channel::execute`] - Requires the `tokio1` feature. This method is best for those who
|
/// 1. [`Channel::execute`] - Requires the `tokio1` feature. This method is best for those who
|
||||||
@@ -210,18 +226,20 @@ impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
|
|||||||
/// [`execute`](InFlightRequest::execute) method. If using `execute`, request processing will
|
/// [`execute`](InFlightRequest::execute) method. If using `execute`, request processing will
|
||||||
/// automatically cease when either the request deadline is reached or when a corresponding
|
/// automatically cease when either the request deadline is reached or when a corresponding
|
||||||
/// cancellation message is received by the Channel.
|
/// cancellation message is received by the Channel.
|
||||||
/// 3. [`Sink::start_send`] - A user is free to manually send responses to requests produced by a
|
/// 3. [`Stream::next`](futures::stream::StreamExt::next) /
|
||||||
/// Channel using [`Sink::start_send`] in lieu of the previous methods. If not using one of the
|
/// [`Sink::send`](futures::sink::SinkExt::send) - A user is free to manually read requests
|
||||||
/// previous execute methods, then nothing will automatically cancel requests or set up the
|
/// from, and send responses into, a Channel in lieu of the previous methods. Channels stream
|
||||||
/// request context. However, the Channel will still clean up resources upon deadline expiration
|
/// [`TrackedRequests`](TrackedRequest), which, in addition to the request itself, contains the
|
||||||
/// or cancellation. In the case that the Channel cleans up resources related to a request
|
/// server [`Span`] and request lifetime [`AbortRegistration`]. Wrapping response
|
||||||
/// before the response is sent, the response can still be sent into the Channel later on.
|
/// logic in an [`Abortable`] future using the abort registration will ensure that the response
|
||||||
/// Because there is no guarantee that a cancellation message will ever be received for a
|
/// does not execute longer than the request deadline. The `Channel` itself will clean up
|
||||||
/// request, or that requests come with reasonably short deadlines, services should strive to
|
/// request state once either the deadline expires, or a cancellation message is received, or a
|
||||||
/// clean up Channel resources by sending a response for every request.
|
/// response is sent. Because there is no guarantee that a cancellation message will ever be
|
||||||
|
/// received for a request, or that requests come with reasonably short deadlines, services
|
||||||
|
/// should strive to clean up Channel resources by sending a response for every request.
|
||||||
pub trait Channel
|
pub trait Channel
|
||||||
where
|
where
|
||||||
Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
|
Self: Transport<Response<<Self as Channel>::Resp>, TrackedRequest<<Self as Channel>::Req>>,
|
||||||
{
|
{
|
||||||
/// Type of request item.
|
/// Type of request item.
|
||||||
type Req;
|
type Req;
|
||||||
@@ -241,24 +259,23 @@ where
|
|||||||
/// Returns the transport underlying the channel.
|
/// Returns the transport underlying the channel.
|
||||||
fn transport(&self) -> &Self::Transport;
|
fn transport(&self) -> &Self::Transport;
|
||||||
|
|
||||||
/// Caps the number of concurrent requests to `limit`.
|
/// Caps the number of concurrent requests to `limit`. An error will be returned for requests
|
||||||
fn max_concurrent_requests(self, limit: usize) -> Throttler<Self>
|
/// over the concurrency limit.
|
||||||
|
///
|
||||||
|
/// Note that this is a very
|
||||||
|
/// simplistic throttling heuristic. It is easy to set a number that is too low for the
|
||||||
|
/// resources available to the server. For production use cases, a more advanced throttler is
|
||||||
|
/// likely needed.
|
||||||
|
fn max_concurrent_requests(
|
||||||
|
self,
|
||||||
|
limit: usize,
|
||||||
|
) -> limits::requests_per_channel::MaxRequests<Self>
|
||||||
where
|
where
|
||||||
Self: Sized,
|
Self: Sized,
|
||||||
{
|
{
|
||||||
Throttler::new(self, limit)
|
limits::requests_per_channel::MaxRequests::new(self, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Tells the Channel that request with ID `request_id` is being handled.
|
|
||||||
/// The request will be tracked until a response with the same ID is sent
|
|
||||||
/// to the Channel or the deadline expires, whichever happens first.
|
|
||||||
fn start_request(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
request_id: u64,
|
|
||||||
deadline: SystemTime,
|
|
||||||
span: Span,
|
|
||||||
) -> Result<AbortRegistration, AlreadyExistsError>;
|
|
||||||
|
|
||||||
/// Returns a stream of requests that automatically handle request cancellation and response
|
/// Returns a stream of requests that automatically handle request cancellation and response
|
||||||
/// routing.
|
/// routing.
|
||||||
///
|
///
|
||||||
@@ -279,11 +296,11 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Runs the channel until completion by executing all requests using the given service
|
/// Runs the channel until completion by executing all requests using the given service
|
||||||
/// function. Request handlers are run concurrently by [spawning](tokio::spawn) on tokio's
|
/// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's
|
||||||
/// default executor.
|
/// default executor.
|
||||||
#[cfg(feature = "tokio1")]
|
#[cfg(feature = "tokio1")]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||||
fn execute<S>(self, serve: S) -> TokioChannelExecutor<Requests<Self>, S>
|
fn execute<S>(self, serve: S) -> self::tokio::TokioChannelExecutor<Requests<Self>, S>
|
||||||
where
|
where
|
||||||
Self: Sized,
|
Self: Sized,
|
||||||
S: Serve<Self::Req, Resp = Self::Resp> + Send + 'static,
|
S: Serve<Self::Req, Resp = Self::Resp> + Send + 'static,
|
||||||
@@ -306,16 +323,17 @@ where
|
|||||||
Transport(#[source] E),
|
Transport(#[source] E),
|
||||||
/// An error occurred while polling expired requests.
|
/// An error occurred while polling expired requests.
|
||||||
#[error("an error occurred while polling expired requests: {0}")]
|
#[error("an error occurred while polling expired requests: {0}")]
|
||||||
Timer(#[source] tokio::time::error::Error),
|
Timer(#[source] ::tokio::time::error::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
|
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
|
||||||
where
|
where
|
||||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||||
{
|
{
|
||||||
type Item = Result<Request<Req>, ChannelError<T::Error>>;
|
type Item = Result<TrackedRequest<Req>, ChannelError<T::Error>>;
|
||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||||
|
#[derive(Debug)]
|
||||||
enum ReceiverStatus {
|
enum ReceiverStatus {
|
||||||
Ready,
|
Ready,
|
||||||
Pending,
|
Pending,
|
||||||
@@ -343,7 +361,16 @@ where
|
|||||||
{
|
{
|
||||||
Poll::Ready(Some(message)) => match message {
|
Poll::Ready(Some(message)) => match message {
|
||||||
ClientMessage::Request(request) => {
|
ClientMessage::Request(request) => {
|
||||||
return Poll::Ready(Some(Ok(request)));
|
match self.as_mut().start_request(request) {
|
||||||
|
Ok(request) => return Poll::Ready(Some(Ok(request))),
|
||||||
|
Err(AlreadyExistsError) => {
|
||||||
|
// Instead of closing the channel if a duplicate request is sent,
|
||||||
|
// just ignore it, since it's already being processed. Note that we
|
||||||
|
// cannot return Poll::Pending here, since nothing has scheduled a
|
||||||
|
// wakeup yet.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ClientMessage::Cancel {
|
ClientMessage::Cancel {
|
||||||
trace_context,
|
trace_context,
|
||||||
@@ -362,6 +389,11 @@ where
|
|||||||
Poll::Pending => Pending,
|
Poll::Pending => Pending,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
tracing::trace!(
|
||||||
|
"Expired requests: {:?}, Inbound: {:?}",
|
||||||
|
expiration_status,
|
||||||
|
request_status
|
||||||
|
);
|
||||||
match (expiration_status, request_status) {
|
match (expiration_status, request_status) {
|
||||||
(Ready, _) | (_, Ready) => continue,
|
(Ready, _) | (_, Ready) => continue,
|
||||||
(Closed, Closed) => return Poll::Ready(None),
|
(Closed, Closed) => return Poll::Ready(None),
|
||||||
@@ -405,6 +437,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
|
tracing::trace!("poll_flush");
|
||||||
self.project()
|
self.project()
|
||||||
.transport
|
.transport
|
||||||
.poll_flush(cx)
|
.poll_flush(cx)
|
||||||
@@ -444,17 +477,6 @@ where
|
|||||||
fn transport(&self) -> &Self::Transport {
|
fn transport(&self) -> &Self::Transport {
|
||||||
self.get_ref()
|
self.get_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_request(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
request_id: u64,
|
|
||||||
deadline: SystemTime,
|
|
||||||
span: Span,
|
|
||||||
) -> Result<AbortRegistration, AlreadyExistsError> {
|
|
||||||
self.project()
|
|
||||||
.in_flight_requests
|
|
||||||
.start_request(request_id, deadline, span)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so
|
/// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so
|
||||||
@@ -492,54 +514,12 @@ where
|
|||||||
mut self: Pin<&mut Self>,
|
mut self: Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Option<Result<InFlightRequest<C::Req, C::Resp>, C::Error>>> {
|
) -> Poll<Option<Result<InFlightRequest<C::Req, C::Resp>, C::Error>>> {
|
||||||
loop {
|
self.channel_pin_mut()
|
||||||
match ready!(self.channel_pin_mut().poll_next(cx)?) {
|
.poll_next(cx)
|
||||||
Some(mut request) => {
|
.map_ok(|request| InFlightRequest {
|
||||||
let span = info_span!(
|
request,
|
||||||
"RPC",
|
response_tx: self.responses_tx.clone(),
|
||||||
rpc.trace_id = %request.context.trace_id(),
|
})
|
||||||
otel.kind = "server",
|
|
||||||
otel.name = tracing::field::Empty,
|
|
||||||
);
|
|
||||||
span.set_context(&request.context);
|
|
||||||
request.context.trace_context =
|
|
||||||
trace::Context::try_from(&span).unwrap_or_else(|_| {
|
|
||||||
tracing::trace!(
|
|
||||||
"OpenTelemetry subscriber not installed; making unsampled
|
|
||||||
child context."
|
|
||||||
);
|
|
||||||
request.context.trace_context.new_child()
|
|
||||||
});
|
|
||||||
let entered = span.enter();
|
|
||||||
tracing::info!("ReceiveRequest");
|
|
||||||
let start = self.channel_pin_mut().start_request(
|
|
||||||
request.id,
|
|
||||||
request.context.deadline,
|
|
||||||
span.clone(),
|
|
||||||
);
|
|
||||||
match start {
|
|
||||||
Ok(abort_registration) => {
|
|
||||||
let response_tx = self.responses_tx.clone();
|
|
||||||
drop(entered);
|
|
||||||
return Poll::Ready(Some(Ok(InFlightRequest {
|
|
||||||
request,
|
|
||||||
response_tx,
|
|
||||||
abort_registration,
|
|
||||||
span,
|
|
||||||
})));
|
|
||||||
}
|
|
||||||
// Instead of closing the channel if a duplicate request is sent, just
|
|
||||||
// ignore it, since it's already being processed. Note that we cannot
|
|
||||||
// return Poll::Pending here, since nothing has scheduled a wakeup yet.
|
|
||||||
Err(AlreadyExistsError) => {
|
|
||||||
tracing::trace!("DuplicateRequest");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => return Poll::Ready(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pump_write(
|
fn pump_write(
|
||||||
@@ -619,16 +599,14 @@ where
|
|||||||
/// A request produced by [Channel::requests].
|
/// A request produced by [Channel::requests].
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct InFlightRequest<Req, Res> {
|
pub struct InFlightRequest<Req, Res> {
|
||||||
request: Request<Req>,
|
request: TrackedRequest<Req>,
|
||||||
response_tx: mpsc::Sender<Response<Res>>,
|
response_tx: mpsc::Sender<Response<Res>>,
|
||||||
abort_registration: AbortRegistration,
|
|
||||||
span: Span,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Req, Res> InFlightRequest<Req, Res> {
|
impl<Req, Res> InFlightRequest<Req, Res> {
|
||||||
/// Returns a reference to the request.
|
/// Returns a reference to the request.
|
||||||
pub fn get(&self) -> &Request<Req> {
|
pub fn get(&self) -> &Request<Req> {
|
||||||
&self.request
|
&self.request.request
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a [future](Future) that executes the request using the given [service
|
/// Returns a [future](Future) that executes the request using the given [service
|
||||||
@@ -647,15 +625,18 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
|||||||
S: Serve<Req, Resp = Res>,
|
S: Serve<Req, Resp = Res>,
|
||||||
{
|
{
|
||||||
let Self {
|
let Self {
|
||||||
abort_registration,
|
|
||||||
request:
|
|
||||||
Request {
|
|
||||||
context,
|
|
||||||
message,
|
|
||||||
id: request_id,
|
|
||||||
},
|
|
||||||
response_tx,
|
response_tx,
|
||||||
span,
|
request:
|
||||||
|
TrackedRequest {
|
||||||
|
abort_registration,
|
||||||
|
span,
|
||||||
|
request:
|
||||||
|
Request {
|
||||||
|
context,
|
||||||
|
message,
|
||||||
|
id: request_id,
|
||||||
|
},
|
||||||
|
},
|
||||||
} = self;
|
} = self;
|
||||||
let method = serve.method(&message);
|
let method = serve.method(&message);
|
||||||
span.record("otel.name", &method.unwrap_or(""));
|
span.record("otel.name", &method.unwrap_or(""));
|
||||||
@@ -704,129 +685,22 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send + 'static execution helper methods.
|
|
||||||
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
|
||||||
impl<C> Requests<C>
|
|
||||||
where
|
|
||||||
C: Channel,
|
|
||||||
C::Req: Send + 'static,
|
|
||||||
C::Resp: Send + 'static,
|
|
||||||
{
|
|
||||||
/// Executes all requests using the given service function. Requests are handled concurrently
|
|
||||||
/// by [spawning](tokio::spawn) each handler on tokio's default executor.
|
|
||||||
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
|
|
||||||
where
|
|
||||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
|
||||||
{
|
|
||||||
TokioChannelExecutor { inner: self, serve }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
|
|
||||||
/// for each new channel.
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
|
||||||
pub struct TokioServerExecutor<T, S> {
|
|
||||||
#[pin]
|
|
||||||
inner: T,
|
|
||||||
serve: S,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A future that drives the server by [spawning](tokio::spawn) each [response
|
|
||||||
/// handler](InFlightRequest::execute) on tokio's default executor.
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
|
||||||
pub struct TokioChannelExecutor<T, S> {
|
|
||||||
#[pin]
|
|
||||||
inner: T,
|
|
||||||
serve: S,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
|
||||||
impl<T, S> TokioServerExecutor<T, S> {
|
|
||||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
|
||||||
self.as_mut().project().inner
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
|
||||||
impl<T, S> TokioChannelExecutor<T, S> {
|
|
||||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
|
||||||
self.as_mut().project().inner
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
|
|
||||||
where
|
|
||||||
St: Sized + Stream<Item = C>,
|
|
||||||
C: Channel + Send + 'static,
|
|
||||||
C::Req: Send + 'static,
|
|
||||||
C::Resp: Send + 'static,
|
|
||||||
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
|
||||||
Se::Fut: Send,
|
|
||||||
{
|
|
||||||
type Output = ();
|
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
|
||||||
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
|
||||||
tokio::spawn(channel.execute(self.serve.clone()));
|
|
||||||
}
|
|
||||||
tracing::info!("Server shutting down.");
|
|
||||||
Poll::Ready(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
|
|
||||||
where
|
|
||||||
C: Channel + 'static,
|
|
||||||
C::Req: Send + 'static,
|
|
||||||
C::Resp: Send + 'static,
|
|
||||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
|
||||||
S::Fut: Send,
|
|
||||||
{
|
|
||||||
type Output = ();
|
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
||||||
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
|
||||||
match response_handler {
|
|
||||||
Ok(resp) => {
|
|
||||||
let server = self.serve.clone();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
resp.execute(server).await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::warn!("Requests stream errored out: {}", e);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Poll::Ready(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
trace,
|
context, trace,
|
||||||
transport::channel::{self, UnboundedChannel},
|
transport::channel::{self, UnboundedChannel},
|
||||||
|
ClientMessage, Request, Response,
|
||||||
};
|
};
|
||||||
use assert_matches::assert_matches;
|
use assert_matches::assert_matches;
|
||||||
use futures::future::{pending, Aborted};
|
use futures::{
|
||||||
|
future::{pending, AbortRegistration, Abortable, Aborted},
|
||||||
|
prelude::*,
|
||||||
|
Future,
|
||||||
|
};
|
||||||
use futures_test::task::noop_context;
|
use futures_test::task::noop_context;
|
||||||
use std::time::Duration;
|
use std::{pin::Pin, task::Poll};
|
||||||
|
|
||||||
fn test_channel<Req, Resp>() -> (
|
fn test_channel<Req, Resp>() -> (
|
||||||
Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
|
Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
|
||||||
@@ -892,12 +766,18 @@ mod tests {
|
|||||||
|
|
||||||
channel
|
channel
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.start_request(0, SystemTime::now(), Span::current())
|
.start_request(Request {
|
||||||
|
id: 0,
|
||||||
|
context: context::current(),
|
||||||
|
message: (),
|
||||||
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
channel
|
channel.as_mut().start_request(Request {
|
||||||
.as_mut()
|
id: 0,
|
||||||
.start_request(0, SystemTime::now(), Span::current()),
|
context: context::current(),
|
||||||
|
message: ()
|
||||||
|
}),
|
||||||
Err(AlreadyExistsError)
|
Err(AlreadyExistsError)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -907,13 +787,21 @@ mod tests {
|
|||||||
let (mut channel, _tx) = test_channel::<(), ()>();
|
let (mut channel, _tx) = test_channel::<(), ()>();
|
||||||
|
|
||||||
tokio::time::pause();
|
tokio::time::pause();
|
||||||
let abort_registration0 = channel
|
let req0 = channel
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.start_request(0, SystemTime::now(), Span::current())
|
.start_request(Request {
|
||||||
|
id: 0,
|
||||||
|
context: context::current(),
|
||||||
|
message: (),
|
||||||
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let abort_registration1 = channel
|
let req1 = channel
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.start_request(1, SystemTime::now(), Span::current())
|
.start_request(Request {
|
||||||
|
id: 1,
|
||||||
|
context: context::current(),
|
||||||
|
message: (),
|
||||||
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
|
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
|
||||||
|
|
||||||
@@ -921,8 +809,8 @@ mod tests {
|
|||||||
channel.as_mut().poll_next(&mut noop_context()),
|
channel.as_mut().poll_next(&mut noop_context()),
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
);
|
);
|
||||||
assert_matches!(test_abortable(abort_registration0).await, Err(Aborted));
|
assert_matches!(test_abortable(req0.abort_registration).await, Err(Aborted));
|
||||||
assert_matches!(test_abortable(abort_registration1).await, Err(Aborted));
|
assert_matches!(test_abortable(req1.abort_registration).await, Err(Aborted));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -930,13 +818,13 @@ mod tests {
|
|||||||
let (mut channel, mut tx) = test_channel::<(), ()>();
|
let (mut channel, mut tx) = test_channel::<(), ()>();
|
||||||
|
|
||||||
tokio::time::pause();
|
tokio::time::pause();
|
||||||
let abort_registration = channel
|
let req = channel
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.start_request(
|
.start_request(Request {
|
||||||
0,
|
id: 0,
|
||||||
SystemTime::now() + Duration::from_millis(100),
|
context: context::current(),
|
||||||
Span::current(),
|
message: (),
|
||||||
)
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
tx.send(ClientMessage::Cancel {
|
tx.send(ClientMessage::Cancel {
|
||||||
@@ -951,7 +839,7 @@ mod tests {
|
|||||||
Poll::Pending
|
Poll::Pending
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_matches!(test_abortable(abort_registration).await, Err(Aborted));
|
assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -961,11 +849,11 @@ mod tests {
|
|||||||
tokio::time::pause();
|
tokio::time::pause();
|
||||||
let _abort_registration = channel
|
let _abort_registration = channel
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.start_request(
|
.start_request(Request {
|
||||||
0,
|
id: 0,
|
||||||
SystemTime::now() + Duration::from_millis(100),
|
context: context::current(),
|
||||||
Span::current(),
|
message: (),
|
||||||
)
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
drop(tx);
|
drop(tx);
|
||||||
@@ -1001,9 +889,13 @@ mod tests {
|
|||||||
let (mut channel, mut tx) = test_channel::<(), ()>();
|
let (mut channel, mut tx) = test_channel::<(), ()>();
|
||||||
|
|
||||||
tokio::time::pause();
|
tokio::time::pause();
|
||||||
let abort_registration = channel
|
let req = channel
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.start_request(0, SystemTime::now(), Span::current())
|
.start_request(Request {
|
||||||
|
id: 0,
|
||||||
|
context: context::current(),
|
||||||
|
message: (),
|
||||||
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
|
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
|
||||||
|
|
||||||
@@ -1013,7 +905,7 @@ mod tests {
|
|||||||
channel.as_mut().poll_next(&mut noop_context()),
|
channel.as_mut().poll_next(&mut noop_context()),
|
||||||
Poll::Ready(Some(Ok(_)))
|
Poll::Ready(Some(Ok(_)))
|
||||||
);
|
);
|
||||||
assert_matches!(test_abortable(abort_registration).await, Err(Aborted));
|
assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -1022,7 +914,11 @@ mod tests {
|
|||||||
|
|
||||||
channel
|
channel
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.start_request(0, SystemTime::now(), Span::current())
|
.start_request(Request {
|
||||||
|
id: 0,
|
||||||
|
context: context::current(),
|
||||||
|
message: (),
|
||||||
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(channel.in_flight_requests(), 1);
|
assert_eq!(channel.in_flight_requests(), 1);
|
||||||
channel
|
channel
|
||||||
@@ -1043,7 +939,11 @@ mod tests {
|
|||||||
requests
|
requests
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.channel_pin_mut()
|
.channel_pin_mut()
|
||||||
.start_request(0, SystemTime::now(), Span::current())
|
.start_request(Request {
|
||||||
|
id: 0,
|
||||||
|
context: context::current(),
|
||||||
|
message: (),
|
||||||
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
requests
|
requests
|
||||||
.as_mut()
|
.as_mut()
|
||||||
@@ -1069,7 +969,11 @@ mod tests {
|
|||||||
requests
|
requests
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.channel_pin_mut()
|
.channel_pin_mut()
|
||||||
.start_request(1, SystemTime::now(), Span::current())
|
.start_request(Request {
|
||||||
|
id: 1,
|
||||||
|
context: context::current(),
|
||||||
|
message: (),
|
||||||
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
@@ -1086,7 +990,11 @@ mod tests {
|
|||||||
requests
|
requests
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.channel_pin_mut()
|
.channel_pin_mut()
|
||||||
.start_request(0, SystemTime::now(), Span::current())
|
.start_request(Request {
|
||||||
|
id: 0,
|
||||||
|
context: context::current(),
|
||||||
|
message: (),
|
||||||
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
requests
|
requests
|
||||||
.as_mut()
|
.as_mut()
|
||||||
@@ -1101,7 +1009,11 @@ mod tests {
|
|||||||
requests
|
requests
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.channel_pin_mut()
|
.channel_pin_mut()
|
||||||
.start_request(1, SystemTime::now(), Span::current())
|
.start_request(Request {
|
||||||
|
id: 1,
|
||||||
|
context: context::current(),
|
||||||
|
message: (),
|
||||||
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
requests
|
requests
|
||||||
.as_mut()
|
.as_mut()
|
||||||
|
|||||||
@@ -98,6 +98,11 @@ impl InFlightRequests {
|
|||||||
&mut self,
|
&mut self,
|
||||||
cx: &mut Context,
|
cx: &mut Context,
|
||||||
) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
|
) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
|
||||||
|
if self.deadlines.is_empty() {
|
||||||
|
// TODO(https://github.com/tokio-rs/tokio/issues/4161)
|
||||||
|
// This is a workaround for DelayQueue not always treating this case correctly.
|
||||||
|
return Poll::Ready(None);
|
||||||
|
}
|
||||||
self.deadlines.poll_expired(cx).map_ok(|expired| {
|
self.deadlines.poll_expired(cx).map_ok(|expired| {
|
||||||
if let Some(RequestData {
|
if let Some(RequestData {
|
||||||
abort_handle, span, ..
|
abort_handle, span, ..
|
||||||
@@ -184,12 +189,31 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn remove_request_doesnt_abort() {
|
async fn remove_request_doesnt_abort() {
|
||||||
let mut in_flight_requests = InFlightRequests::default();
|
let mut in_flight_requests = InFlightRequests::default();
|
||||||
|
assert!(in_flight_requests.deadlines.is_empty());
|
||||||
|
|
||||||
let abort_registration = in_flight_requests
|
let abort_registration = in_flight_requests
|
||||||
.start_request(0, SystemTime::now(), Span::current())
|
.start_request(
|
||||||
|
0,
|
||||||
|
SystemTime::now() + std::time::Duration::from_secs(10),
|
||||||
|
Span::current(),
|
||||||
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||||
|
|
||||||
|
// Precondition: Pending expiration
|
||||||
|
assert_matches!(
|
||||||
|
in_flight_requests.poll_expired(&mut noop_context()),
|
||||||
|
Poll::Pending
|
||||||
|
);
|
||||||
|
assert!(!in_flight_requests.deadlines.is_empty());
|
||||||
|
|
||||||
assert_matches!(in_flight_requests.remove_request(0), Some(_));
|
assert_matches!(in_flight_requests.remove_request(0), Some(_));
|
||||||
|
// Postcondition: No pending expirations
|
||||||
|
assert!(in_flight_requests.deadlines.is_empty());
|
||||||
|
assert_matches!(
|
||||||
|
in_flight_requests.poll_expired(&mut noop_context()),
|
||||||
|
Poll::Ready(None)
|
||||||
|
);
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
abortable_future.poll_unpin(&mut noop_context()),
|
abortable_future.poll_unpin(&mut noop_context()),
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
|
|||||||
49
tarpc/src/server/incoming.rs
Normal file
49
tarpc/src/server/incoming.rs
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
use super::{
|
||||||
|
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
|
||||||
|
Channel,
|
||||||
|
};
|
||||||
|
use futures::prelude::*;
|
||||||
|
use std::{fmt, hash::Hash};
|
||||||
|
|
||||||
|
#[cfg(feature = "tokio1")]
|
||||||
|
use super::{tokio::TokioServerExecutor, Serve};
|
||||||
|
|
||||||
|
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
|
||||||
|
pub trait Incoming<C>
|
||||||
|
where
|
||||||
|
Self: Sized + Stream<Item = C>,
|
||||||
|
C: Channel,
|
||||||
|
{
|
||||||
|
/// Enforces channel per-key limits.
|
||||||
|
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> MaxChannelsPerKey<Self, K, KF>
|
||||||
|
where
|
||||||
|
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||||
|
KF: Fn(&C) -> K,
|
||||||
|
{
|
||||||
|
MaxChannelsPerKey::new(self, n, keymaker)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Caps the number of concurrent requests per channel.
|
||||||
|
fn max_concurrent_requests_per_channel(self, n: usize) -> MaxRequestsPerChannel<Self> {
|
||||||
|
MaxRequestsPerChannel::new(self, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
|
||||||
|
/// concurrently by spawning on tokio's default executor, and each request will be also
|
||||||
|
/// be spawned on tokio's default executor.
|
||||||
|
#[cfg(feature = "tokio1")]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||||
|
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
|
||||||
|
where
|
||||||
|
S: Serve<C::Req, Resp = C::Resp>,
|
||||||
|
{
|
||||||
|
TokioServerExecutor::new(self, serve)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, C> Incoming<C> for S
|
||||||
|
where
|
||||||
|
S: Sized + Stream<Item = C>,
|
||||||
|
C: Channel,
|
||||||
|
{
|
||||||
|
}
|
||||||
5
tarpc/src/server/limits.rs
Normal file
5
tarpc/src/server/limits.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
/// Provides functionality to limit the number of active channels.
|
||||||
|
pub mod channels_per_key;
|
||||||
|
|
||||||
|
/// Provides a [channel](crate::server::Channel) that limits the number of in-flight requests.
|
||||||
|
pub mod requests_per_channel;
|
||||||
@@ -9,20 +9,23 @@ use crate::{
|
|||||||
util::Compact,
|
util::Compact,
|
||||||
};
|
};
|
||||||
use fnv::FnvHashMap;
|
use fnv::FnvHashMap;
|
||||||
use futures::{future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
|
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||||
use pin_project::pin_project;
|
use pin_project::pin_project;
|
||||||
use std::sync::{Arc, Weak};
|
use std::sync::{Arc, Weak};
|
||||||
use std::{
|
use std::{
|
||||||
collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin,
|
collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin,
|
||||||
time::SystemTime,
|
|
||||||
};
|
};
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tracing::{debug, info, trace};
|
use tracing::{debug, info, trace};
|
||||||
|
|
||||||
/// A single-threaded filter that drops channels based on per-key limits.
|
/// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on
|
||||||
|
/// per-key limits.
|
||||||
|
///
|
||||||
|
/// The decision to drop a Channel is made once at the time the Channel materializes. Once a
|
||||||
|
/// Channel is yielded, it will not be prematurely dropped.
|
||||||
#[pin_project]
|
#[pin_project]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ChannelFilter<S, K, F>
|
pub struct MaxChannelsPerKey<S, K, F>
|
||||||
where
|
where
|
||||||
K: Eq + Hash,
|
K: Eq + Hash,
|
||||||
{
|
{
|
||||||
@@ -35,7 +38,7 @@ where
|
|||||||
keymaker: F,
|
keymaker: F,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A channel that is tracked by a ChannelFilter.
|
/// A channel that is tracked by [`MaxChannelsPerKey`].
|
||||||
#[pin_project]
|
#[pin_project]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct TrackedChannel<C, K> {
|
pub struct TrackedChannel<C, K> {
|
||||||
@@ -116,15 +119,6 @@ where
|
|||||||
fn transport(&self) -> &Self::Transport {
|
fn transport(&self) -> &Self::Transport {
|
||||||
self.inner.transport()
|
self.inner.transport()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_request(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
id: u64,
|
|
||||||
deadline: SystemTime,
|
|
||||||
span: tracing::Span,
|
|
||||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
|
||||||
self.inner_pin_mut().start_request(id, deadline, span)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C, K> TrackedChannel<C, K> {
|
impl<C, K> TrackedChannel<C, K> {
|
||||||
@@ -139,7 +133,7 @@ impl<C, K> TrackedChannel<C, K> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, K, F> ChannelFilter<S, K, F>
|
impl<S, K, F> MaxChannelsPerKey<S, K, F>
|
||||||
where
|
where
|
||||||
K: Eq + Hash,
|
K: Eq + Hash,
|
||||||
S: Stream,
|
S: Stream,
|
||||||
@@ -148,7 +142,7 @@ where
|
|||||||
/// Sheds new channels to stay under configured limits.
|
/// Sheds new channels to stay under configured limits.
|
||||||
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
|
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
|
||||||
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel();
|
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel();
|
||||||
ChannelFilter {
|
MaxChannelsPerKey {
|
||||||
listener: listener.fuse(),
|
listener: listener.fuse(),
|
||||||
channels_per_key,
|
channels_per_key,
|
||||||
dropped_keys,
|
dropped_keys,
|
||||||
@@ -159,7 +153,7 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, K, F> ChannelFilter<S, K, F>
|
impl<S, K, F> MaxChannelsPerKey<S, K, F>
|
||||||
where
|
where
|
||||||
S: Stream,
|
S: Stream,
|
||||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||||
@@ -251,7 +245,7 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, K, F> Stream for ChannelFilter<S, K, F>
|
impl<S, K, F> Stream for MaxChannelsPerKey<S, K, F>
|
||||||
where
|
where
|
||||||
S: Stream,
|
S: Stream,
|
||||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||||
@@ -354,7 +348,7 @@ fn channel_filter_increment_channels_for_key() {
|
|||||||
key: &'static str,
|
key: &'static str,
|
||||||
}
|
}
|
||||||
let (_, listener) = futures::channel::mpsc::unbounded();
|
let (_, listener) = futures::channel::mpsc::unbounded();
|
||||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||||
pin_mut!(filter);
|
pin_mut!(filter);
|
||||||
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
|
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
|
||||||
assert_eq!(Arc::strong_count(&tracker1), 1);
|
assert_eq!(Arc::strong_count(&tracker1), 1);
|
||||||
@@ -375,7 +369,7 @@ fn channel_filter_handle_new_channel() {
|
|||||||
key: &'static str,
|
key: &'static str,
|
||||||
}
|
}
|
||||||
let (_, listener) = futures::channel::mpsc::unbounded();
|
let (_, listener) = futures::channel::mpsc::unbounded();
|
||||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||||
pin_mut!(filter);
|
pin_mut!(filter);
|
||||||
let channel1 = filter
|
let channel1 = filter
|
||||||
.as_mut()
|
.as_mut()
|
||||||
@@ -407,7 +401,7 @@ fn channel_filter_poll_listener() {
|
|||||||
key: &'static str,
|
key: &'static str,
|
||||||
}
|
}
|
||||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||||
pin_mut!(filter);
|
pin_mut!(filter);
|
||||||
|
|
||||||
new_channels
|
new_channels
|
||||||
@@ -443,7 +437,7 @@ fn channel_filter_poll_closed_channels() {
|
|||||||
key: &'static str,
|
key: &'static str,
|
||||||
}
|
}
|
||||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||||
pin_mut!(filter);
|
pin_mut!(filter);
|
||||||
|
|
||||||
new_channels
|
new_channels
|
||||||
@@ -471,7 +465,7 @@ fn channel_filter_stream() {
|
|||||||
key: &'static str,
|
key: &'static str,
|
||||||
}
|
}
|
||||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||||
pin_mut!(filter);
|
pin_mut!(filter);
|
||||||
|
|
||||||
new_channels
|
new_channels
|
||||||
@@ -4,45 +4,49 @@
|
|||||||
// license that can be found in the LICENSE file or at
|
// license that can be found in the LICENSE file or at
|
||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use super::{Channel, Config};
|
use crate::{
|
||||||
use crate::{Response, ServerError};
|
server::{Channel, Config},
|
||||||
use futures::{future::AbortRegistration, prelude::*, ready, task::*};
|
Response, ServerError,
|
||||||
|
};
|
||||||
|
use futures::{prelude::*, ready, task::*};
|
||||||
use pin_project::pin_project;
|
use pin_project::pin_project;
|
||||||
use std::{io, pin::Pin, time::SystemTime};
|
use std::{io, pin::Pin};
|
||||||
use tracing::Span;
|
|
||||||
|
|
||||||
/// A [`Channel`] that limits the number of concurrent
|
/// A [`Channel`] that limits the number of concurrent requests by throttling.
|
||||||
/// requests by throttling.
|
///
|
||||||
|
/// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low
|
||||||
|
/// for the resources available to the server. For production use cases, a more advanced throttler
|
||||||
|
/// is likely needed.
|
||||||
#[pin_project]
|
#[pin_project]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Throttler<C> {
|
pub struct MaxRequests<C> {
|
||||||
max_in_flight_requests: usize,
|
max_in_flight_requests: usize,
|
||||||
#[pin]
|
#[pin]
|
||||||
inner: C,
|
inner: C,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C> Throttler<C> {
|
impl<C> MaxRequests<C> {
|
||||||
/// Returns the inner channel.
|
/// Returns the inner channel.
|
||||||
pub fn get_ref(&self) -> &C {
|
pub fn get_ref(&self) -> &C {
|
||||||
&self.inner
|
&self.inner
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C> Throttler<C>
|
impl<C> MaxRequests<C>
|
||||||
where
|
where
|
||||||
C: Channel,
|
C: Channel,
|
||||||
{
|
{
|
||||||
/// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to
|
/// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to
|
||||||
/// `max_in_flight_requests`.
|
/// `max_in_flight_requests`.
|
||||||
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
|
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
|
||||||
Throttler {
|
MaxRequests {
|
||||||
max_in_flight_requests,
|
max_in_flight_requests,
|
||||||
inner,
|
inner,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C> Stream for Throttler<C>
|
impl<C> Stream for MaxRequests<C>
|
||||||
where
|
where
|
||||||
C: Channel,
|
C: Channel,
|
||||||
{
|
{
|
||||||
@@ -54,19 +58,18 @@ where
|
|||||||
ready!(self.as_mut().project().inner.poll_ready(cx)?);
|
ready!(self.as_mut().project().inner.poll_ready(cx)?);
|
||||||
|
|
||||||
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
|
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
|
||||||
Some(request) => {
|
Some(r) => {
|
||||||
tracing::debug!(
|
let _entered = r.span.enter();
|
||||||
rpc.trace_id = %request.context.trace_id(),
|
tracing::info!(
|
||||||
in_flight_requests = self.as_mut().in_flight_requests(),
|
in_flight_requests = self.as_mut().in_flight_requests(),
|
||||||
max_in_flight_requests = *self.as_mut().project().max_in_flight_requests,
|
"ThrottleRequest",
|
||||||
"At in-flight request limit",
|
|
||||||
);
|
);
|
||||||
|
|
||||||
self.as_mut().start_send(Response {
|
self.as_mut().start_send(Response {
|
||||||
request_id: request.id,
|
request_id: r.request.id,
|
||||||
message: Err(ServerError {
|
message: Err(ServerError {
|
||||||
kind: io::ErrorKind::WouldBlock,
|
kind: io::ErrorKind::WouldBlock,
|
||||||
detail: Some("Server throttled the request.".into()),
|
detail: "server throttled the request.".into(),
|
||||||
}),
|
}),
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
@@ -77,7 +80,7 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
|
impl<C> Sink<Response<<C as Channel>::Resp>> for MaxRequests<C>
|
||||||
where
|
where
|
||||||
C: Channel,
|
C: Channel,
|
||||||
{
|
{
|
||||||
@@ -103,13 +106,13 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C> AsRef<C> for Throttler<C> {
|
impl<C> AsRef<C> for MaxRequests<C> {
|
||||||
fn as_ref(&self) -> &C {
|
fn as_ref(&self) -> &C {
|
||||||
&self.inner
|
&self.inner
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C> Channel for Throttler<C>
|
impl<C> Channel for MaxRequests<C>
|
||||||
where
|
where
|
||||||
C: Channel,
|
C: Channel,
|
||||||
{
|
{
|
||||||
@@ -128,27 +131,19 @@ where
|
|||||||
fn transport(&self) -> &Self::Transport {
|
fn transport(&self) -> &Self::Transport {
|
||||||
self.inner.transport()
|
self.inner.transport()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_request(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
id: u64,
|
|
||||||
deadline: SystemTime,
|
|
||||||
span: Span,
|
|
||||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
|
||||||
self.project().inner.start_request(id, deadline, span)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A stream of throttling channels.
|
/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on
|
||||||
|
/// the number of in-flight requests.
|
||||||
#[pin_project]
|
#[pin_project]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ThrottlerStream<S> {
|
pub struct MaxRequestsPerChannel<S> {
|
||||||
#[pin]
|
#[pin]
|
||||||
inner: S,
|
inner: S,
|
||||||
max_in_flight_requests: usize,
|
max_in_flight_requests: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> ThrottlerStream<S>
|
impl<S> MaxRequestsPerChannel<S>
|
||||||
where
|
where
|
||||||
S: Stream,
|
S: Stream,
|
||||||
<S as Stream>::Item: Channel,
|
<S as Stream>::Item: Channel,
|
||||||
@@ -161,16 +156,16 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> Stream for ThrottlerStream<S>
|
impl<S> Stream for MaxRequestsPerChannel<S>
|
||||||
where
|
where
|
||||||
S: Stream,
|
S: Stream,
|
||||||
<S as Stream>::Item: Channel,
|
<S as Stream>::Item: Channel,
|
||||||
{
|
{
|
||||||
type Item = Throttler<<S as Stream>::Item>;
|
type Item = MaxRequests<<S as Stream>::Item>;
|
||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||||
match ready!(self.as_mut().project().inner.poll_next(cx)) {
|
match ready!(self.as_mut().project().inner.poll_next(cx)) {
|
||||||
Some(channel) => Poll::Ready(Some(Throttler::new(
|
Some(channel) => Poll::Ready(Some(MaxRequests::new(
|
||||||
channel,
|
channel,
|
||||||
*self.project().max_in_flight_requests,
|
*self.project().max_in_flight_requests,
|
||||||
))),
|
))),
|
||||||
@@ -183,19 +178,20 @@ where
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
use crate::{
|
use crate::server::{
|
||||||
server::{
|
testing::{self, FakeChannel, PollExt},
|
||||||
in_flight_requests::AlreadyExistsError,
|
TrackedRequest,
|
||||||
testing::{self, FakeChannel, PollExt},
|
|
||||||
},
|
|
||||||
Request,
|
|
||||||
};
|
};
|
||||||
use pin_utils::pin_mut;
|
use pin_utils::pin_mut;
|
||||||
use std::{marker::PhantomData, time::Duration};
|
use std::{
|
||||||
|
marker::PhantomData,
|
||||||
|
time::{Duration, SystemTime},
|
||||||
|
};
|
||||||
|
use tracing::Span;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn throttler_in_flight_requests() {
|
async fn throttler_in_flight_requests() {
|
||||||
let throttler = Throttler {
|
let throttler = MaxRequests {
|
||||||
max_in_flight_requests: 0,
|
max_in_flight_requests: 0,
|
||||||
inner: FakeChannel::default::<isize, isize>(),
|
inner: FakeChannel::default::<isize, isize>(),
|
||||||
};
|
};
|
||||||
@@ -215,28 +211,9 @@ mod tests {
|
|||||||
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
|
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn throttler_start_request() {
|
|
||||||
let throttler = Throttler {
|
|
||||||
max_in_flight_requests: 0,
|
|
||||||
inner: FakeChannel::default::<isize, isize>(),
|
|
||||||
};
|
|
||||||
|
|
||||||
pin_mut!(throttler);
|
|
||||||
throttler
|
|
||||||
.as_mut()
|
|
||||||
.start_request(
|
|
||||||
1,
|
|
||||||
SystemTime::now() + Duration::from_secs(1),
|
|
||||||
Span::current(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn throttler_poll_next_done() {
|
fn throttler_poll_next_done() {
|
||||||
let throttler = Throttler {
|
let throttler = MaxRequests {
|
||||||
max_in_flight_requests: 0,
|
max_in_flight_requests: 0,
|
||||||
inner: FakeChannel::default::<isize, isize>(),
|
inner: FakeChannel::default::<isize, isize>(),
|
||||||
};
|
};
|
||||||
@@ -247,7 +224,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn throttler_poll_next_some() -> io::Result<()> {
|
fn throttler_poll_next_some() -> io::Result<()> {
|
||||||
let throttler = Throttler {
|
let throttler = MaxRequests {
|
||||||
max_in_flight_requests: 1,
|
max_in_flight_requests: 1,
|
||||||
inner: FakeChannel::default::<isize, isize>(),
|
inner: FakeChannel::default::<isize, isize>(),
|
||||||
};
|
};
|
||||||
@@ -259,7 +236,7 @@ mod tests {
|
|||||||
throttler
|
throttler
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.poll_next(&mut testing::cx())?
|
.poll_next(&mut testing::cx())?
|
||||||
.map(|r| r.map(|r| (r.id, r.message))),
|
.map(|r| r.map(|r| (r.request.id, r.request.message))),
|
||||||
Poll::Ready(Some((0, 1)))
|
Poll::Ready(Some((0, 1)))
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -267,7 +244,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn throttler_poll_next_throttled() {
|
fn throttler_poll_next_throttled() {
|
||||||
let throttler = Throttler {
|
let throttler = MaxRequests {
|
||||||
max_in_flight_requests: 0,
|
max_in_flight_requests: 0,
|
||||||
inner: FakeChannel::default::<isize, isize>(),
|
inner: FakeChannel::default::<isize, isize>(),
|
||||||
};
|
};
|
||||||
@@ -283,7 +260,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn throttler_poll_next_throttled_sink_not_ready() {
|
fn throttler_poll_next_throttled_sink_not_ready() {
|
||||||
let throttler = Throttler {
|
let throttler = MaxRequests {
|
||||||
max_in_flight_requests: 0,
|
max_in_flight_requests: 0,
|
||||||
inner: PendingSink::default::<isize, isize>(),
|
inner: PendingSink::default::<isize, isize>(),
|
||||||
};
|
};
|
||||||
@@ -294,7 +271,8 @@ mod tests {
|
|||||||
ghost: PhantomData<fn(Out) -> In>,
|
ghost: PhantomData<fn(Out) -> In>,
|
||||||
}
|
}
|
||||||
impl PendingSink<(), ()> {
|
impl PendingSink<(), ()> {
|
||||||
pub fn default<Req, Resp>() -> PendingSink<io::Result<Request<Req>>, Response<Resp>> {
|
pub fn default<Req, Resp>(
|
||||||
|
) -> PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||||
PendingSink { ghost: PhantomData }
|
PendingSink { ghost: PhantomData }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -319,7 +297,7 @@ mod tests {
|
|||||||
Poll::Pending
|
Poll::Pending
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>> {
|
impl<Req, Resp> Channel for PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||||
type Req = Req;
|
type Req = Req;
|
||||||
type Resp = Resp;
|
type Resp = Resp;
|
||||||
type Transport = ();
|
type Transport = ();
|
||||||
@@ -332,20 +310,12 @@ mod tests {
|
|||||||
fn transport(&self) -> &() {
|
fn transport(&self) -> &() {
|
||||||
&()
|
&()
|
||||||
}
|
}
|
||||||
fn start_request(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
_id: u64,
|
|
||||||
_deadline: SystemTime,
|
|
||||||
_span: tracing::Span,
|
|
||||||
) -> Result<AbortRegistration, AlreadyExistsError> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn throttler_start_send() {
|
async fn throttler_start_send() {
|
||||||
let throttler = Throttler {
|
let throttler = MaxRequests {
|
||||||
max_in_flight_requests: 0,
|
max_in_flight_requests: 0,
|
||||||
inner: FakeChannel::default::<isize, isize>(),
|
inner: FakeChannel::default::<isize, isize>(),
|
||||||
};
|
};
|
||||||
@@ -6,10 +6,10 @@
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
context,
|
context,
|
||||||
server::{Channel, Config},
|
server::{Channel, Config, TrackedRequest},
|
||||||
Request, Response,
|
Request, Response,
|
||||||
};
|
};
|
||||||
use futures::{future::AbortRegistration, task::*, Sink, Stream};
|
use futures::{task::*, Sink, Stream};
|
||||||
use pin_project::pin_project;
|
use pin_project::pin_project;
|
||||||
use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
|
use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
|
||||||
use tracing::Span;
|
use tracing::Span;
|
||||||
@@ -62,7 +62,7 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Req, Resp> Channel for FakeChannel<io::Result<Request<Req>>, Response<Resp>>
|
impl<Req, Resp> Channel for FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>>
|
||||||
where
|
where
|
||||||
Req: Unpin,
|
Req: Unpin,
|
||||||
{
|
{
|
||||||
@@ -81,34 +81,28 @@ where
|
|||||||
fn transport(&self) -> &() {
|
fn transport(&self) -> &() {
|
||||||
&()
|
&()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_request(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
id: u64,
|
|
||||||
deadline: SystemTime,
|
|
||||||
span: Span,
|
|
||||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
|
||||||
self.project()
|
|
||||||
.in_flight_requests
|
|
||||||
.start_request(id, deadline, span)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
|
impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||||
pub fn push_req(&mut self, id: u64, message: Req) {
|
pub fn push_req(&mut self, id: u64, message: Req) {
|
||||||
self.stream.push_back(Ok(Request {
|
let (_, abort_registration) = futures::future::AbortHandle::new_pair();
|
||||||
context: context::Context {
|
self.stream.push_back(Ok(TrackedRequest {
|
||||||
deadline: SystemTime::UNIX_EPOCH,
|
request: Request {
|
||||||
trace_context: Default::default(),
|
context: context::Context {
|
||||||
|
deadline: SystemTime::UNIX_EPOCH,
|
||||||
|
trace_context: Default::default(),
|
||||||
|
},
|
||||||
|
id,
|
||||||
|
message,
|
||||||
},
|
},
|
||||||
id,
|
abort_registration,
|
||||||
message,
|
span: Span::none(),
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FakeChannel<(), ()> {
|
impl FakeChannel<(), ()> {
|
||||||
pub fn default<Req, Resp>() -> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
|
pub fn default<Req, Resp>() -> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||||
FakeChannel {
|
FakeChannel {
|
||||||
stream: Default::default(),
|
stream: Default::default(),
|
||||||
sink: Default::default(),
|
sink: Default::default(),
|
||||||
|
|||||||
111
tarpc/src/server/tokio.rs
Normal file
111
tarpc/src/server/tokio.rs
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
use super::{Channel, Requests, Serve};
|
||||||
|
use futures::{prelude::*, ready, task::*};
|
||||||
|
use pin_project::pin_project;
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
|
||||||
|
/// for each new channel. Returned by
|
||||||
|
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
|
||||||
|
#[pin_project]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TokioServerExecutor<T, S> {
|
||||||
|
#[pin]
|
||||||
|
inner: T,
|
||||||
|
serve: S,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, S> TokioServerExecutor<T, S> {
|
||||||
|
pub(crate) fn new(inner: T, serve: S) -> Self {
|
||||||
|
Self { inner, serve }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A future that drives the server by [spawning](tokio::spawn) each [response
|
||||||
|
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
|
||||||
|
/// [`Channel::execute`](crate::server::Channel::execute).
|
||||||
|
#[pin_project]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TokioChannelExecutor<T, S> {
|
||||||
|
#[pin]
|
||||||
|
inner: T,
|
||||||
|
serve: S,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, S> TokioServerExecutor<T, S> {
|
||||||
|
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||||
|
self.as_mut().project().inner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, S> TokioChannelExecutor<T, S> {
|
||||||
|
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||||
|
self.as_mut().project().inner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send + 'static execution helper methods.
|
||||||
|
|
||||||
|
impl<C> Requests<C>
|
||||||
|
where
|
||||||
|
C: Channel,
|
||||||
|
C::Req: Send + 'static,
|
||||||
|
C::Resp: Send + 'static,
|
||||||
|
{
|
||||||
|
/// Executes all requests using the given service function. Requests are handled concurrently
|
||||||
|
/// by [spawning](::tokio::spawn) each handler on tokio's default executor.
|
||||||
|
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
|
||||||
|
where
|
||||||
|
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
||||||
|
{
|
||||||
|
TokioChannelExecutor { inner: self, serve }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
|
||||||
|
where
|
||||||
|
St: Sized + Stream<Item = C>,
|
||||||
|
C: Channel + Send + 'static,
|
||||||
|
C::Req: Send + 'static,
|
||||||
|
C::Resp: Send + 'static,
|
||||||
|
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||||
|
Se::Fut: Send,
|
||||||
|
{
|
||||||
|
type Output = ();
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||||
|
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||||
|
tokio::spawn(channel.execute(self.serve.clone()));
|
||||||
|
}
|
||||||
|
tracing::info!("Server shutting down.");
|
||||||
|
Poll::Ready(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
|
||||||
|
where
|
||||||
|
C: Channel + 'static,
|
||||||
|
C::Req: Send + 'static,
|
||||||
|
C::Resp: Send + 'static,
|
||||||
|
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||||
|
S::Fut: Send,
|
||||||
|
{
|
||||||
|
type Output = ();
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||||
|
match response_handler {
|
||||||
|
Ok(resp) => {
|
||||||
|
let server = self.serve.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
resp.execute(server).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("Requests stream errored out: {}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Poll::Ready(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -71,9 +71,9 @@ pub struct SpanId(u64);
|
|||||||
#[repr(u8)]
|
#[repr(u8)]
|
||||||
pub enum SamplingDecision {
|
pub enum SamplingDecision {
|
||||||
/// The associated span was sampled by its creating process. Child spans must also be sampled.
|
/// The associated span was sampled by its creating process. Child spans must also be sampled.
|
||||||
Sampled = opentelemetry::trace::TRACE_FLAG_SAMPLED,
|
Sampled,
|
||||||
/// The associated span was not sampled by its creating process.
|
/// The associated span was not sampled by its creating process.
|
||||||
Unsampled = opentelemetry::trace::TRACE_FLAG_NOT_SAMPLED,
|
Unsampled,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Context {
|
impl Context {
|
||||||
@@ -173,8 +173,8 @@ impl TryFrom<&tracing::Span> for Context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<&dyn opentelemetry::trace::Span> for Context {
|
impl From<opentelemetry::trace::SpanRef<'_>> for Context {
|
||||||
fn from(span: &dyn opentelemetry::trace::Span) -> Self {
|
fn from(span: opentelemetry::trace::SpanRef<'_>) -> Self {
|
||||||
let otel_ctx = span.span_context();
|
let otel_ctx = span.span_context();
|
||||||
Self {
|
Self {
|
||||||
trace_id: TraceId::from(otel_ctx.trace_id()),
|
trace_id: TraceId::from(otel_ctx.trace_id()),
|
||||||
@@ -184,6 +184,15 @@ impl From<&dyn opentelemetry::trace::Span> for Context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<SamplingDecision> for opentelemetry::trace::TraceFlags {
|
||||||
|
fn from(decision: SamplingDecision) -> Self {
|
||||||
|
match decision {
|
||||||
|
SamplingDecision::Sampled => opentelemetry::trace::TraceFlags::SAMPLED,
|
||||||
|
SamplingDecision::Unsampled => opentelemetry::trace::TraceFlags::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<&opentelemetry::trace::SpanContext> for SamplingDecision {
|
impl From<&opentelemetry::trace::SpanContext> for SamplingDecision {
|
||||||
fn from(context: &opentelemetry::trace::SpanContext) -> Self {
|
fn from(context: &opentelemetry::trace::SpanContext) -> Self {
|
||||||
if context.is_sampled() {
|
if context.is_sampled() {
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use crate::{
|
use crate::{
|
||||||
client, context,
|
client, context,
|
||||||
server::{BaseChannel, Incoming},
|
server::{incoming::Incoming, BaseChannel},
|
||||||
transport::{
|
transport::{
|
||||||
self,
|
self,
|
||||||
channel::{Channel, UnboundedChannel},
|
channel::{Channel, UnboundedChannel},
|
||||||
@@ -170,7 +170,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn integration() -> io::Result<()> {
|
async fn integration() -> anyhow::Result<()> {
|
||||||
let _ = tracing_subscriber::fmt::try_init();
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
|
||||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
use futures::prelude::*;
|
use futures::prelude::*;
|
||||||
use std::io;
|
|
||||||
use tarpc::serde_transport;
|
use tarpc::serde_transport;
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
client, context,
|
client, context,
|
||||||
server::{BaseChannel, Incoming},
|
server::{incoming::Incoming, BaseChannel},
|
||||||
};
|
};
|
||||||
use tokio_serde::formats::Json;
|
use tokio_serde::formats::Json;
|
||||||
|
|
||||||
@@ -33,7 +32,7 @@ impl ColorProtocol for ColorServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_call() -> io::Result<()> {
|
async fn test_call() -> anyhow::Result<()> {
|
||||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
|
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
|
||||||
let addr = transport.local_addr();
|
let addr = transport.local_addr();
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ use std::time::{Duration, SystemTime};
|
|||||||
use tarpc::{
|
use tarpc::{
|
||||||
client::{self},
|
client::{self},
|
||||||
context,
|
context,
|
||||||
server::{self, BaseChannel, Channel, Incoming},
|
server::{self, incoming::Incoming, BaseChannel, Channel},
|
||||||
transport::channel,
|
transport::channel,
|
||||||
};
|
};
|
||||||
use tokio::join;
|
use tokio::join;
|
||||||
|
|||||||
Reference in New Issue
Block a user