34 Commits

Author SHA1 Message Date
Tim Kuehn
a3a6404a30 Prepare release of 0.28.0 2022-04-06 22:07:07 -07:00
Tim Kuehn
b36eac80b1 Bump minimum rust version to 1.58.0 2022-04-06 21:53:56 -07:00
Bruno
d7070e4bc3 Update opentelemetry and related dependencies (#362) 2022-04-03 14:09:14 -07:00
Tim Kuehn
b5d1828308 Use captured identifiers in format strings.
This was stabilized in Rust 1.58.0: https://blog.rust-lang.org/2022/01/13/Rust-1.58.0.html
2022-01-13 15:00:44 -08:00
Zak Cutner
92cfe63c4f Use single-threaded Tokio runtime (#360) 2022-01-06 20:59:51 -08:00
Tim Kuehn
839a2f067c Update example to latest version of Clap 2021-12-27 22:56:17 -08:00
David Kleingeld
b5d593488c Derive more traits for RpcError (#359)
Makes RpcError derive Clone, PartialEq, Eq, Hash, Serialize and Deserialize.
2021-12-27 22:00:58 -08:00
Tim Kuehn
eea38b8bf4 Simplify code with const assert!.
The code that prevents compilation on systems where usize is larger than
u64 previously used a const index-out-of-bounds trick. That code can now
be replaced with assert!, as const panic! has landed in 1.57.0 stable.
2021-12-03 15:20:33 -08:00
Shi Yan
70493c15f4 Fix a compiling issue of the official example (#358)
Fix a compiling issue of the official example because of the following error :

```
error[E0599]: the method `execute` exists for struct `BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>`, but its trait bounds were not satisfied
  --> src/main.rs:39:25
   |
39 |     tokio::spawn(server.execute(HelloServer.serve()));
   |                         ^^^^^^^ method cannot be called on `BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>` due to unsatisfied trait bounds
   |
   = note: the following trait bounds were not satisfied:
           `<&BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>> as futures::Stream>::Item = _`
           which is required by `&BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>: tarpc::server::incoming::Incoming<_>`
           `&BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>: futures::Stream`
           which is required by `&BaseChannel<_, _, UnboundedChannel<ClientMessage<_>, Response<_>>>: tarpc::server::incoming::Incoming<_>`
   = help: items from traits can only be used if the trait is in scope
help: the following trait is implemented but not in scope; perhaps add a `use` for it:
   |
1  | use tarpc::server::Channel;
   |
```

See https://github.com/google/tarpc/pull/358#issuecomment-981953193 for the root cause.
2021-11-29 17:01:16 -08:00
baptiste0928
f7c5d6a7c3 Fix example-service (#355)
Fixes the compilation of the example-service crate (the Clap trait has been renamed Parser in clap-rs/clap@d840d56).
2021-11-15 08:39:04 -08:00
Scott Kirkpatrick
98c5d2a18b Re-add typo fixes (#353)
The typo fixes that were added by commit b5d9aaa
were accidentally reverted by commit 1e680e3, this
will add them back
2021-11-08 10:07:21 -08:00
Tim Kuehn
46b534f7c6 Use HashMap::shrink_to in impl of Comapct::compact. 2021-10-21 17:03:57 -07:00
Tim Kuehn
42b4fc52b1 Set rust-version to 1.56 2021-10-21 16:08:15 -07:00
Tim Kuehn
350dbcdad0 Upgrade to Rust 2021! 2021-10-21 14:10:21 -07:00
Tim Kuehn
b1b4461d89 Prepare release of 0.27.2 2021-10-08 22:31:56 -07:00
Tim Kuehn
f694b7573a Close TcpStream when client disconnects.
An attempt at a clean shutdown helps the server to drop its connections
more quickly.

Testing this uncovered a latent bug in DelayQueue wherein `poll_expired`
yields `Pending` when empty. A workaround was added to
`InFlightRequests::poll_expired`: check if there are actually any
outstanding requests before calling `DelayQueue::poll_expired`.
2021-10-08 22:13:24 -07:00
Tim Kuehn
1e680e3a5a Fix typos in docs.
Fixes https://github.com/google/tarpc/issues/352.
2021-10-08 19:19:50 -07:00
Tim Kuehn
2591d21e94 Update release notes to mention io::Error = 2021-09-23 13:57:43 -07:00
Tim Kuehn
6632f68d95 Prepare for 0.27 release 2021-09-22 15:41:34 -07:00
Dmitry Kakurin
25985ad56a Update README.md (#350)
Fixed 2 typos
2021-09-01 17:58:49 -07:00
Tim Kuehn
d6a24e9420 Address Clippy lint 2021-08-24 12:40:18 -07:00
Tim Kuehn
281a78f3c7 Add tokio-serde-bincode feature 2021-08-24 12:37:57 -07:00
Julian Tescher
a0787d0091 Update to opentelemetry 0.16.x (#349) 2021-08-17 00:00:07 -04:00
Frederik-Baetens
d2acba0e8a add serde-transport-json feature flag (#346)
In general, it should be possible to use, or at least import all functionality of a library, when having only that library in your cargo.toml.
2021-05-06 08:41:57 -07:00
Tim Kuehn
ea7b6763c4 Refactor server module.
In the interest of the user's attention, some ancillary APIs have been
moved to new submodules:

- server::limits contains what was previously called Throttler and
  ChannelFilter. Both of those names were very generic, when the methods
  applied by these types were very specific (and also simplistic). Renames
  have occurred:
  - ThrottlerStream => MaxRequestsPerChannel
  - Throttler => MaxRequests
  - ChannelFilter => MaxChannelsPerKey
- server::incoming contains the Incoming trait.
- server::tokio contains the tokio-specific helper types.

The 5 structs and 1 enum remaining in the base server module are all
core to the functioning of the server.
2021-04-21 17:05:49 -07:00
Tim Kuehn
eb67c540b9 Use more structured errors in client. 2021-04-21 14:54:45 -07:00
Tim Kuehn
4151d0abd3 Move Span creation into BaseChannel.
It's important for Channel decorators, like Throttler, to have access to
the Span. This means that the BaseChannel becomes responsible for
starting its own requests. Actually, this simplifies the integration for
the Channel users, as they can assume any yielded requests are already
tracked.

This entails the following breaking changes:

- removed trait method Channel::start_request as it is now done
  internally.
2021-04-21 14:54:45 -07:00
Tim Kuehn
d0c11a6efa Change RPC error type from io::Error => RpcError.
Becaue tarpc is a library, not an application, it should strive to
use structured errors in its API so that users have maximal flexibility
in how they handle errors. io::Error makes that hard, because it is a
kitchen-sink error type.

RPCs in particular only have 3 classes of errors:

- The connection breaks.
- The request expires.
- The server decides not to process the request.

(Of course, RPCs themselves can have application-specific errors, but
from the perspective of the RPC library, those can be classified as
successful responsees).
2021-04-20 18:29:55 -07:00
Tim Kuehn
82c4da1743 Prepare release of v0.26.2 2021-04-20 11:28:15 -07:00
Tim Kuehn
0a15e0b75c Rustdoc: link RPC futures to their methods. 2021-04-20 11:25:26 -07:00
Tim Kuehn
0b315c29bf It's not currently possible to document the enum variants, which means
projects that #[deny(missing_docs)] wouldn't compile if using tarpc
services.
2021-04-20 09:01:39 -07:00
Tim Kuehn
56f09bf61f Fix log that's split across lines. 2021-04-17 17:15:16 -07:00
Tim Kuehn
6d82e82419 Fix formatting 2021-04-16 16:51:21 -07:00
Tim Kuehn
9bebaf814a Address clippy lint 2021-04-14 17:49:27 -07:00
33 changed files with 845 additions and 661 deletions

View File

@@ -40,7 +40,7 @@ rather than in a separate language such as .proto. This means there's no separat
process, and no context switching between different languages.
Some other features of tarpc:
- Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
- Pluggable transport: any type implementing `Stream<Item = Request> + Sink<Response>` can be
used as a transport to connect the client and server.
- `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
- Cascading cancellation: dropping a request will send a cancellation message to the server.
@@ -55,7 +55,7 @@ Some other features of tarpc:
[tracing](https://github.com/tokio-rs/tracing) primitives extended with
[OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
[Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
each RPC can be traced through the client, server, amd other dependencies downstream of the
each RPC can be traced through the client, server, and other dependencies downstream of the
server. Even for applications not connected to a distributed tracing collector, the
instrumentation can also be ingested by regular loggers like
[env_logger](https://github.com/env-logger-rs/env_logger/).
@@ -67,7 +67,7 @@ Some other features of tarpc:
Add to your `Cargo.toml` dependencies:
```toml
tarpc = "0.26"
tarpc = "0.28"
```
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
@@ -80,8 +80,9 @@ This example uses [tokio](https://tokio.rs), so add the following dependencies t
your `Cargo.toml`:
```toml
futures = "1.0"
tarpc = { version = "0.26", features = ["tokio1"] }
anyhow = "1.0"
futures = "0.3"
tarpc = { version = "0.28", features = ["tokio1"] }
tokio = { version = "1.0", features = ["macros"] }
```
@@ -99,9 +100,8 @@ use futures::{
};
use tarpc::{
client, context,
server::{self, Incoming},
server::{self, incoming::Incoming, Channel},
};
use std::io;
// This is the service definition. It looks a lot like a trait definition.
// It defines one RPC, hello, which takes one arg, name, and returns a String.
@@ -128,7 +128,7 @@ impl World for HelloServer {
type HelloFut = Ready<String>;
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!("Hello, {}!", name))
future::ready(format!("Hello, {name}!"))
}
}
```
@@ -140,7 +140,7 @@ available behind the `tcp` feature.
```rust
#[tokio::main]
async fn main() -> io::Result<()> {
async fn main() -> anyhow::Result<()> {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = server::BaseChannel::with_defaults(server_transport);
@@ -155,7 +155,7 @@ async fn main() -> io::Result<()> {
// specifies a deadline and trace information which can be helpful in debugging requests.
let hello = client.hello(context::current(), "Stim".to_string()).await?;
println!("{}", hello);
println!("{hello}");
Ok(())
}

View File

@@ -1,3 +1,50 @@
## 0.28.0 (2022-04-06)
### Breaking Changes
- The minimum supported Rust version has increased to 1.58.0.
- The version of opentelemetry depended on by tarpc has increased to 0.17.0.
## 0.27.2 (2021-10-08)
### Fixes
Clients will now close their transport before dropping it. An attempt at a clean shutdown can help
the server drop its connections more quickly.
## 0.27.1 (2021-09-22)
### Breaking Changes
### RPC error type is changing
RPC return types are changing from `Result<Response, io::Error>` to `Result<Response,
tarpc::client::RpcError>`.
Becaue tarpc is a library, not an application, it should strive to
use structured errors in its API so that users have maximal flexibility
in how they handle errors. io::Error makes that hard, because it is a
kitchen-sink error type.
RPCs in particular only have 3 classes of errors:
- The connection breaks.
- The request expires.
- The server decides not to process the request.
RPC responses can also contain application-specific errors, but from the
perspective of the RPC library, those are opaque to the framework, classified
as successful responsees.
### Open Telemetry
The Opentelemetry dependency is updated to version 0.16.x.
## 0.27.0 (2021-09-22)
This version was yanked due to tarpc-plugins version mismatches.
## 0.26.0 (2021-04-14)
### New Features
@@ -68,10 +115,10 @@ tracing_subscriber::fmt.
### References
[1] https://github.com/tokio-rs/tracing
[2] https://opentelemetry.io
[3] https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
[4] https://github.com/env-logger-rs/env_logger
1. https://github.com/tokio-rs/tracing
2. https://opentelemetry.io
3. https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
4. https://github.com/env-logger-rs/env_logger
## 0.25.0 (2021-03-10)

View File

@@ -1,8 +1,9 @@
[package]
name = "tarpc-example-service"
version = "0.10.0"
rust-version = "1.56"
authors = ["Tim Kuehn <tikue@google.com>"]
edition = "2018"
edition = "2021"
license = "MIT"
documentation = "https://docs.rs/tarpc-example-service"
homepage = "https://github.com/google/tarpc"
@@ -14,18 +15,16 @@ description = "An example server built on tarpc."
[dependencies]
anyhow = "1.0"
clap = "3.0.0-beta.2"
clap = { version = "3.0.0-rc.9", features = ["derive"] }
log = "0.4"
futures = "0.3"
opentelemetry = { version = "0.13", features = ["rt-tokio"] }
opentelemetry-jaeger = { version = "0.12", features = ["tokio"] }
opentelemetry = { version = "0.16", features = ["rt-tokio"] }
opentelemetry-jaeger = { version = "0.15", features = ["rt-tokio"] }
rand = "0.8"
serde = { version = "1.0" }
tarpc = { version = "0.26", path = "../tarpc", features = ["full"] }
tarpc = { version = "0.28", path = "../tarpc", features = ["full"] }
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
tokio-serde = { version = "0.8", features = ["json"] }
tracing = { version = "0.1" }
tracing-opentelemetry = "0.12"
tracing-opentelemetry = "0.15"
tracing-subscriber = "0.2"
[lib]

View File

@@ -4,14 +4,14 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use clap::Clap;
use clap::Parser;
use service::{init_tracing, WorldClient};
use std::{net::SocketAddr, time::Duration};
use tarpc::{client, context, tokio_serde::formats::Json};
use tokio::time::sleep;
use tracing::Instrument;
#[derive(Clap)]
#[derive(Parser)]
struct Flags {
/// Sets the server address to connect to.
#[clap(long)]

View File

@@ -4,7 +4,7 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use clap::Clap;
use clap::Parser;
use futures::{future, prelude::*};
use rand::{
distributions::{Distribution, Uniform},
@@ -17,12 +17,12 @@ use std::{
};
use tarpc::{
context,
server::{self, Channel, Incoming},
server::{self, incoming::Incoming, Channel},
tokio_serde::formats::Json,
};
use tokio::time;
#[derive(Clap)]
#[derive(Parser)]
struct Flags {
/// Sets the port number to listen on.
#[clap(long)]
@@ -40,7 +40,7 @@ impl World for HelloServer {
let sleep_time =
Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng()));
time::sleep(sleep_time).await;
format!("Hello, {}! You are connected from {}", name, self.0)
format!("Hello, {name}! You are connected from {}", self.0)
}
}

View File

@@ -1,8 +1,9 @@
[package]
name = "tarpc-plugins"
version = "0.11.0"
version = "0.12.0"
rust-version = "1.56"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2018"
edition = "2021"
license = "MIT"
documentation = "https://docs.rs/tarpc-plugins"
homepage = "https://github.com/google/tarpc"

View File

@@ -83,7 +83,7 @@ impl Parse for Service {
ident_errors,
syn::Error::new(
rpc.ident.span(),
format!("method name conflicts with generated fn `{}::serve`", ident)
format!("method name conflicts with generated fn `{ident}::serve`")
)
);
}
@@ -270,14 +270,14 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
let methods = rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>();
let request_names = methods
.iter()
.map(|m| format!("{}.{}", ident, m))
.map(|m| format!("{ident}.{m}"))
.collect::<Vec<_>>();
ServiceGenerator {
response_fut_name,
service_ident: ident,
server_ident: &format_ident!("Serve{}", ident),
response_fut_ident: &Ident::new(&response_fut_name, ident.span()),
response_fut_ident: &Ident::new(response_fut_name, ident.span()),
client_ident: &format_ident!("{}Client", ident),
request_ident: &format_ident!("{}Request", ident),
response_ident: &format_ident!("{}Response", ident),
@@ -306,7 +306,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
.collect::<Vec<_>>(),
future_types: &camel_case_fn_names
.iter()
.map(|name| parse_str(&format!("{}Fut", name)).unwrap())
.map(|name| parse_str(&format!("{name}Fut")).unwrap())
.collect::<Vec<_>>(),
derive_serialize: derive_serialize.as_ref(),
}
@@ -406,14 +406,10 @@ fn verify_types_were_provided(
) -> syn::Result<()> {
let mut result = Ok(());
for (method, expected) in expected {
if provided
.iter()
.find(|typedecl| typedecl.ident == expected)
.is_none()
{
if !provided.iter().any(|typedecl| typedecl.ident == expected) {
let mut e = syn::Error::new(
span,
format!("not all trait items implemented, missing: `{}`", expected),
format!("not all trait items implemented, missing: `{expected}`"),
);
let fn_span = method.sig.fn_token.span();
e.extend(syn::Error::new(
@@ -483,7 +479,7 @@ impl<'a> ServiceGenerator<'a> {
),
output,
)| {
let ty_doc = format!("The response future returned by {}.", ident);
let ty_doc = format!("The response future returned by [`{service_ident}::{ident}`].");
quote! {
#[doc = #ty_doc]
type #future_type: std::future::Future<Output = #output>;
@@ -582,6 +578,7 @@ impl<'a> ServiceGenerator<'a> {
quote! {
/// The request sent over the wire from the client to the server.
#[allow(missing_docs)]
#[derive(Debug)]
#derive_serialize
#vis enum #request_ident {
@@ -602,6 +599,7 @@ impl<'a> ServiceGenerator<'a> {
quote! {
/// The response sent over the wire from the server to the client.
#[allow(missing_docs)]
#[derive(Debug)]
#derive_serialize
#vis enum #response_ident {
@@ -622,6 +620,7 @@ impl<'a> ServiceGenerator<'a> {
quote! {
/// A future resolving to a server response.
#[allow(missing_docs)]
#vis enum #response_fut_ident<S: #service_ident> {
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
}
@@ -747,7 +746,7 @@ impl<'a> ServiceGenerator<'a> {
#[allow(unused)]
#( #method_attrs )*
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
-> impl std::future::Future<Output = Result<#return_types, tarpc::client::RpcError>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
let resp = self.0.call(ctx, #request_names, request);
async move {

View File

@@ -1,8 +1,12 @@
[package]
name = "tarpc"
version = "0.26.0"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2018"
version = "0.28.0"
rust-version = "1.58.0"
authors = [
"Adam Wright <adam.austin.wright@gmail.com>",
"Tim Kuehn <timothy.j.kuehn@gmail.com>",
]
edition = "2021"
license = "MIT"
documentation = "https://docs.rs/tarpc"
homepage = "https://github.com/google/tarpc"
@@ -16,11 +20,20 @@ description = "An RPC framework for Rust with a focus on ease of use."
default = []
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
tokio1 = ["tokio/rt-multi-thread"]
tokio1 = ["tokio/rt"]
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
serde-transport-json = ["tokio-serde/json"]
serde-transport-bincode = ["tokio-serde/bincode"]
tcp = ["tokio/net"]
full = ["serde1", "tokio1", "serde-transport", "tcp"]
full = [
"serde1",
"tokio1",
"serde-transport",
"serde-transport-json",
"serde-transport-bincode",
"tcp",
]
[badges]
travis-ci = { repository = "google/tarpc" }
@@ -34,14 +47,17 @@ pin-project = "1.0"
rand = "0.8"
serde = { optional = true, version = "1.0", features = ["derive"] }
static_assertions = "1.1.0"
tarpc-plugins = { path = "../plugins", version = "0.11" }
tarpc-plugins = { path = "../plugins", version = "0.12" }
thiserror = "1.0"
tokio = { version = "1", features = ["time"] }
tokio-util = { version = "0.6.3", features = ["time"] }
tokio-util = { version = "0.6.9", features = ["time"] }
tokio-serde = { optional = true, version = "0.8" }
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
tracing-opentelemetry = { version = "0.12", default-features = false }
opentelemetry = { version = "0.13", default-features = false }
tracing = { version = "0.1", default-features = false, features = [
"attributes",
"log",
] }
tracing-opentelemetry = { version = "0.17.2", default-features = false }
opentelemetry = { version = "0.17.0", default-features = false }
[dev-dependencies]
@@ -50,11 +66,13 @@ bincode = "1.3"
bytes = { version = "1", features = ["serde"] }
flate2 = "1.0"
futures-test = "0.3"
opentelemetry = { version = "0.13", default-features = false, features = ["rt-tokio"] }
opentelemetry-jaeger = { version = "0.12", features = ["tokio"] }
opentelemetry = { version = "0.17.0", default-features = false, features = [
"rt-tokio",
] }
opentelemetry-jaeger = { version = "0.16.0", features = ["rt-tokio"] }
pin-utils = "0.1.0-alpha"
serde_bytes = "0.11"
tracing-subscriber = "0.2"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tokio = { version = "1", features = ["full", "test-util"] }
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
trybuild = "1.0"

View File

@@ -7,8 +7,8 @@ use tarpc::{
client, context,
serde_transport::tcp,
server::{BaseChannel, Channel},
tokio_serde::formats::Bincode,
};
use tokio_serde::formats::Bincode;
/// Type of compression that should be enabled on the request. The transport is free to ignore this.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)]
@@ -54,7 +54,7 @@ where
if algorithm != CompressionAlgorithm::Deflate {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Compression algorithm {:?} not supported", algorithm),
format!("Compression algorithm {algorithm:?} not supported"),
));
}
let mut deflater = DeflateDecoder::new(payload.as_slice());
@@ -102,7 +102,7 @@ struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hey, {}!", name)
format!("Hey, {name}!")
}
}

View File

@@ -1,9 +1,7 @@
use futures::future;
use tarpc::context::Context;
use tarpc::serde_transport as transport;
use tarpc::server::{BaseChannel, Channel};
use tarpc::{context::Context, tokio_serde::formats::Bincode};
use tokio::net::{UnixListener, UnixStream};
use tokio_serde::formats::Bincode;
use tokio_util::codec::length_delimited::LengthDelimitedCodec;
#[tarpc::service]
@@ -14,16 +12,13 @@ pub trait PingService {
#[derive(Clone)]
struct Service;
#[tarpc::server]
impl PingService for Service {
type PingFut = future::Ready<()>;
fn ping(self, _: Context) -> Self::PingFut {
future::ready(())
}
async fn ping(self, _: Context) {}
}
#[tokio::main]
async fn main() -> std::io::Result<()> {
async fn main() -> anyhow::Result<()> {
let bind_addr = "/tmp/tarpc_on_unix_example.sock";
let _ = std::fs::remove_file(bind_addr);
@@ -46,5 +41,7 @@ async fn main() -> std::io::Result<()> {
PingServiceClient::new(Default::default(), transport)
.spawn()
.ping(tarpc::context::current())
.await
.await?;
Ok(())
}

View File

@@ -290,7 +290,7 @@ fn init_tracing(service_name: &str) -> anyhow::Result<()> {
.install_batch(opentelemetry::runtime::Tokio)?;
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(tracing_subscriber::filter::EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer())
.with(tracing_opentelemetry::layer().with_tracer(tracer))
.try_init()?;

View File

@@ -5,7 +5,6 @@
// https://opensource.org/licenses/MIT.
use futures::future::{self, Ready};
use std::io;
use tarpc::{
client, context,
server::{self, Channel},
@@ -30,12 +29,12 @@ impl World for HelloServer {
type HelloFut = Ready<String>;
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!("Hello, {}!", name))
future::ready(format!("Hello, {name}!"))
}
}
#[tokio::main]
async fn main() -> io::Result<()> {
async fn main() -> anyhow::Result<()> {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = server::BaseChannel::with_defaults(server_transport);
@@ -50,7 +49,7 @@ async fn main() -> io::Result<()> {
// specifies a deadline and trace information which can be helpful in debugging requests.
let hello = client.hello(context::current(), "Stim".to_string()).await?;
println!("{}", hello);
println!("{hello}");
Ok(())
}

View File

@@ -9,7 +9,7 @@ use futures::{future, prelude::*};
use std::env;
use tarpc::{
client, context,
server::{BaseChannel, Incoming},
server::{incoming::Incoming, BaseChannel},
};
use tokio_serde::formats::Json;
use tracing_subscriber::prelude::*;

View File

@@ -8,14 +8,14 @@
mod in_flight_requests;
use crate::{context, trace, ClientMessage, Request, Response, Transport};
use crate::{context, trace, ClientMessage, Request, Response, ServerError, Transport};
use futures::{prelude::*, ready, stream::Fuse, task::*};
use in_flight_requests::InFlightRequests;
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
use pin_project::pin_project;
use std::{
convert::TryFrom,
error::Error,
fmt, io, mem,
fmt, mem,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
@@ -60,15 +60,16 @@ pub struct NewClient<C, D> {
impl<C, D, E> NewClient<C, D>
where
D: Future<Output = Result<(), E>> + Send + 'static,
E: std::fmt::Display,
E: std::error::Error + Send + Sync + 'static,
{
/// Helper method to spawn the dispatch on the default executor.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
pub fn spawn(self) -> C {
let dispatch = self
.dispatch
.unwrap_or_else(move |e| tracing::warn!("Connection broken: {}", e));
let dispatch = self.dispatch.unwrap_or_else(move |e| {
let e = anyhow::Error::new(e);
tracing::warn!("Connection broken: {:?}", e);
});
tokio::spawn(dispatch);
self.client
}
@@ -80,14 +81,10 @@ impl<C, D> fmt::Debug for NewClient<C, D> {
}
}
#[allow(dead_code)]
#[allow(clippy::no_effect)]
const CHECK_USIZE: () = {
if std::mem::size_of::<usize>() > std::mem::size_of::<u64>() {
// TODO: replace this with panic!() as soon as RFC 2345 gets stabilized
["usize is too big to fit in u64"][42];
}
};
const _CHECK_USIZE: () = assert!(
std::mem::size_of::<usize>() <= std::mem::size_of::<u64>(),
"usize is too big to fit in u64"
);
/// Handles communication from the client to request dispatch.
#[derive(Debug)]
@@ -125,7 +122,7 @@ impl<Req, Resp> Channel<Req, Resp> {
mut ctx: context::Context,
request_name: &str,
request: Req,
) -> io::Result<Resp> {
) -> Result<Resp, RpcError> {
let span = Span::current();
ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
tracing::warn!(
@@ -156,7 +153,7 @@ impl<Req, Resp> Channel<Req, Resp> {
response_completion,
})
.await
.map_err(|mpsc::error::SendError(_)| io::Error::from(io::ErrorKind::ConnectionReset))?;
.map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
response_guard.response().await
}
}
@@ -164,23 +161,45 @@ impl<Req, Resp> Channel<Req, Resp> {
/// A server response that is completed by request dispatch when the corresponding response
/// arrives off the wire.
struct ResponseGuard<'a, Resp> {
response: &'a mut oneshot::Receiver<Response<Resp>>,
response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
cancellation: &'a RequestCancellation,
request_id: u64,
}
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
/// rather cross-cutting errors that can always occur.
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub enum RpcError {
/// The client disconnected from the server.
#[error("the client disconnected from the server")]
Disconnected,
/// The request exceeded its deadline.
#[error("the request exceeded its deadline")]
DeadlineExceeded,
/// The server aborted request processing.
#[error("the server aborted request processing")]
Server(#[from] ServerError),
}
impl From<DeadlineExceededError> for RpcError {
fn from(_: DeadlineExceededError) -> Self {
RpcError::DeadlineExceeded
}
}
impl<Resp> ResponseGuard<'_, Resp> {
async fn response(mut self) -> io::Result<Resp> {
async fn response(mut self) -> Result<Resp, RpcError> {
let response = (&mut self.response).await;
// Cancel drop logic once a response has been received.
mem::forget(self);
match response {
Ok(resp) => Ok(resp.message?),
Ok(resp) => Ok(resp?.message?),
Err(oneshot::error::RecvError { .. }) => {
// The oneshot is Canceled when the dispatch task ends. In that case,
// there's nothing listening on the other side, so there's no point in
// propagating cancellation.
Err(io::Error::from(io::ErrorKind::ConnectionReset))
Err(RpcError::Disconnected)
}
}
}
@@ -257,11 +276,23 @@ pub enum ChannelError<E>
where
E: Error + Send + Sync + 'static,
{
/// An error occurred reading from, or writing to, the transport.
#[error("an error occurred in the transport: {0}")]
Transport(#[source] E),
/// An error occurred while polling expired requests.
#[error("an error occurred while polling expired requests: {0}")]
/// Could not read from the transport.
#[error("could not read from the transport")]
Read(#[source] E),
/// Could not ready the transport for writes.
#[error("could not ready the transport for writes")]
Ready(#[source] E),
/// Could not write to the transport.
#[error("could not write to the transport")]
Write(#[source] E),
/// Could not flush the transport.
#[error("could not flush the transport")]
Flush(#[source] E),
/// Could not close the write end of the transport.
#[error("could not close the write end of the transport")]
Close(#[source] E),
/// Could not poll expired requests.
#[error("could not poll expired requests")]
Timer(#[source] tokio::time::error::Error),
}
@@ -277,6 +308,42 @@ where
self.as_mut().project().transport
}
fn poll_ready<'a>(
self: &'a mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), ChannelError<C::Error>>> {
self.transport_pin_mut()
.poll_ready(cx)
.map_err(ChannelError::Ready)
}
fn start_send(
self: &mut Pin<&mut Self>,
message: ClientMessage<Req>,
) -> Result<(), ChannelError<C::Error>> {
self.transport_pin_mut()
.start_send(message)
.map_err(ChannelError::Write)
}
fn poll_flush<'a>(
self: &'a mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), ChannelError<C::Error>>> {
self.transport_pin_mut()
.poll_flush(cx)
.map_err(ChannelError::Flush)
}
fn poll_close<'a>(
self: &'a mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), ChannelError<C::Error>>> {
self.transport_pin_mut()
.poll_close(cx)
.map_err(ChannelError::Close)
}
fn canceled_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut CanceledRequests {
self.as_mut().project().canceled_requests
}
@@ -293,7 +360,7 @@ where
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
self.transport_pin_mut()
.poll_next(cx)
.map_err(ChannelError::Transport)
.map_err(ChannelError::Read)
.map_ok(|response| {
self.complete(response);
})
@@ -308,21 +375,13 @@ where
Closed,
}
let pending_requests_status = match self
.as_mut()
.poll_write_request(cx)
.map_err(ChannelError::Transport)?
{
let pending_requests_status = match self.as_mut().poll_write_request(cx)? {
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::Pending,
};
let canceled_requests_status = match self
.as_mut()
.poll_write_cancel(cx)
.map_err(ChannelError::Transport)?
{
let canceled_requests_status = match self.as_mut().poll_write_cancel(cx)? {
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::Pending,
@@ -344,18 +403,12 @@ where
match (pending_requests_status, canceled_requests_status) {
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
ready!(self
.transport_pin_mut()
.poll_flush(cx)
.map_err(ChannelError::Transport)?);
ready!(self.poll_close(cx)?);
Poll::Ready(None)
}
(ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => {
// No more messages to process, so flush any messages buffered in the transport.
ready!(self
.transport_pin_mut()
.poll_flush(cx)
.map_err(ChannelError::Transport)?);
ready!(self.poll_flush(cx)?);
// Even if we fully-flush, we return Pending, because we have no more requests
// or cancellations right now.
@@ -371,7 +424,7 @@ where
fn poll_next_request(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<DispatchRequest<Req, Resp>, C::Error>>> {
) -> Poll<Option<Result<DispatchRequest<Req, Resp>, ChannelError<C::Error>>>> {
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
tracing::info!(
"At in-flight request capacity ({}/{}).",
@@ -409,7 +462,7 @@ where
fn poll_next_cancellation(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(context::Context, Span, u64), C::Error>>> {
) -> Poll<Option<Result<(context::Context, Span, u64), ChannelError<C::Error>>>> {
ready!(self.ensure_writeable(cx)?);
loop {
@@ -431,9 +484,9 @@ where
fn ensure_writeable<'a>(
self: &'a mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), C::Error>>> {
while self.transport_pin_mut().poll_ready(cx)?.is_pending() {
ready!(self.transport_pin_mut().poll_flush(cx)?);
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
while self.poll_ready(cx)?.is_pending() {
ready!(self.poll_flush(cx)?);
}
Poll::Ready(Some(Ok(())))
}
@@ -441,7 +494,7 @@ where
fn poll_write_request<'a>(
self: &'a mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), C::Error>>> {
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
let DispatchRequest {
ctx,
span,
@@ -465,7 +518,7 @@ where
trace_context: ctx.trace_context,
},
});
self.transport_pin_mut().start_send(request)?;
self.start_send(request)?;
let deadline = ctx.deadline;
tracing::info!(
tarpc.deadline = %humantime::format_rfc3339(deadline),
@@ -482,7 +535,7 @@ where
fn poll_write_cancel<'a>(
self: &'a mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), C::Error>>> {
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
Some(triple) => triple,
None => return Poll::Ready(None),
@@ -493,7 +546,7 @@ where
trace_context: context.trace_context,
request_id,
};
self.transport_pin_mut().start_send(cancel)?;
self.start_send(cancel)?;
tracing::info!("CancelRequest");
Poll::Ready(Some(Ok(())))
}
@@ -549,7 +602,7 @@ struct DispatchRequest<Req, Resp> {
pub span: Span,
pub request_id: u64,
pub request: Req,
pub response_completion: oneshot::Sender<Response<Resp>>,
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
}
/// Sends request cancellation signals.
@@ -596,7 +649,10 @@ mod tests {
RequestDispatch, ResponseGuard,
};
use crate::{
client::{in_flight_requests::InFlightRequests, Config},
client::{
in_flight_requests::{DeadlineExceededError, InFlightRequests},
Config,
},
context,
transport::{self, channel::UnboundedChannel},
ClientMessage, Response,
@@ -630,7 +686,7 @@ mod tests {
.await
.unwrap();
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
assert_matches!(rx.try_recv(), Ok(Response { request_id: 0, message: Ok(resp) }) if resp == "Resp");
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
}
#[tokio::test]
@@ -651,10 +707,10 @@ mod tests {
async fn dispatch_response_doesnt_cancel_after_complete() {
let (cancellation, mut canceled_requests) = cancellations();
let (tx, mut response) = oneshot::channel();
tx.send(Response {
tx.send(Ok(Response {
request_id: 0,
message: Ok("well done"),
})
}))
.unwrap();
// resp's drop() is run, but should not send a cancel message.
ResponseGuard {
@@ -799,8 +855,8 @@ mod tests {
async fn send_request<'a>(
channel: &'a mut Channel<String, String>,
request: &str,
response_completion: oneshot::Sender<Response<String>>,
response: &'a mut oneshot::Receiver<Response<String>>,
response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
) -> ResponseGuard<'a, String> {
let request_id =
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();

View File

@@ -1,12 +1,11 @@
use crate::{
context,
util::{Compact, TimeUntil},
Response, ServerError,
Response,
};
use fnv::FnvHashMap;
use std::{
collections::hash_map,
io,
task::{Context, Poll},
};
use tokio::sync::oneshot;
@@ -29,11 +28,17 @@ impl<Resp> Default for InFlightRequests<Resp> {
}
}
/// The request exceeded its deadline.
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[error("the request exceeded its deadline")]
pub struct DeadlineExceededError;
#[derive(Debug)]
struct RequestData<Resp> {
ctx: context::Context,
span: Span,
response_completion: oneshot::Sender<Response<Resp>>,
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
/// The key to remove the timer for the request's deadline.
deadline_key: delay_queue::Key,
}
@@ -60,7 +65,7 @@ impl<Resp> InFlightRequests<Resp> {
request_id: u64,
ctx: context::Context,
span: Span,
response_completion: oneshot::Sender<Response<Resp>>,
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
) -> Result<(), AlreadyExistsError> {
match self.request_data.entry(request_id) {
hash_map::Entry::Vacant(vacant) => {
@@ -85,7 +90,7 @@ impl<Resp> InFlightRequests<Resp> {
tracing::info!("ReceiveResponse");
self.request_data.compact(0.1);
self.deadlines.remove(&request_data.deadline_key);
let _ = request_data.response_completion.send(response);
let _ = request_data.response_completion.send(Ok(response));
return true;
}
@@ -124,19 +129,9 @@ impl<Resp> InFlightRequests<Resp> {
self.request_data.compact(0.1);
let _ = request_data
.response_completion
.send(Self::deadline_exceeded_error(request_id));
.send(Err(DeadlineExceededError));
}
request_id
})
}
fn deadline_exceeded_error(request_id: u64) -> Response<Resp> {
Response {
request_id,
message: Err(ServerError {
kind: io::ErrorKind::TimedOut,
detail: Some("Client dropped expired request.".to_string()),
}),
}
}
}

View File

@@ -92,7 +92,7 @@ impl SpanExt for tracing::Span {
.with_remote_span_context(opentelemetry::trace::SpanContext::new(
opentelemetry::trace::TraceId::from(context.trace_context.trace_id),
opentelemetry::trace::SpanId::from(context.trace_context.span_id),
context.trace_context.sampling_decision as u8,
opentelemetry::trace::TraceFlags::from(context.trace_context.sampling_decision),
true,
opentelemetry::trace::TraceState::default(),
))

View File

@@ -27,7 +27,7 @@
//! process, and no context switching between different languages.
//!
//! Some other features of tarpc:
//! - Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
//! - Pluggable transport: any type implementing `Stream<Item = Request> + Sink<Response>` can be
//! used as a transport to connect the client and server.
//! - `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
//! - Cascading cancellation: dropping a request will send a cancellation message to the server.
@@ -42,7 +42,7 @@
//! [tracing](https://github.com/tokio-rs/tracing) primitives extended with
//! [OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
//! [Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
//! each RPC can be traced through the client, server, amd other dependencies downstream of the
//! each RPC can be traced through the client, server, and other dependencies downstream of the
//! server. Even for applications not connected to a distributed tracing collector, the
//! instrumentation can also be ingested by regular loggers like
//! [env_logger](https://github.com/env-logger-rs/env_logger/).
@@ -54,7 +54,7 @@
//! Add to your `Cargo.toml` dependencies:
//!
//! ```toml
//! tarpc = "0.26"
//! tarpc = "0.28"
//! ```
//!
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
@@ -67,8 +67,9 @@
//! your `Cargo.toml`:
//!
//! ```toml
//! futures = "1.0"
//! tarpc = { version = "0.26", features = ["tokio1"] }
//! anyhow = "1.0"
//! futures = "0.3"
//! tarpc = { version = "0.28", features = ["tokio1"] }
//! tokio = { version = "1.0", features = ["macros"] }
//! ```
//!
@@ -87,9 +88,8 @@
//! };
//! use tarpc::{
//! client, context,
//! server::{self, Incoming},
//! server::{self, incoming::Incoming, Channel},
//! };
//! use std::io;
//!
//! // This is the service definition. It looks a lot like a trait definition.
//! // It defines one RPC, hello, which takes one arg, name, and returns a String.
@@ -111,9 +111,8 @@
//! # };
//! # use tarpc::{
//! # client, context,
//! # server::{self, Incoming},
//! # server::{self, incoming::Incoming},
//! # };
//! # use std::io;
//! # // This is the service definition. It looks a lot like a trait definition.
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
//! # #[tarpc::service]
@@ -133,7 +132,7 @@
//! type HelloFut = Ready<String>;
//!
//! fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
//! future::ready(format!("Hello, {}!", name))
//! future::ready(format!("Hello, {name}!"))
//! }
//! }
//! ```
@@ -153,7 +152,6 @@
//! # client, context,
//! # server::{self, Channel},
//! # };
//! # use std::io;
//! # // This is the service definition. It looks a lot like a trait definition.
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
//! # #[tarpc::service]
@@ -170,14 +168,14 @@
//! # // an associated type representing the future output by the fn.
//! # type HelloFut = Ready<String>;
//! # fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
//! # future::ready(format!("Hello, {}!", name))
//! # future::ready(format!("Hello, {name}!"))
//! # }
//! # }
//! # #[cfg(not(feature = "tokio1"))]
//! # fn main() {}
//! # #[cfg(feature = "tokio1")]
//! #[tokio::main]
//! async fn main() -> io::Result<()> {
//! async fn main() -> anyhow::Result<()> {
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
//!
//! let server = server::BaseChannel::with_defaults(server_transport);
@@ -192,7 +190,7 @@
//! // specifies a deadline and trace information which can be helpful in debugging requests.
//! let hello = client.hello(context::current(), "Stim".to_string()).await?;
//!
//! println!("{}", hello);
//! println!("{hello}");
//!
//! Ok(())
//! }
@@ -266,7 +264,7 @@ pub use tarpc_plugins::service;
/// #[tarpc::server]
/// impl World for HelloServer {
/// async fn hello(self, _: context::Context, name: String) -> String {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// format!("Hello, {name}! You are connected from {:?}.", self.0)
/// }
/// }
/// ```
@@ -292,7 +290,7 @@ pub use tarpc_plugins::service;
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
/// + Send>> {
/// Box::pin(async move {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// format!("Hello, {name}! You are connected from {:?}.", self.0)
/// })
/// }
/// }
@@ -364,8 +362,9 @@ pub struct Response<T> {
pub message: Result<T, ServerError>,
}
/// An error response from a server to a client.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
/// An error indicating the server aborted the request early, e.g., due to request throttling.
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
#[error("{kind:?}: {detail}")]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct ServerError {
@@ -380,13 +379,7 @@ pub struct ServerError {
/// The type of error that occurred to fail the request.
pub kind: io::ErrorKind,
/// A message describing more detail about the error that occurred.
pub detail: Option<String>,
}
impl From<ServerError> for io::Error {
fn from(e: ServerError) -> io::Error {
io::Error::new(e.kind, e.detail.unwrap_or_default())
}
pub detail: String,
}
impl<T> Request<T> {

View File

@@ -42,12 +42,10 @@ where
type Item = io::Result<Item>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
self.project().inner.poll_next(cx).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("while reading from transport: {}", e.into()),
)
})
self.project()
.inner
.poll_next(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}
@@ -63,39 +61,31 @@ where
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().inner.poll_ready(cx).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("while readying write half of transport: {}", e.into()),
)
})
self.project()
.inner
.poll_ready(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.project().inner.start_send(item).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("while writing to transport: {}", e.into()),
)
})
self.project()
.inner
.start_send(item)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().inner.poll_flush(cx).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("while flushing transport: {}", e.into()),
)
})
self.project()
.inner
.poll_flush(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().inner.poll_close(cx).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("while closing write half of transport: {}", e.into()),
)
})
self.project()
.inner
.poll_close(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}

View File

@@ -10,6 +10,7 @@ use crate::{
context::{self, SpanExt},
trace, ClientMessage, Request, Response, Transport,
};
use ::tokio::sync::mpsc;
use futures::{
future::{AbortRegistration, Abortable},
prelude::*,
@@ -19,23 +20,23 @@ use futures::{
};
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
use pin_project::pin_project;
use std::{
convert::TryFrom, error::Error, fmt, hash::Hash, marker::PhantomData, pin::Pin,
time::SystemTime,
};
use tokio::sync::mpsc;
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin};
use tracing::{info_span, instrument::Instrument, Span};
mod filter;
mod in_flight_requests;
#[cfg(test)]
mod testing;
mod throttle;
pub use self::{
filter::ChannelFilter,
throttle::{Throttler, ThrottlerStream},
};
/// Provides functionality to apply server limits.
pub mod limits;
/// Provides helper methods for streams of Channels.
pub mod incoming;
/// Provides convenience functionality for tokio-enabled applications.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
pub mod tokio;
/// Settings that control the behavior of [channels](Channel).
#[derive(Clone, Debug)]
@@ -94,51 +95,13 @@ where
}
}
/// An extension trait for [streams](Stream) of [`Channels`](Channel).
pub trait Incoming<C>
where
Self: Sized + Stream<Item = C>,
C: Channel,
{
/// Enforces channel per-key limits.
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> filter::ChannelFilter<Self, K, KF>
where
K: fmt::Display + Eq + Hash + Clone + Unpin,
KF: Fn(&C) -> K,
{
ChannelFilter::new(self, n, keymaker)
}
/// Caps the number of concurrent requests per channel.
fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream<Self> {
ThrottlerStream::new(self, n)
}
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
/// concurrently by spawning on tokio's default executor, and each request will be also
/// be spawned on tokio's default executor.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
where
S: Serve<C::Req, Resp = C::Resp>,
{
TokioServerExecutor { inner: self, serve }
}
}
impl<S, C> Incoming<C> for S
where
S: Sized + Stream<Item = C>,
C: Channel,
{
}
/// BaseChannel is a [Transport] that keeps track of in-flight requests. It converts a
/// [`Transport`](Transport) of [`ClientMessages`](ClientMessage) into a stream of
/// [requests](ClientMessage::Request).
/// BaseChannel is the standard implementation of a [`Channel`].
///
/// Besides requests, the other type of client message is [cancellation
/// BaseChannel manages a [`Transport`](Transport) of client [`messages`](ClientMessage) and
/// implements a [`Stream`] of [requests](TrackedRequest). See the [`Channel`] documentation for
/// how to use channels.
///
/// Besides requests, the other type of client message handled by `BaseChannel` is [cancellation
/// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation
/// messages. Instead, it internally handles them by cancelling corresponding requests (removing
/// the corresponding in-flight requests and aborting their handlers).
@@ -190,6 +153,47 @@ where
fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<T>> {
self.as_mut().project().transport
}
fn start_request(
self: Pin<&mut Self>,
mut request: Request<Req>,
) -> Result<TrackedRequest<Req>, AlreadyExistsError> {
let span = info_span!(
"RPC",
rpc.trace_id = %request.context.trace_id(),
otel.kind = "server",
otel.name = tracing::field::Empty,
);
span.set_context(&request.context);
request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
tracing::trace!(
"OpenTelemetry subscriber not installed; making unsampled \
child context."
);
request.context.trace_context.new_child()
});
let entered = span.enter();
tracing::info!("ReceiveRequest");
let start = self.project().in_flight_requests.start_request(
request.id,
request.context.deadline,
span.clone(),
);
match start {
Ok(abort_registration) => {
drop(entered);
Ok(TrackedRequest {
request,
abort_registration,
span,
})
}
Err(AlreadyExistsError) => {
tracing::trace!("DuplicateRequest");
Err(AlreadyExistsError)
}
}
}
}
impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
@@ -198,8 +202,20 @@ impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
}
}
/// The server end of an open connection with a client, streaming in requests from, and sinking
/// responses to, the client.
/// A request tracked by a [`Channel`].
#[derive(Debug)]
pub struct TrackedRequest<Req> {
/// The request sent by the client.
pub request: Request<Req>,
/// A registration to abort a future when the [`Channel`] that produced this request stops
/// tracking it.
pub abort_registration: AbortRegistration,
/// A span representing the server processing of this request.
pub span: Span,
}
/// The server end of an open connection with a client, receiving requests from, and sending
/// responses to, the client. `Channel` is a [`Transport`] with request lifecycle management.
///
/// The ways to use a Channel, in order of simplest to most complex, is:
/// 1. [`Channel::execute`] - Requires the `tokio1` feature. This method is best for those who
@@ -210,18 +226,20 @@ impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
/// [`execute`](InFlightRequest::execute) method. If using `execute`, request processing will
/// automatically cease when either the request deadline is reached or when a corresponding
/// cancellation message is received by the Channel.
/// 3. [`Sink::start_send`] - A user is free to manually send responses to requests produced by a
/// Channel using [`Sink::start_send`] in lieu of the previous methods. If not using one of the
/// previous execute methods, then nothing will automatically cancel requests or set up the
/// request context. However, the Channel will still clean up resources upon deadline expiration
/// or cancellation. In the case that the Channel cleans up resources related to a request
/// before the response is sent, the response can still be sent into the Channel later on.
/// Because there is no guarantee that a cancellation message will ever be received for a
/// request, or that requests come with reasonably short deadlines, services should strive to
/// clean up Channel resources by sending a response for every request.
/// 3. [`Stream::next`](futures::stream::StreamExt::next) /
/// [`Sink::send`](futures::sink::SinkExt::send) - A user is free to manually read requests
/// from, and send responses into, a Channel in lieu of the previous methods. Channels stream
/// [`TrackedRequests`](TrackedRequest), which, in addition to the request itself, contains the
/// server [`Span`] and request lifetime [`AbortRegistration`]. Wrapping response
/// logic in an [`Abortable`] future using the abort registration will ensure that the response
/// does not execute longer than the request deadline. The `Channel` itself will clean up
/// request state once either the deadline expires, or a cancellation message is received, or a
/// response is sent. Because there is no guarantee that a cancellation message will ever be
/// received for a request, or that requests come with reasonably short deadlines, services
/// should strive to clean up Channel resources by sending a response for every request.
pub trait Channel
where
Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
Self: Transport<Response<<Self as Channel>::Resp>, TrackedRequest<<Self as Channel>::Req>>,
{
/// Type of request item.
type Req;
@@ -241,24 +259,23 @@ where
/// Returns the transport underlying the channel.
fn transport(&self) -> &Self::Transport;
/// Caps the number of concurrent requests to `limit`.
fn max_concurrent_requests(self, limit: usize) -> Throttler<Self>
/// Caps the number of concurrent requests to `limit`. An error will be returned for requests
/// over the concurrency limit.
///
/// Note that this is a very
/// simplistic throttling heuristic. It is easy to set a number that is too low for the
/// resources available to the server. For production use cases, a more advanced throttler is
/// likely needed.
fn max_concurrent_requests(
self,
limit: usize,
) -> limits::requests_per_channel::MaxRequests<Self>
where
Self: Sized,
{
Throttler::new(self, limit)
limits::requests_per_channel::MaxRequests::new(self, limit)
}
/// Tells the Channel that request with ID `request_id` is being handled.
/// The request will be tracked until a response with the same ID is sent
/// to the Channel or the deadline expires, whichever happens first.
fn start_request(
self: Pin<&mut Self>,
request_id: u64,
deadline: SystemTime,
span: Span,
) -> Result<AbortRegistration, AlreadyExistsError>;
/// Returns a stream of requests that automatically handle request cancellation and response
/// routing.
///
@@ -279,11 +296,11 @@ where
}
/// Runs the channel until completion by executing all requests using the given service
/// function. Request handlers are run concurrently by [spawning](tokio::spawn) on tokio's
/// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's
/// default executor.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
fn execute<S>(self, serve: S) -> TokioChannelExecutor<Requests<Self>, S>
fn execute<S>(self, serve: S) -> self::tokio::TokioChannelExecutor<Requests<Self>, S>
where
Self: Sized,
S: Serve<Self::Req, Resp = Self::Resp> + Send + 'static,
@@ -306,16 +323,17 @@ where
Transport(#[source] E),
/// An error occurred while polling expired requests.
#[error("an error occurred while polling expired requests: {0}")]
Timer(#[source] tokio::time::error::Error),
Timer(#[source] ::tokio::time::error::Error),
}
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
{
type Item = Result<Request<Req>, ChannelError<T::Error>>;
type Item = Result<TrackedRequest<Req>, ChannelError<T::Error>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
#[derive(Debug)]
enum ReceiverStatus {
Ready,
Pending,
@@ -343,7 +361,16 @@ where
{
Poll::Ready(Some(message)) => match message {
ClientMessage::Request(request) => {
return Poll::Ready(Some(Ok(request)));
match self.as_mut().start_request(request) {
Ok(request) => return Poll::Ready(Some(Ok(request))),
Err(AlreadyExistsError) => {
// Instead of closing the channel if a duplicate request is sent,
// just ignore it, since it's already being processed. Note that we
// cannot return Poll::Pending here, since nothing has scheduled a
// wakeup yet.
continue;
}
}
}
ClientMessage::Cancel {
trace_context,
@@ -362,6 +389,11 @@ where
Poll::Pending => Pending,
};
tracing::trace!(
"Expired requests: {:?}, Inbound: {:?}",
expiration_status,
request_status
);
match (expiration_status, request_status) {
(Ready, _) | (_, Ready) => continue,
(Closed, Closed) => return Poll::Ready(None),
@@ -405,6 +437,7 @@ where
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
tracing::trace!("poll_flush");
self.project()
.transport
.poll_flush(cx)
@@ -444,17 +477,6 @@ where
fn transport(&self) -> &Self::Transport {
self.get_ref()
}
fn start_request(
self: Pin<&mut Self>,
request_id: u64,
deadline: SystemTime,
span: Span,
) -> Result<AbortRegistration, AlreadyExistsError> {
self.project()
.in_flight_requests
.start_request(request_id, deadline, span)
}
}
/// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so
@@ -492,54 +514,12 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<InFlightRequest<C::Req, C::Resp>, C::Error>>> {
loop {
match ready!(self.channel_pin_mut().poll_next(cx)?) {
Some(mut request) => {
let span = info_span!(
"RPC",
rpc.trace_id = %request.context.trace_id(),
otel.kind = "server",
otel.name = tracing::field::Empty,
);
span.set_context(&request.context);
request.context.trace_context =
trace::Context::try_from(&span).unwrap_or_else(|_| {
tracing::trace!(
"OpenTelemetry subscriber not installed; making unsampled
child context."
);
request.context.trace_context.new_child()
});
let entered = span.enter();
tracing::info!("ReceiveRequest");
let start = self.channel_pin_mut().start_request(
request.id,
request.context.deadline,
span.clone(),
);
match start {
Ok(abort_registration) => {
let response_tx = self.responses_tx.clone();
drop(entered);
return Poll::Ready(Some(Ok(InFlightRequest {
request,
response_tx,
abort_registration,
span,
})));
}
// Instead of closing the channel if a duplicate request is sent, just
// ignore it, since it's already being processed. Note that we cannot
// return Poll::Pending here, since nothing has scheduled a wakeup yet.
Err(AlreadyExistsError) => {
tracing::trace!("DuplicateRequest");
continue;
}
}
}
None => return Poll::Ready(None),
}
}
self.channel_pin_mut()
.poll_next(cx)
.map_ok(|request| InFlightRequest {
request,
response_tx: self.responses_tx.clone(),
})
}
fn pump_write(
@@ -619,16 +599,14 @@ where
/// A request produced by [Channel::requests].
#[derive(Debug)]
pub struct InFlightRequest<Req, Res> {
request: Request<Req>,
request: TrackedRequest<Req>,
response_tx: mpsc::Sender<Response<Res>>,
abort_registration: AbortRegistration,
span: Span,
}
impl<Req, Res> InFlightRequest<Req, Res> {
/// Returns a reference to the request.
pub fn get(&self) -> &Request<Req> {
&self.request
&self.request.request
}
/// Returns a [future](Future) that executes the request using the given [service
@@ -647,15 +625,18 @@ impl<Req, Res> InFlightRequest<Req, Res> {
S: Serve<Req, Resp = Res>,
{
let Self {
abort_registration,
request:
Request {
context,
message,
id: request_id,
},
response_tx,
span,
request:
TrackedRequest {
abort_registration,
span,
request:
Request {
context,
message,
id: request_id,
},
},
} = self;
let method = serve.method(&message);
span.record("otel.name", &method.unwrap_or(""));
@@ -704,129 +685,22 @@ where
}
}
// Send + 'static execution helper methods.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
impl<C> Requests<C>
where
C: Channel,
C::Req: Send + 'static,
C::Resp: Send + 'static,
{
/// Executes all requests using the given service function. Requests are handled concurrently
/// by [spawning](tokio::spawn) each handler on tokio's default executor.
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
where
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
{
TokioChannelExecutor { inner: self, serve }
}
}
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
/// for each new channel.
#[pin_project]
#[derive(Debug)]
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
pub struct TokioServerExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
/// A future that drives the server by [spawning](tokio::spawn) each [response
/// handler](InFlightRequest::execute) on tokio's default executor.
#[pin_project]
#[derive(Debug)]
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
pub struct TokioChannelExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
impl<T, S> TokioServerExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
impl<T, S> TokioChannelExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
#[cfg(feature = "tokio1")]
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
where
St: Sized + Stream<Item = C>,
C: Channel + Send + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
Se::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
tokio::spawn(channel.execute(self.serve.clone()));
}
tracing::info!("Server shutting down.");
Poll::Ready(())
}
}
#[cfg(feature = "tokio1")]
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
where
C: Channel + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
S::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
match response_handler {
Ok(resp) => {
let server = self.serve.clone();
tokio::spawn(async move {
resp.execute(server).await;
});
}
Err(e) => {
tracing::warn!("Requests stream errored out: {}", e);
break;
}
}
}
Poll::Ready(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests};
use crate::{
trace,
context, trace,
transport::channel::{self, UnboundedChannel},
ClientMessage, Request, Response,
};
use assert_matches::assert_matches;
use futures::future::{pending, Aborted};
use futures::{
future::{pending, AbortRegistration, Abortable, Aborted},
prelude::*,
Future,
};
use futures_test::task::noop_context;
use std::time::Duration;
use std::{pin::Pin, task::Poll};
fn test_channel<Req, Resp>() -> (
Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
@@ -892,12 +766,18 @@ mod tests {
channel
.as_mut()
.start_request(0, SystemTime::now(), Span::current())
.start_request(Request {
id: 0,
context: context::current(),
message: (),
})
.unwrap();
assert_matches!(
channel
.as_mut()
.start_request(0, SystemTime::now(), Span::current()),
channel.as_mut().start_request(Request {
id: 0,
context: context::current(),
message: ()
}),
Err(AlreadyExistsError)
);
}
@@ -907,13 +787,21 @@ mod tests {
let (mut channel, _tx) = test_channel::<(), ()>();
tokio::time::pause();
let abort_registration0 = channel
let req0 = channel
.as_mut()
.start_request(0, SystemTime::now(), Span::current())
.start_request(Request {
id: 0,
context: context::current(),
message: (),
})
.unwrap();
let abort_registration1 = channel
let req1 = channel
.as_mut()
.start_request(1, SystemTime::now(), Span::current())
.start_request(Request {
id: 1,
context: context::current(),
message: (),
})
.unwrap();
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
@@ -921,8 +809,8 @@ mod tests {
channel.as_mut().poll_next(&mut noop_context()),
Poll::Pending
);
assert_matches!(test_abortable(abort_registration0).await, Err(Aborted));
assert_matches!(test_abortable(abort_registration1).await, Err(Aborted));
assert_matches!(test_abortable(req0.abort_registration).await, Err(Aborted));
assert_matches!(test_abortable(req1.abort_registration).await, Err(Aborted));
}
#[tokio::test]
@@ -930,13 +818,13 @@ mod tests {
let (mut channel, mut tx) = test_channel::<(), ()>();
tokio::time::pause();
let abort_registration = channel
let req = channel
.as_mut()
.start_request(
0,
SystemTime::now() + Duration::from_millis(100),
Span::current(),
)
.start_request(Request {
id: 0,
context: context::current(),
message: (),
})
.unwrap();
tx.send(ClientMessage::Cancel {
@@ -951,7 +839,7 @@ mod tests {
Poll::Pending
);
assert_matches!(test_abortable(abort_registration).await, Err(Aborted));
assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
}
#[tokio::test]
@@ -961,11 +849,11 @@ mod tests {
tokio::time::pause();
let _abort_registration = channel
.as_mut()
.start_request(
0,
SystemTime::now() + Duration::from_millis(100),
Span::current(),
)
.start_request(Request {
id: 0,
context: context::current(),
message: (),
})
.unwrap();
drop(tx);
@@ -1001,9 +889,13 @@ mod tests {
let (mut channel, mut tx) = test_channel::<(), ()>();
tokio::time::pause();
let abort_registration = channel
let req = channel
.as_mut()
.start_request(0, SystemTime::now(), Span::current())
.start_request(Request {
id: 0,
context: context::current(),
message: (),
})
.unwrap();
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
@@ -1013,7 +905,7 @@ mod tests {
channel.as_mut().poll_next(&mut noop_context()),
Poll::Ready(Some(Ok(_)))
);
assert_matches!(test_abortable(abort_registration).await, Err(Aborted));
assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
}
#[tokio::test]
@@ -1022,7 +914,11 @@ mod tests {
channel
.as_mut()
.start_request(0, SystemTime::now(), Span::current())
.start_request(Request {
id: 0,
context: context::current(),
message: (),
})
.unwrap();
assert_eq!(channel.in_flight_requests(), 1);
channel
@@ -1043,7 +939,11 @@ mod tests {
requests
.as_mut()
.channel_pin_mut()
.start_request(0, SystemTime::now(), Span::current())
.start_request(Request {
id: 0,
context: context::current(),
message: (),
})
.unwrap();
requests
.as_mut()
@@ -1069,7 +969,11 @@ mod tests {
requests
.as_mut()
.channel_pin_mut()
.start_request(1, SystemTime::now(), Span::current())
.start_request(Request {
id: 1,
context: context::current(),
message: (),
})
.unwrap();
assert_matches!(
@@ -1086,7 +990,11 @@ mod tests {
requests
.as_mut()
.channel_pin_mut()
.start_request(0, SystemTime::now(), Span::current())
.start_request(Request {
id: 0,
context: context::current(),
message: (),
})
.unwrap();
requests
.as_mut()
@@ -1101,7 +1009,11 @@ mod tests {
requests
.as_mut()
.channel_pin_mut()
.start_request(1, SystemTime::now(), Span::current())
.start_request(Request {
id: 1,
context: context::current(),
message: (),
})
.unwrap();
requests
.as_mut()

View File

@@ -98,6 +98,11 @@ impl InFlightRequests {
&mut self,
cx: &mut Context,
) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
if self.deadlines.is_empty() {
// TODO(https://github.com/tokio-rs/tokio/issues/4161)
// This is a workaround for DelayQueue not always treating this case correctly.
return Poll::Ready(None);
}
self.deadlines.poll_expired(cx).map_ok(|expired| {
if let Some(RequestData {
abort_handle, span, ..
@@ -184,12 +189,31 @@ mod tests {
#[tokio::test]
async fn remove_request_doesnt_abort() {
let mut in_flight_requests = InFlightRequests::default();
assert!(in_flight_requests.deadlines.is_empty());
let abort_registration = in_flight_requests
.start_request(0, SystemTime::now(), Span::current())
.start_request(
0,
SystemTime::now() + std::time::Duration::from_secs(10),
Span::current(),
)
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
// Precondition: Pending expiration
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Pending
);
assert!(!in_flight_requests.deadlines.is_empty());
assert_matches!(in_flight_requests.remove_request(0), Some(_));
// Postcondition: No pending expirations
assert!(in_flight_requests.deadlines.is_empty());
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Ready(None)
);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Pending

View File

@@ -0,0 +1,49 @@
use super::{
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
Channel,
};
use futures::prelude::*;
use std::{fmt, hash::Hash};
#[cfg(feature = "tokio1")]
use super::{tokio::TokioServerExecutor, Serve};
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
pub trait Incoming<C>
where
Self: Sized + Stream<Item = C>,
C: Channel,
{
/// Enforces channel per-key limits.
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> MaxChannelsPerKey<Self, K, KF>
where
K: fmt::Display + Eq + Hash + Clone + Unpin,
KF: Fn(&C) -> K,
{
MaxChannelsPerKey::new(self, n, keymaker)
}
/// Caps the number of concurrent requests per channel.
fn max_concurrent_requests_per_channel(self, n: usize) -> MaxRequestsPerChannel<Self> {
MaxRequestsPerChannel::new(self, n)
}
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
/// concurrently by spawning on tokio's default executor, and each request will be also
/// be spawned on tokio's default executor.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
where
S: Serve<C::Req, Resp = C::Resp>,
{
TokioServerExecutor::new(self, serve)
}
}
impl<S, C> Incoming<C> for S
where
S: Sized + Stream<Item = C>,
C: Channel,
{
}

View File

@@ -0,0 +1,5 @@
/// Provides functionality to limit the number of active channels.
pub mod channels_per_key;
/// Provides a [channel](crate::server::Channel) that limits the number of in-flight requests.
pub mod requests_per_channel;

View File

@@ -9,20 +9,23 @@ use crate::{
util::Compact,
};
use fnv::FnvHashMap;
use futures::{future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
use futures::{prelude::*, ready, stream::Fuse, task::*};
use pin_project::pin_project;
use std::sync::{Arc, Weak};
use std::{
collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin,
time::SystemTime,
};
use tokio::sync::mpsc;
use tracing::{debug, info, trace};
/// A single-threaded filter that drops channels based on per-key limits.
/// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on
/// per-key limits.
///
/// The decision to drop a Channel is made once at the time the Channel materializes. Once a
/// Channel is yielded, it will not be prematurely dropped.
#[pin_project]
#[derive(Debug)]
pub struct ChannelFilter<S, K, F>
pub struct MaxChannelsPerKey<S, K, F>
where
K: Eq + Hash,
{
@@ -35,7 +38,7 @@ where
keymaker: F,
}
/// A channel that is tracked by a ChannelFilter.
/// A channel that is tracked by [`MaxChannelsPerKey`].
#[pin_project]
#[derive(Debug)]
pub struct TrackedChannel<C, K> {
@@ -116,15 +119,6 @@ where
fn transport(&self) -> &Self::Transport {
self.inner.transport()
}
fn start_request(
mut self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
span: tracing::Span,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
self.inner_pin_mut().start_request(id, deadline, span)
}
}
impl<C, K> TrackedChannel<C, K> {
@@ -139,7 +133,7 @@ impl<C, K> TrackedChannel<C, K> {
}
}
impl<S, K, F> ChannelFilter<S, K, F>
impl<S, K, F> MaxChannelsPerKey<S, K, F>
where
K: Eq + Hash,
S: Stream,
@@ -148,7 +142,7 @@ where
/// Sheds new channels to stay under configured limits.
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel();
ChannelFilter {
MaxChannelsPerKey {
listener: listener.fuse(),
channels_per_key,
dropped_keys,
@@ -159,7 +153,7 @@ where
}
}
impl<S, K, F> ChannelFilter<S, K, F>
impl<S, K, F> MaxChannelsPerKey<S, K, F>
where
S: Stream,
K: fmt::Display + Eq + Hash + Clone + Unpin,
@@ -251,7 +245,7 @@ where
}
}
impl<S, K, F> Stream for ChannelFilter<S, K, F>
impl<S, K, F> Stream for MaxChannelsPerKey<S, K, F>
where
S: Stream,
K: fmt::Display + Eq + Hash + Clone + Unpin,
@@ -354,7 +348,7 @@ fn channel_filter_increment_channels_for_key() {
key: &'static str,
}
let (_, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
assert_eq!(Arc::strong_count(&tracker1), 1);
@@ -375,7 +369,7 @@ fn channel_filter_handle_new_channel() {
key: &'static str,
}
let (_, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
let channel1 = filter
.as_mut()
@@ -407,7 +401,7 @@ fn channel_filter_poll_listener() {
key: &'static str,
}
let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
@@ -443,7 +437,7 @@ fn channel_filter_poll_closed_channels() {
key: &'static str,
}
let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
@@ -471,7 +465,7 @@ fn channel_filter_stream() {
key: &'static str,
}
let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels

View File

@@ -4,45 +4,49 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use super::{Channel, Config};
use crate::{Response, ServerError};
use futures::{future::AbortRegistration, prelude::*, ready, task::*};
use crate::{
server::{Channel, Config},
Response, ServerError,
};
use futures::{prelude::*, ready, task::*};
use pin_project::pin_project;
use std::{io, pin::Pin, time::SystemTime};
use tracing::Span;
use std::{io, pin::Pin};
/// A [`Channel`] that limits the number of concurrent
/// requests by throttling.
/// A [`Channel`] that limits the number of concurrent requests by throttling.
///
/// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low
/// for the resources available to the server. For production use cases, a more advanced throttler
/// is likely needed.
#[pin_project]
#[derive(Debug)]
pub struct Throttler<C> {
pub struct MaxRequests<C> {
max_in_flight_requests: usize,
#[pin]
inner: C,
}
impl<C> Throttler<C> {
impl<C> MaxRequests<C> {
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
}
}
impl<C> Throttler<C>
impl<C> MaxRequests<C>
where
C: Channel,
{
/// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to
/// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to
/// `max_in_flight_requests`.
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
Throttler {
MaxRequests {
max_in_flight_requests,
inner,
}
}
}
impl<C> Stream for Throttler<C>
impl<C> Stream for MaxRequests<C>
where
C: Channel,
{
@@ -54,19 +58,18 @@ where
ready!(self.as_mut().project().inner.poll_ready(cx)?);
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
Some(request) => {
tracing::debug!(
rpc.trace_id = %request.context.trace_id(),
Some(r) => {
let _entered = r.span.enter();
tracing::info!(
in_flight_requests = self.as_mut().in_flight_requests(),
max_in_flight_requests = *self.as_mut().project().max_in_flight_requests,
"At in-flight request limit",
"ThrottleRequest",
);
self.as_mut().start_send(Response {
request_id: request.id,
request_id: r.request.id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()),
detail: "server throttled the request.".into(),
}),
})?;
}
@@ -77,7 +80,7 @@ where
}
}
impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
impl<C> Sink<Response<<C as Channel>::Resp>> for MaxRequests<C>
where
C: Channel,
{
@@ -103,13 +106,13 @@ where
}
}
impl<C> AsRef<C> for Throttler<C> {
impl<C> AsRef<C> for MaxRequests<C> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C> Channel for Throttler<C>
impl<C> Channel for MaxRequests<C>
where
C: Channel,
{
@@ -128,27 +131,19 @@ where
fn transport(&self) -> &Self::Transport {
self.inner.transport()
}
fn start_request(
self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
span: Span,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
self.project().inner.start_request(id, deadline, span)
}
}
/// A stream of throttling channels.
/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on
/// the number of in-flight requests.
#[pin_project]
#[derive(Debug)]
pub struct ThrottlerStream<S> {
pub struct MaxRequestsPerChannel<S> {
#[pin]
inner: S,
max_in_flight_requests: usize,
}
impl<S> ThrottlerStream<S>
impl<S> MaxRequestsPerChannel<S>
where
S: Stream,
<S as Stream>::Item: Channel,
@@ -161,16 +156,16 @@ where
}
}
impl<S> Stream for ThrottlerStream<S>
impl<S> Stream for MaxRequestsPerChannel<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
type Item = Throttler<<S as Stream>::Item>;
type Item = MaxRequests<<S as Stream>::Item>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match ready!(self.as_mut().project().inner.poll_next(cx)) {
Some(channel) => Poll::Ready(Some(Throttler::new(
Some(channel) => Poll::Ready(Some(MaxRequests::new(
channel,
*self.project().max_in_flight_requests,
))),
@@ -183,19 +178,20 @@ where
mod tests {
use super::*;
use crate::{
server::{
in_flight_requests::AlreadyExistsError,
testing::{self, FakeChannel, PollExt},
},
Request,
use crate::server::{
testing::{self, FakeChannel, PollExt},
TrackedRequest,
};
use pin_utils::pin_mut;
use std::{marker::PhantomData, time::Duration};
use std::{
marker::PhantomData,
time::{Duration, SystemTime},
};
use tracing::Span;
#[tokio::test]
async fn throttler_in_flight_requests() {
let throttler = Throttler {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
@@ -215,28 +211,9 @@ mod tests {
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
}
#[tokio::test]
async fn throttler_start_request() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler
.as_mut()
.start_request(
1,
SystemTime::now() + Duration::from_secs(1),
Span::current(),
)
.unwrap();
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
}
#[test]
fn throttler_poll_next_done() {
let throttler = Throttler {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
@@ -247,7 +224,7 @@ mod tests {
#[test]
fn throttler_poll_next_some() -> io::Result<()> {
let throttler = Throttler {
let throttler = MaxRequests {
max_in_flight_requests: 1,
inner: FakeChannel::default::<isize, isize>(),
};
@@ -259,7 +236,7 @@ mod tests {
throttler
.as_mut()
.poll_next(&mut testing::cx())?
.map(|r| r.map(|r| (r.id, r.message))),
.map(|r| r.map(|r| (r.request.id, r.request.message))),
Poll::Ready(Some((0, 1)))
);
Ok(())
@@ -267,7 +244,7 @@ mod tests {
#[test]
fn throttler_poll_next_throttled() {
let throttler = Throttler {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
@@ -283,7 +260,7 @@ mod tests {
#[test]
fn throttler_poll_next_throttled_sink_not_ready() {
let throttler = Throttler {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: PendingSink::default::<isize, isize>(),
};
@@ -294,7 +271,8 @@ mod tests {
ghost: PhantomData<fn(Out) -> In>,
}
impl PendingSink<(), ()> {
pub fn default<Req, Resp>() -> PendingSink<io::Result<Request<Req>>, Response<Resp>> {
pub fn default<Req, Resp>(
) -> PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
PendingSink { ghost: PhantomData }
}
}
@@ -319,7 +297,7 @@ mod tests {
Poll::Pending
}
}
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>> {
impl<Req, Resp> Channel for PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
type Req = Req;
type Resp = Resp;
type Transport = ();
@@ -332,20 +310,12 @@ mod tests {
fn transport(&self) -> &() {
&()
}
fn start_request(
self: Pin<&mut Self>,
_id: u64,
_deadline: SystemTime,
_span: tracing::Span,
) -> Result<AbortRegistration, AlreadyExistsError> {
unimplemented!()
}
}
}
#[tokio::test]
async fn throttler_start_send() {
let throttler = Throttler {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};

View File

@@ -6,10 +6,10 @@
use crate::{
context,
server::{Channel, Config},
server::{Channel, Config, TrackedRequest},
Request, Response,
};
use futures::{future::AbortRegistration, task::*, Sink, Stream};
use futures::{task::*, Sink, Stream};
use pin_project::pin_project;
use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
use tracing::Span;
@@ -62,7 +62,7 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
}
}
impl<Req, Resp> Channel for FakeChannel<io::Result<Request<Req>>, Response<Resp>>
impl<Req, Resp> Channel for FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>>
where
Req: Unpin,
{
@@ -81,34 +81,28 @@ where
fn transport(&self) -> &() {
&()
}
fn start_request(
self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
span: Span,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
self.project()
.in_flight_requests
.start_request(id, deadline, span)
}
}
impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
pub fn push_req(&mut self, id: u64, message: Req) {
self.stream.push_back(Ok(Request {
context: context::Context {
deadline: SystemTime::UNIX_EPOCH,
trace_context: Default::default(),
let (_, abort_registration) = futures::future::AbortHandle::new_pair();
self.stream.push_back(Ok(TrackedRequest {
request: Request {
context: context::Context {
deadline: SystemTime::UNIX_EPOCH,
trace_context: Default::default(),
},
id,
message,
},
id,
message,
abort_registration,
span: Span::none(),
}));
}
}
impl FakeChannel<(), ()> {
pub fn default<Req, Resp>() -> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
pub fn default<Req, Resp>() -> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
FakeChannel {
stream: Default::default(),
sink: Default::default(),

111
tarpc/src/server/tokio.rs Normal file
View File

@@ -0,0 +1,111 @@
use super::{Channel, Requests, Serve};
use futures::{prelude::*, ready, task::*};
use pin_project::pin_project;
use std::pin::Pin;
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
/// for each new channel. Returned by
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
#[pin_project]
#[derive(Debug)]
pub struct TokioServerExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
impl<T, S> TokioServerExecutor<T, S> {
pub(crate) fn new(inner: T, serve: S) -> Self {
Self { inner, serve }
}
}
/// A future that drives the server by [spawning](tokio::spawn) each [response
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
/// [`Channel::execute`](crate::server::Channel::execute).
#[pin_project]
#[derive(Debug)]
pub struct TokioChannelExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
impl<T, S> TokioServerExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
impl<T, S> TokioChannelExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
// Send + 'static execution helper methods.
impl<C> Requests<C>
where
C: Channel,
C::Req: Send + 'static,
C::Resp: Send + 'static,
{
/// Executes all requests using the given service function. Requests are handled concurrently
/// by [spawning](::tokio::spawn) each handler on tokio's default executor.
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
where
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
{
TokioChannelExecutor { inner: self, serve }
}
}
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
where
St: Sized + Stream<Item = C>,
C: Channel + Send + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
Se::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
tokio::spawn(channel.execute(self.serve.clone()));
}
tracing::info!("Server shutting down.");
Poll::Ready(())
}
}
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
where
C: Channel + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
S::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
match response_handler {
Ok(resp) => {
let server = self.serve.clone();
tokio::spawn(async move {
resp.execute(server).await;
});
}
Err(e) => {
tracing::warn!("Requests stream errored out: {}", e);
break;
}
}
}
Poll::Ready(())
}
}

View File

@@ -71,9 +71,9 @@ pub struct SpanId(u64);
#[repr(u8)]
pub enum SamplingDecision {
/// The associated span was sampled by its creating process. Child spans must also be sampled.
Sampled = opentelemetry::trace::TRACE_FLAG_SAMPLED,
Sampled,
/// The associated span was not sampled by its creating process.
Unsampled = opentelemetry::trace::TRACE_FLAG_NOT_SAMPLED,
Unsampled,
}
impl Context {
@@ -138,25 +138,25 @@ impl From<u64> for SpanId {
impl From<opentelemetry::trace::TraceId> for TraceId {
fn from(trace_id: opentelemetry::trace::TraceId) -> Self {
Self::from(trace_id.to_u128())
Self::from(u128::from_be_bytes(trace_id.to_bytes()))
}
}
impl From<TraceId> for opentelemetry::trace::TraceId {
fn from(trace_id: TraceId) -> Self {
Self::from_u128(trace_id.into())
Self::from_bytes(u128::from(trace_id).to_be_bytes())
}
}
impl From<opentelemetry::trace::SpanId> for SpanId {
fn from(span_id: opentelemetry::trace::SpanId) -> Self {
Self::from(span_id.to_u64())
Self::from(u64::from_be_bytes(span_id.to_bytes()))
}
}
impl From<SpanId> for opentelemetry::trace::SpanId {
fn from(span_id: SpanId) -> Self {
Self::from_u64(span_id.0)
Self::from_bytes(u64::from(span_id).to_be_bytes())
}
}
@@ -173,8 +173,8 @@ impl TryFrom<&tracing::Span> for Context {
}
}
impl From<&dyn opentelemetry::trace::Span> for Context {
fn from(span: &dyn opentelemetry::trace::Span) -> Self {
impl From<opentelemetry::trace::SpanRef<'_>> for Context {
fn from(span: opentelemetry::trace::SpanRef<'_>) -> Self {
let otel_ctx = span.span_context();
Self {
trace_id: TraceId::from(otel_ctx.trace_id()),
@@ -184,6 +184,15 @@ impl From<&dyn opentelemetry::trace::Span> for Context {
}
}
impl From<SamplingDecision> for opentelemetry::trace::TraceFlags {
fn from(decision: SamplingDecision) -> Self {
match decision {
SamplingDecision::Sampled => opentelemetry::trace::TraceFlags::SAMPLED,
SamplingDecision::Unsampled => opentelemetry::trace::TraceFlags::default(),
}
}
}
impl From<&opentelemetry::trace::SpanContext> for SamplingDecision {
fn from(context: &opentelemetry::trace::SpanContext) -> Self {
if context.is_sampled() {

View File

@@ -151,7 +151,7 @@ impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
mod tests {
use crate::{
client, context,
server::{BaseChannel, Incoming},
server::{incoming::Incoming, BaseChannel},
transport::{
self,
channel::{Channel, UnboundedChannel},
@@ -170,7 +170,7 @@ mod tests {
}
#[tokio::test]
async fn integration() -> io::Result<()> {
async fn integration() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (client_channel, server_channel) = transport::channel::unbounded();
@@ -181,7 +181,7 @@ mod tests {
future::ready(request.parse::<u64>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("{:?} is not an int", request),
format!("{request:?} is not an int"),
)
}))
}),

View File

@@ -38,11 +38,34 @@ where
H: BuildHasher,
{
fn compact(&mut self, usage_ratio_threshold: f64) {
if self.capacity() > 1000 {
let usage_ratio = self.len() as f64 / self.capacity() as f64;
if usage_ratio < usage_ratio_threshold {
self.shrink_to_fit();
}
}
let usage_ratio_threshold = usage_ratio_threshold.clamp(f64::MIN_POSITIVE, 1.);
let cap = f64::max(1000., self.len() as f64 / usage_ratio_threshold);
self.shrink_to(cap as usize);
}
}
#[test]
fn test_compact() {
let mut map = HashMap::with_capacity(2048);
assert_eq!(map.capacity(), 3584);
// Make usage ratio 25%
for i in 0..896 {
map.insert(format!("k{i}"), "v");
}
map.compact(-1.0);
assert_eq!(map.capacity(), 3584);
map.compact(0.25);
assert_eq!(map.capacity(), 3584);
map.compact(0.50);
assert_eq!(map.capacity(), 1792);
map.compact(1.0);
assert_eq!(map.capacity(), 1792);
map.compact(2.0);
assert_eq!(map.capacity(), 1792);
}

View File

@@ -7,8 +7,8 @@ struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
fn hello(name: String) -> String {
format!("Hello, {}!", name)
fn hello(name: String) -> String {
format!("Hello, {name}!", name)
}
}

View File

@@ -7,5 +7,5 @@ error: not all trait items implemented, missing: `HelloFut`
error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async
--> $DIR/tarpc_server_missing_async.rs:10:5
|
10 | fn hello(name: String) -> String {
10 | fn hello(name: String) -> String {
| ^^

View File

@@ -1,9 +1,8 @@
use futures::prelude::*;
use std::io;
use tarpc::serde_transport;
use tarpc::{
client, context,
server::{BaseChannel, Incoming},
server::{incoming::Incoming, BaseChannel},
};
use tokio_serde::formats::Json;
@@ -33,7 +32,7 @@ impl ColorProtocol for ColorServer {
}
#[tokio::test]
async fn test_call() -> io::Result<()> {
async fn test_call() -> anyhow::Result<()> {
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
let addr = transport.local_addr();
tokio::spawn(

View File

@@ -7,7 +7,7 @@ use std::time::{Duration, SystemTime};
use tarpc::{
client::{self},
context,
server::{self, BaseChannel, Channel, Incoming},
server::{self, incoming::Incoming, BaseChannel, Channel},
transport::channel,
};
use tokio::join;
@@ -31,7 +31,7 @@ impl Service for Server {
type HeyFut = Ready<String>;
fn hey(self, _: context::Context, name: String) -> Self::HeyFut {
ready(format!("Hey, {}.", name))
ready(format!("Hey, {name}."))
}
}