mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6632f68d95 | ||
|
|
25985ad56a | ||
|
|
d6a24e9420 | ||
|
|
281a78f3c7 | ||
|
|
a0787d0091 | ||
|
|
d2acba0e8a | ||
|
|
ea7b6763c4 | ||
|
|
eb67c540b9 | ||
|
|
4151d0abd3 | ||
|
|
d0c11a6efa | ||
|
|
82c4da1743 | ||
|
|
0a15e0b75c | ||
|
|
0b315c29bf | ||
|
|
56f09bf61f | ||
|
|
6d82e82419 | ||
|
|
9bebaf814a | ||
|
|
5f4d6e6008 | ||
|
|
07d07d7ba3 | ||
|
|
a41bbf65b2 | ||
|
|
21e2f7ca62 | ||
|
|
7b7c182411 | ||
|
|
db0c778ead | ||
|
|
c3efb83ac1 | ||
|
|
3d7b0171fe | ||
|
|
c191ff5b2e | ||
|
|
90bc7f741d | ||
|
|
d3f6c01df2 | ||
|
|
c6450521e6 |
@@ -1,7 +1,11 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
|
||||
members = [
|
||||
"example-service",
|
||||
"tarpc",
|
||||
"plugins",
|
||||
]
|
||||
|
||||
[profile.dev]
|
||||
split-debuginfo = "unpacked"
|
||||
|
||||
22
README.md
22
README.md
@@ -40,7 +40,7 @@ rather than in a separate language such as .proto. This means there's no separat
|
||||
process, and no context switching between different languages.
|
||||
|
||||
Some other features of tarpc:
|
||||
- Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
|
||||
- Pluggable transport: any type implementing `Stream<Item = Request> + Sink<Response>` can be
|
||||
used as a transport to connect the client and server.
|
||||
- `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
|
||||
- Cascading cancellation: dropping a request will send a cancellation message to the server.
|
||||
@@ -51,6 +51,14 @@ Some other features of tarpc:
|
||||
requests sent by the server that use the request context will propagate the request deadline.
|
||||
For example, if a server is handling a request with a 10s deadline, does 2s of work, then
|
||||
sends a request to another server, that server will see an 8s deadline.
|
||||
- Distributed tracing: tarpc is instrumented with
|
||||
[tracing](https://github.com/tokio-rs/tracing) primitives extended with
|
||||
[OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
|
||||
[Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
|
||||
each RPC can be traced through the client, server, and other dependencies downstream of the
|
||||
server. Even for applications not connected to a distributed tracing collector, the
|
||||
instrumentation can also be ingested by regular loggers like
|
||||
[env_logger](https://github.com/env-logger-rs/env_logger/).
|
||||
- Serde serialization: enabling the `serde1` Cargo feature will make service requests and
|
||||
responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
|
||||
be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
|
||||
@@ -59,7 +67,7 @@ Some other features of tarpc:
|
||||
Add to your `Cargo.toml` dependencies:
|
||||
|
||||
```toml
|
||||
tarpc = "0.25"
|
||||
tarpc = "0.27"
|
||||
```
|
||||
|
||||
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
@@ -72,8 +80,9 @@ This example uses [tokio](https://tokio.rs), so add the following dependencies t
|
||||
your `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
anyhow = "1.0"
|
||||
futures = "1.0"
|
||||
tarpc = { version = "0.25", features = ["tokio1"] }
|
||||
tarpc = { version = "0.27", features = ["tokio1"] }
|
||||
tokio = { version = "1.0", features = ["macros"] }
|
||||
```
|
||||
|
||||
@@ -91,9 +100,8 @@ use futures::{
|
||||
};
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{self, Incoming},
|
||||
server::{self, incoming::Incoming},
|
||||
};
|
||||
use std::io;
|
||||
|
||||
// This is the service definition. It looks a lot like a trait definition.
|
||||
// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
@@ -132,7 +140,7 @@ available behind the `tcp` feature.
|
||||
|
||||
```rust
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
|
||||
let server = server::BaseChannel::with_defaults(server_transport);
|
||||
@@ -140,7 +148,7 @@ async fn main() -> io::Result<()> {
|
||||
|
||||
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
// that takes a config and any Transport as input.
|
||||
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
|
||||
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
||||
|
||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||
|
||||
88
RELEASES.md
88
RELEASES.md
@@ -1,3 +1,89 @@
|
||||
## 0.27.1 (2021-09-22)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
The Opentelemetry dependency is updated to version 0.16.x.
|
||||
|
||||
## 0.27.0 (2021-09-22)
|
||||
|
||||
This version was yanked due to tarpc-plugins version mismatches.
|
||||
|
||||
|
||||
## 0.26.0 (2021-04-14)
|
||||
|
||||
### New Features
|
||||
|
||||
#### Tracing
|
||||
|
||||
tarpc is now instrumented with tracing primitives extended with
|
||||
OpenTelemetry traces. Using a compatible tracing-opentelemetry
|
||||
subscriber like Jaeger, each RPC can be traced through the client,
|
||||
server, amd other dependencies downstream of the server. Even for
|
||||
applications not connected to a distributed tracing collector, the
|
||||
instrumentation can also be ingested by regular loggers like env_logger.
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
#### Logging
|
||||
|
||||
Logged events are now structured using tracing. For applications using a
|
||||
logger and not a tracing subscriber, these logs may look different or
|
||||
contain information in a less consumable manner. The easiest solution is
|
||||
to add a tracing subscriber that logs to stdout, such as
|
||||
tracing_subscriber::fmt.
|
||||
|
||||
#### Context
|
||||
|
||||
- Context no longer has parent_span, which was actually never needed,
|
||||
because the context sent in an RPC is inherently the parent context.
|
||||
For purposes of distributed tracing, the client side of the RPC has all
|
||||
necessary information to link the span to its parent; the server side
|
||||
need do nothing more than export the (trace ID, span ID) tuple.
|
||||
- Context has a new field, SamplingDecision, which has two variants,
|
||||
Sampled and Unsampled. This field can be used by downstream systems to
|
||||
determine whether a trace needs to be exported. If the parent span is
|
||||
sampled, the expectation is that all child spans be exported, as well;
|
||||
to do otherwise could result in lossy traces being exported. Note that
|
||||
if an Openetelemetry tracing subscriber is not installed, the fallback
|
||||
context will still be used, but the Context's sampling decision will
|
||||
always be inherited by the parent Context's sampling decision.
|
||||
- Context::scope has been removed. Context propagation is now done via
|
||||
tracing's task-local spans. Spans can be propagated across tasks via
|
||||
Span::in_scope. When a service receives a request, it attaches an
|
||||
Opentelemetry context to the local Span created before request handling,
|
||||
and this context contains the request deadline. This span-local deadline
|
||||
is retrieved by Context::current, but it cannot be modified so that
|
||||
future Context::current calls contain a different deadline. However, the
|
||||
deadline in the context passed into an RPC call will override it, so
|
||||
users can retrieve the current context and then modify the deadline
|
||||
field, as has been historically possible.
|
||||
- Context propgation precedence changes: when an RPC is initiated, the
|
||||
current Span's Opentelemetry context takes precedence over the trace
|
||||
context passed into the RPC method. If there is no current Span, then
|
||||
the trace context argument is used as it has been historically. Note
|
||||
that Opentelemetry context propagation requires an Opentelemetry
|
||||
tracing subscriber to be installed.
|
||||
|
||||
#### Server
|
||||
|
||||
- The server::Channel trait now has an additional required associated
|
||||
type and method which returns the underlying transport. This makes it
|
||||
more ergonomic for users to retrieve transport-specific information,
|
||||
like IP Address. BaseChannel implements Channel::transport by returning
|
||||
the underlying transport, and channel decorators like Throttler just
|
||||
delegate to the Channel::transport method of the wrapped channel.
|
||||
|
||||
#### Client
|
||||
|
||||
- NewClient::spawn no longer returns a result, as spawn can't fail.
|
||||
|
||||
### References
|
||||
|
||||
1. https://github.com/tokio-rs/tracing
|
||||
2. https://opentelemetry.io
|
||||
3. https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
|
||||
4. https://github.com/env-logger-rs/env_logger
|
||||
|
||||
## 0.25.0 (2021-03-10)
|
||||
|
||||
### Breaking Changes
|
||||
@@ -170,7 +256,7 @@ nameable futures and will just be boxing the return type anyway. This macro does
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- Enums had _non_exhaustive fields replaced with the #[non_exhaustive] attribute.
|
||||
- Enums had `_non_exhaustive` fields replaced with the #[non_exhaustive] attribute.
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc-example-service"
|
||||
version = "0.9.0"
|
||||
version = "0.10.0"
|
||||
authors = ["Tim Kuehn <tikue@google.com>"]
|
||||
edition = "2018"
|
||||
license = "MIT"
|
||||
@@ -13,12 +13,18 @@ readme = "../README.md"
|
||||
description = "An example server built on tarpc."
|
||||
|
||||
[dependencies]
|
||||
clap = "2.33"
|
||||
env_logger = "0.8"
|
||||
anyhow = "1.0"
|
||||
clap = "3.0.0-beta.2"
|
||||
log = "0.4"
|
||||
futures = "0.3"
|
||||
serde = { version = "1.0" }
|
||||
tarpc = { version = "0.25", path = "../tarpc", features = ["full"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
opentelemetry = { version = "0.16", features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.15", features = ["rt-tokio"] }
|
||||
rand = "0.8"
|
||||
tarpc = { version = "0.27", path = "../tarpc", features = ["full"] }
|
||||
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
|
||||
tracing = { version = "0.1" }
|
||||
tracing-opentelemetry = "0.15"
|
||||
tracing-subscriber = "0.2"
|
||||
|
||||
[lib]
|
||||
name = "service"
|
||||
|
||||
@@ -4,57 +4,49 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use clap::{App, Arg};
|
||||
use std::{io, net::SocketAddr};
|
||||
use clap::Clap;
|
||||
use service::{init_tracing, WorldClient};
|
||||
use std::{net::SocketAddr, time::Duration};
|
||||
use tarpc::{client, context, tokio_serde::formats::Json};
|
||||
use tokio::time::sleep;
|
||||
use tracing::Instrument;
|
||||
|
||||
#[derive(Clap)]
|
||||
struct Flags {
|
||||
/// Sets the server address to connect to.
|
||||
#[clap(long)]
|
||||
server_addr: SocketAddr,
|
||||
/// Sets the name to say hello to.
|
||||
#[clap(long)]
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
env_logger::init();
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let flags = Flags::parse();
|
||||
init_tracing("Tarpc Example Client")?;
|
||||
|
||||
let flags = App::new("Hello Client")
|
||||
.version("0.1")
|
||||
.author("Tim <tikue@google.com>")
|
||||
.about("Say hello!")
|
||||
.arg(
|
||||
Arg::with_name("server_addr")
|
||||
.long("server_addr")
|
||||
.value_name("ADDRESS")
|
||||
.help("Sets the server address to connect to.")
|
||||
.required(true)
|
||||
.takes_value(true),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("name")
|
||||
.short("n")
|
||||
.long("name")
|
||||
.value_name("STRING")
|
||||
.help("Sets the name to say hello to.")
|
||||
.required(true)
|
||||
.takes_value(true),
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
let server_addr = flags.value_of("server_addr").unwrap();
|
||||
let server_addr = server_addr
|
||||
.parse::<SocketAddr>()
|
||||
.unwrap_or_else(|e| panic!(r#"--server_addr value "{}" invalid: {}"#, server_addr, e));
|
||||
|
||||
let name = flags.value_of("name").unwrap().into();
|
||||
|
||||
let mut transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default);
|
||||
transport.config_mut().max_frame_length(usize::MAX);
|
||||
let transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
|
||||
|
||||
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
|
||||
// config and any Transport as input.
|
||||
let client = service::WorldClient::new(client::Config::default(), transport.await?).spawn()?;
|
||||
let client = WorldClient::new(client::Config::default(), transport.await?).spawn();
|
||||
|
||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||
// specifies a deadline and trace information which can be helpful in debugging requests.
|
||||
let hello = client.hello(context::current(), name).await?;
|
||||
let hello = async move {
|
||||
// Send the request twice, just to be safe! ;)
|
||||
tokio::select! {
|
||||
hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 }
|
||||
hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 }
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("Two Hellos"))
|
||||
.await;
|
||||
|
||||
println!("{}", hello);
|
||||
tracing::info!("{:?}", hello);
|
||||
|
||||
// Let the background span processor finish.
|
||||
sleep(Duration::from_micros(1)).await;
|
||||
opentelemetry::global::shutdown_tracer_provider();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use std::env;
|
||||
use tracing_subscriber::{fmt::format::FmtSpan, prelude::*};
|
||||
|
||||
/// This is the service definition. It looks a lot like a trait definition.
|
||||
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
#[tarpc::service]
|
||||
@@ -11,3 +14,21 @@ pub trait World {
|
||||
/// Returns a greeting for name.
|
||||
async fn hello(name: String) -> String;
|
||||
}
|
||||
|
||||
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
|
||||
pub fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||
|
||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||
.with_service_name(service_name)
|
||||
.with_max_packet_size(2usize.pow(13))
|
||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with(tracing_subscriber::fmt::layer().with_span_events(FmtSpan::NEW | FmtSpan::CLOSE))
|
||||
.with(tracing_opentelemetry::layer().with_tracer(tracer))
|
||||
.try_init()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,18 +4,30 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use clap::{App, Arg};
|
||||
use clap::Clap;
|
||||
use futures::{future, prelude::*};
|
||||
use service::World;
|
||||
use rand::{
|
||||
distributions::{Distribution, Uniform},
|
||||
thread_rng,
|
||||
};
|
||||
use service::{init_tracing, World};
|
||||
use std::{
|
||||
io,
|
||||
net::{IpAddr, SocketAddr},
|
||||
net::{IpAddr, Ipv6Addr, SocketAddr},
|
||||
time::Duration,
|
||||
};
|
||||
use tarpc::{
|
||||
context,
|
||||
server::{self, Channel, Incoming},
|
||||
server::{self, incoming::Incoming, Channel},
|
||||
tokio_serde::formats::Json,
|
||||
};
|
||||
use tokio::time;
|
||||
|
||||
#[derive(Clap)]
|
||||
struct Flags {
|
||||
/// Sets the port number to listen on.
|
||||
#[clap(long)]
|
||||
port: u16,
|
||||
}
|
||||
|
||||
// This is the type that implements the generated World trait. It is the business logic
|
||||
// and is used to start the server.
|
||||
@@ -25,35 +37,19 @@ struct HelloServer(SocketAddr);
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
async fn hello(self, _: context::Context, name: String) -> String {
|
||||
format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
||||
let sleep_time =
|
||||
Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng()));
|
||||
time::sleep(sleep_time).await;
|
||||
format!("Hello, {}! You are connected from {}", name, self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
env_logger::init();
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let flags = Flags::parse();
|
||||
init_tracing("Tarpc Example Server")?;
|
||||
|
||||
let flags = App::new("Hello Server")
|
||||
.version("0.1")
|
||||
.author("Tim <tikue@google.com>")
|
||||
.about("Say hello!")
|
||||
.arg(
|
||||
Arg::with_name("port")
|
||||
.short("p")
|
||||
.long("port")
|
||||
.value_name("NUMBER")
|
||||
.help("Sets the port number to listen on")
|
||||
.required(true)
|
||||
.takes_value(true),
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
let port = flags.value_of("port").unwrap();
|
||||
let port = port
|
||||
.parse()
|
||||
.unwrap_or_else(|e| panic!(r#"--port value "{}" invalid: {}"#, port, e));
|
||||
|
||||
let server_addr = (IpAddr::from([0, 0, 0, 0]), port);
|
||||
let server_addr = (IpAddr::V6(Ipv6Addr::LOCALHOST), flags.port);
|
||||
|
||||
// JSON transport is provided by the json_transport tarpc module. It makes it easy
|
||||
// to start up a serde-powered json serialization strategy over TCP.
|
||||
@@ -64,12 +60,12 @@ async fn main() -> io::Result<()> {
|
||||
.filter_map(|r| future::ready(r.ok()))
|
||||
.map(server::BaseChannel::with_defaults)
|
||||
// Limit channels to 1 per IP.
|
||||
.max_channels_per_key(1, |t| t.as_ref().peer_addr().unwrap().ip())
|
||||
.max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip())
|
||||
// serve is generated by the service attribute. It takes as input any type implementing
|
||||
// the generated World trait.
|
||||
.map(|channel| {
|
||||
let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap());
|
||||
channel.requests().execute(server.serve())
|
||||
let server = HelloServer(channel.transport().peer_addr().unwrap());
|
||||
channel.execute(server.serve())
|
||||
})
|
||||
// Max 10 channels.
|
||||
.buffer_unordered(10)
|
||||
|
||||
@@ -67,7 +67,7 @@ else
|
||||
fi
|
||||
|
||||
printf "${PREFIX} Checking for rustfmt ... "
|
||||
command -v cargo fmt &>/dev/null
|
||||
command -v rustfmt &>/dev/null
|
||||
if [ $? == 0 ]; then
|
||||
printf "${SUCCESS}\n"
|
||||
else
|
||||
@@ -93,19 +93,19 @@ diff=""
|
||||
for file in $(git diff --name-only --cached);
|
||||
do
|
||||
if [ ${file: -3} == ".rs" ]; then
|
||||
diff="$diff$(cargo fmt -- --check $file)"
|
||||
diff="$diff$(rustfmt --edition 2018 --check $file)"
|
||||
if [ $? != 0 ]; then
|
||||
FMTRESULT=1
|
||||
fi
|
||||
fi
|
||||
done
|
||||
if grep --quiet "^[-+]" <<< "$diff"; then
|
||||
FMTRESULT=1
|
||||
fi
|
||||
|
||||
if [ "${TARPC_SKIP_RUSTFMT}" == 1 ]; then
|
||||
printf "${SKIPPED}\n"$?
|
||||
elif [ ${FMTRESULT} != 0 ]; then
|
||||
FAILED=1
|
||||
printf "${FAILURE}\n"
|
||||
echo "$diff" | sed 's/Using rustfmt config file.*$/d/'
|
||||
echo "$diff"
|
||||
else
|
||||
printf "${SUCCESS}\n"
|
||||
fi
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc-plugins"
|
||||
version = "0.10.0"
|
||||
version = "0.12.0"
|
||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||
edition = "2018"
|
||||
license = "MIT"
|
||||
@@ -30,4 +30,4 @@ proc-macro = true
|
||||
assert-type-eq = "0.1.0"
|
||||
futures = "0.3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tarpc = { path = "../tarpc" }
|
||||
tarpc = { path = "../tarpc", features = ["serde1"] }
|
||||
|
||||
@@ -267,18 +267,25 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
None
|
||||
};
|
||||
|
||||
let methods = rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>();
|
||||
let request_names = methods
|
||||
.iter()
|
||||
.map(|m| format!("{}.{}", ident, m))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
ServiceGenerator {
|
||||
response_fut_name,
|
||||
service_ident: ident,
|
||||
server_ident: &format_ident!("Serve{}", ident),
|
||||
response_fut_ident: &Ident::new(&response_fut_name, ident.span()),
|
||||
response_fut_ident: &Ident::new(response_fut_name, ident.span()),
|
||||
client_ident: &format_ident!("{}Client", ident),
|
||||
request_ident: &format_ident!("{}Request", ident),
|
||||
response_ident: &format_ident!("{}Response", ident),
|
||||
vis,
|
||||
args,
|
||||
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
|
||||
method_idents: &rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>(),
|
||||
method_idents: &methods,
|
||||
request_names: &*request_names,
|
||||
attrs,
|
||||
rpcs,
|
||||
return_types: &rpcs
|
||||
@@ -399,11 +406,7 @@ fn verify_types_were_provided(
|
||||
) -> syn::Result<()> {
|
||||
let mut result = Ok(());
|
||||
for (method, expected) in expected {
|
||||
if provided
|
||||
.iter()
|
||||
.find(|typedecl| typedecl.ident == expected)
|
||||
.is_none()
|
||||
{
|
||||
if !provided.iter().any(|typedecl| typedecl.ident == expected) {
|
||||
let mut e = syn::Error::new(
|
||||
span,
|
||||
format!("not all trait items implemented, missing: `{}`", expected),
|
||||
@@ -441,6 +444,7 @@ struct ServiceGenerator<'a> {
|
||||
camel_case_idents: &'a [Ident],
|
||||
future_types: &'a [Type],
|
||||
method_idents: &'a [&'a Ident],
|
||||
request_names: &'a [String],
|
||||
method_attrs: &'a [&'a [Attribute]],
|
||||
args: &'a [&'a [PatType]],
|
||||
return_types: &'a [&'a Type],
|
||||
@@ -475,7 +479,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
),
|
||||
output,
|
||||
)| {
|
||||
let ty_doc = format!("The response future returned by {}.", ident);
|
||||
let ty_doc = format!("The response future returned by [`{}::{}`].", service_ident, ident);
|
||||
quote! {
|
||||
#[doc = #ty_doc]
|
||||
type #future_type: std::future::Future<Output = #output>;
|
||||
@@ -524,6 +528,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
camel_case_idents,
|
||||
arg_pats,
|
||||
method_idents,
|
||||
request_names,
|
||||
..
|
||||
} = self;
|
||||
|
||||
@@ -534,6 +539,16 @@ impl<'a> ServiceGenerator<'a> {
|
||||
type Resp = #response_ident;
|
||||
type Fut = #response_fut_ident<S>;
|
||||
|
||||
fn method(&self, req: &#request_ident) -> Option<&'static str> {
|
||||
Some(match req {
|
||||
#(
|
||||
#request_ident::#camel_case_idents{..} => {
|
||||
#request_names
|
||||
}
|
||||
)*
|
||||
})
|
||||
}
|
||||
|
||||
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
|
||||
match req {
|
||||
#(
|
||||
@@ -563,6 +578,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
|
||||
quote! {
|
||||
/// The request sent over the wire from the client to the server.
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug)]
|
||||
#derive_serialize
|
||||
#vis enum #request_ident {
|
||||
@@ -583,6 +599,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
|
||||
quote! {
|
||||
/// The response sent over the wire from the server to the client.
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug)]
|
||||
#derive_serialize
|
||||
#vis enum #response_ident {
|
||||
@@ -603,6 +620,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
|
||||
quote! {
|
||||
/// A future resolving to a server response.
|
||||
#[allow(missing_docs)]
|
||||
#vis enum #response_fut_ident<S: #service_ident> {
|
||||
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
|
||||
}
|
||||
@@ -714,6 +732,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
method_attrs,
|
||||
vis,
|
||||
method_idents,
|
||||
request_names,
|
||||
args,
|
||||
return_types,
|
||||
arg_pats,
|
||||
@@ -727,9 +746,9 @@ impl<'a> ServiceGenerator<'a> {
|
||||
#[allow(unused)]
|
||||
#( #method_attrs )*
|
||||
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
|
||||
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
|
||||
-> impl std::future::Future<Output = Result<#return_types, tarpc::client::RpcError>> + '_ {
|
||||
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
|
||||
let resp = self.0.call(ctx, request);
|
||||
let resp = self.0.call(ctx, #request_names, request);
|
||||
async move {
|
||||
match resp.await? {
|
||||
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc"
|
||||
version = "0.25.1"
|
||||
version = "0.27.1"
|
||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||
edition = "2018"
|
||||
license = "MIT"
|
||||
@@ -17,10 +17,12 @@ default = []
|
||||
|
||||
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
|
||||
tokio1 = ["tokio/rt-multi-thread"]
|
||||
serde-transport = ["serde1", "tokio1", "tokio-serde/json", "tokio-util/codec"]
|
||||
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
||||
serde-transport-json = ["tokio-serde/json"]
|
||||
serde-transport-bincode = ["tokio-serde/bincode"]
|
||||
tcp = ["tokio/net"]
|
||||
|
||||
full = ["serde1", "tokio1", "serde-transport", "tcp"]
|
||||
full = ["serde1", "tokio1", "serde-transport", "serde-transport-json", "serde-transport-bincode", "tcp"]
|
||||
|
||||
[badges]
|
||||
travis-ci = { repository = "google/tarpc" }
|
||||
@@ -30,26 +32,31 @@ anyhow = "1.0"
|
||||
fnv = "1.0"
|
||||
futures = "0.3"
|
||||
humantime = "2.0"
|
||||
log = "0.4"
|
||||
pin-project = "1.0"
|
||||
rand = "0.7"
|
||||
rand = "0.8"
|
||||
serde = { optional = true, version = "1.0", features = ["derive"] }
|
||||
static_assertions = "1.1.0"
|
||||
tarpc-plugins = { path = "../plugins", version = "0.10" }
|
||||
tarpc-plugins = { path = "../plugins", version = "0.12" }
|
||||
thiserror = "1.0"
|
||||
tokio = { version = "1", features = ["time"] }
|
||||
tokio-util = { version = "0.6.3", features = ["time"] }
|
||||
tokio-serde = { optional = true, version = "0.8" }
|
||||
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
|
||||
tracing-opentelemetry = { version = "0.15", default-features = false }
|
||||
opentelemetry = { version = "0.16", default-features = false }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = "1.4"
|
||||
bincode = "1.3"
|
||||
bytes = { version = "1", features = ["serde"] }
|
||||
env_logger = "0.8"
|
||||
flate2 = "1.0"
|
||||
futures-test = "0.3"
|
||||
log = "0.4"
|
||||
opentelemetry = { version = "0.16", default-features = false, features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.15", features = ["rt-tokio"] }
|
||||
pin-utils = "0.1.0-alpha"
|
||||
serde_bytes = "0.11"
|
||||
tracing-subscriber = "0.2"
|
||||
tokio = { version = "1", features = ["full", "test-util"] }
|
||||
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
|
||||
trybuild = "1.0"
|
||||
@@ -63,7 +70,7 @@ name = "compression"
|
||||
required-features = ["serde-transport", "tcp"]
|
||||
|
||||
[[example]]
|
||||
name = "server_calling_server"
|
||||
name = "tracing"
|
||||
required-features = ["full"]
|
||||
|
||||
[[example]]
|
||||
|
||||
@@ -7,8 +7,8 @@ use tarpc::{
|
||||
client, context,
|
||||
serde_transport::tcp,
|
||||
server::{BaseChannel, Channel},
|
||||
tokio_serde::formats::Bincode,
|
||||
};
|
||||
use tokio_serde::formats::Bincode;
|
||||
|
||||
/// Type of compression that should be enabled on the request. The transport is free to ignore this.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)]
|
||||
@@ -118,7 +118,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
});
|
||||
|
||||
let transport = tcp::connect(addr, Bincode::default).await?;
|
||||
let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn()?;
|
||||
let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn();
|
||||
|
||||
println!(
|
||||
"{}",
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
use futures::future;
|
||||
use tarpc::context::Context;
|
||||
use tarpc::serde_transport as transport;
|
||||
use tarpc::server::{BaseChannel, Channel};
|
||||
use tarpc::{context::Context, tokio_serde::formats::Bincode};
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
use tokio_serde::formats::Bincode;
|
||||
use tokio_util::codec::length_delimited::LengthDelimitedCodec;
|
||||
|
||||
#[tarpc::service]
|
||||
@@ -14,16 +12,13 @@ pub trait PingService {
|
||||
#[derive(Clone)]
|
||||
struct Service;
|
||||
|
||||
#[tarpc::server]
|
||||
impl PingService for Service {
|
||||
type PingFut = future::Ready<()>;
|
||||
|
||||
fn ping(self, _: Context) -> Self::PingFut {
|
||||
future::ready(())
|
||||
}
|
||||
async fn ping(self, _: Context) {}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let bind_addr = "/tmp/tarpc_on_unix_example.sock";
|
||||
|
||||
let _ = std::fs::remove_file(bind_addr);
|
||||
@@ -44,7 +39,9 @@ async fn main() -> std::io::Result<()> {
|
||||
let conn = UnixStream::connect(bind_addr).await?;
|
||||
let transport = transport::new(codec_builder.new_framed(conn), Bincode::default());
|
||||
PingServiceClient::new(Default::default(), transport)
|
||||
.spawn()?
|
||||
.spawn()
|
||||
.ping(tarpc::context::current())
|
||||
.await
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -38,10 +38,11 @@ use futures::{
|
||||
future::{self, AbortHandle},
|
||||
prelude::*,
|
||||
};
|
||||
use log::info;
|
||||
use publisher::Publisher as _;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
env,
|
||||
error::Error,
|
||||
io,
|
||||
net::SocketAddr,
|
||||
sync::{Arc, Mutex, RwLock},
|
||||
@@ -54,6 +55,8 @@ use tarpc::{
|
||||
};
|
||||
use tokio::net::ToSocketAddrs;
|
||||
use tokio_serde::formats::Json;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
pub mod subscriber {
|
||||
#[tarpc::service]
|
||||
@@ -83,10 +86,7 @@ impl subscriber::Subscriber for Subscriber {
|
||||
}
|
||||
|
||||
async fn receive(self, _: context::Context, topic: String, message: String) {
|
||||
info!(
|
||||
"[{}] received message on topic '{}': {}",
|
||||
self.local_addr, topic, message
|
||||
);
|
||||
info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,7 +120,7 @@ impl Subscriber {
|
||||
let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve()));
|
||||
tokio::spawn(async move {
|
||||
match handler.await {
|
||||
Ok(()) | Err(future::Aborted) => info!("[{}] subscriber shutdown.", local_addr),
|
||||
Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."),
|
||||
}
|
||||
});
|
||||
Ok(SubscriberHandle(abort_handle))
|
||||
@@ -153,13 +153,13 @@ impl Publisher {
|
||||
subscriptions: self.clone().start_subscription_manager().await?,
|
||||
};
|
||||
|
||||
info!("[{}] listening for publishers.", publisher_addrs.publisher);
|
||||
info!(publisher_addr = %publisher_addrs.publisher, "listening for publishers.",);
|
||||
tokio::spawn(async move {
|
||||
// Because this is just an example, we know there will only be one publisher. In more
|
||||
// realistic code, this would be a loop to continually accept new publisher
|
||||
// connections.
|
||||
let publisher = connecting_publishers.next().await.unwrap().unwrap();
|
||||
info!("[{}] publisher connected.", publisher.peer_addr().unwrap());
|
||||
info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected.");
|
||||
|
||||
server::BaseChannel::with_defaults(publisher)
|
||||
.execute(self.serve())
|
||||
@@ -174,7 +174,7 @@ impl Publisher {
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let new_subscriber_addr = connecting_subscribers.get_ref().local_addr();
|
||||
info!("[{}] listening for subscribers.", new_subscriber_addr);
|
||||
info!(?new_subscriber_addr, "listening for subscribers.");
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(conn) = connecting_subscribers.next().await {
|
||||
@@ -215,7 +215,7 @@ impl Publisher {
|
||||
},
|
||||
);
|
||||
|
||||
info!("[{}] subscribed to topics: {:?}", subscriber_addr, topics);
|
||||
info!(%subscriber_addr, ?topics, "subscribed to new topics");
|
||||
let mut subscriptions = self.subscriptions.write().unwrap();
|
||||
for topic in topics {
|
||||
subscriptions
|
||||
@@ -226,18 +226,18 @@ impl Publisher {
|
||||
}
|
||||
}
|
||||
|
||||
fn start_subscriber_gc(
|
||||
fn start_subscriber_gc<E: Error>(
|
||||
self,
|
||||
subscriber_addr: SocketAddr,
|
||||
client_dispatch: impl Future<Output = anyhow::Result<()>> + Send + 'static,
|
||||
client_dispatch: impl Future<Output = Result<(), E>> + Send + 'static,
|
||||
subscriber_ready: oneshot::Receiver<()>,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = client_dispatch.await {
|
||||
info!(
|
||||
"[{}] subscriber connection broken: {:?}",
|
||||
subscriber_addr, e
|
||||
)
|
||||
%subscriber_addr,
|
||||
error = %e,
|
||||
"subscriber connection broken");
|
||||
}
|
||||
// Don't clean up the subscriber until initialization is done.
|
||||
let _ = subscriber_ready.await;
|
||||
@@ -281,13 +281,29 @@ impl publisher::Publisher for Publisher {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
|
||||
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||
.with_service_name(service_name)
|
||||
.with_max_packet_size(2usize.pow(13))
|
||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(tracing_opentelemetry::layer().with_tracer(tracer))
|
||||
.try_init()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
env_logger::init();
|
||||
init_tracing("Pub/Sub")?;
|
||||
|
||||
let clients = Arc::new(Mutex::new(HashMap::new()));
|
||||
let addrs = Publisher {
|
||||
clients,
|
||||
clients: Arc::new(Mutex::new(HashMap::new())),
|
||||
subscriptions: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
.start()
|
||||
@@ -309,7 +325,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
client::Config::default(),
|
||||
tcp::connect(addrs.publisher, Json::default).await?,
|
||||
)
|
||||
.spawn()?;
|
||||
.spawn();
|
||||
|
||||
publisher
|
||||
.publish(context::current(), "calculus".into(), "sqrt(2)".into())
|
||||
@@ -337,6 +353,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
opentelemetry::global::shutdown_tracer_provider();
|
||||
info!("done.");
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use futures::future::{self, Ready};
|
||||
use std::io;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{self, Channel},
|
||||
@@ -35,7 +34,7 @@ impl World for HelloServer {
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
|
||||
let server = server::BaseChannel::with_defaults(server_transport);
|
||||
@@ -43,7 +42,7 @@ async fn main() -> io::Result<()> {
|
||||
|
||||
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
// that takes a config and any Transport as input.
|
||||
let client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
|
||||
let client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
||||
|
||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||
|
||||
@@ -6,12 +6,13 @@
|
||||
|
||||
use crate::{add::Add as AddService, double::Double as DoubleService};
|
||||
use futures::{future, prelude::*};
|
||||
use std::io;
|
||||
use std::env;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{BaseChannel, Incoming},
|
||||
server::{incoming::Incoming, BaseChannel},
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
pub mod add {
|
||||
#[tarpc::service]
|
||||
@@ -54,9 +55,25 @@ impl DoubleService for DoubleServer {
|
||||
}
|
||||
}
|
||||
|
||||
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||
.with_service_name(service_name)
|
||||
.with_max_packet_size(2usize.pow(13))
|
||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(tracing_opentelemetry::layer().with_tracer(tracer))
|
||||
.try_init()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
env_logger::init();
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
init_tracing("tarpc_tracing_example")?;
|
||||
|
||||
let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||
.await?
|
||||
@@ -69,7 +86,7 @@ async fn main() -> io::Result<()> {
|
||||
tokio::spawn(add_server);
|
||||
|
||||
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn()?;
|
||||
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn();
|
||||
|
||||
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||
.await?
|
||||
@@ -83,10 +100,14 @@ async fn main() -> io::Result<()> {
|
||||
|
||||
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let double_client =
|
||||
double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?;
|
||||
double::DoubleClient::new(client::Config::default(), to_double_server).spawn();
|
||||
|
||||
for i in 1..=5 {
|
||||
eprintln!("{:?}", double_client.double(context::current(), i).await?);
|
||||
let ctx = context::current();
|
||||
for _ in 1..=5 {
|
||||
tracing::info!("{:?}", double_client.double(ctx, 1).await?);
|
||||
}
|
||||
|
||||
opentelemetry::global::shutdown_tracer_provider();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -8,16 +8,14 @@
|
||||
|
||||
mod in_flight_requests;
|
||||
|
||||
use crate::{
|
||||
context, trace::SpanId, ClientMessage, PollContext, PollIo, Request, Response, Transport,
|
||||
};
|
||||
use crate::{context, trace, ClientMessage, Request, Response, ServerError, Transport};
|
||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||
use in_flight_requests::InFlightRequests;
|
||||
use log::{info, trace};
|
||||
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
fmt, io,
|
||||
error::Error,
|
||||
fmt, mem,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
@@ -25,6 +23,7 @@ use std::{
|
||||
},
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::Span;
|
||||
|
||||
/// Settings that control the behavior of the client.
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -61,19 +60,18 @@ pub struct NewClient<C, D> {
|
||||
impl<C, D, E> NewClient<C, D>
|
||||
where
|
||||
D: Future<Output = Result<(), E>> + Send + 'static,
|
||||
E: std::fmt::Display,
|
||||
E: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
/// Helper method to spawn the dispatch on the default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub fn spawn(self) -> io::Result<C> {
|
||||
use log::warn;
|
||||
|
||||
let dispatch = self
|
||||
.dispatch
|
||||
.unwrap_or_else(move |e| warn!("Connection broken: {}", e));
|
||||
pub fn spawn(self) -> C {
|
||||
let dispatch = self.dispatch.unwrap_or_else(move |e| {
|
||||
let e = anyhow::Error::new(e);
|
||||
tracing::warn!("Connection broken: {:?}", e);
|
||||
});
|
||||
tokio::spawn(dispatch);
|
||||
Ok(self.client)
|
||||
self.client
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,100 +112,117 @@ impl<Req, Resp> Clone for Channel<Req, Resp> {
|
||||
|
||||
impl<Req, Resp> Channel<Req, Resp> {
|
||||
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
||||
/// resolves when the request is sent (not when the response is received).
|
||||
fn send(
|
||||
/// resolves to the response.
|
||||
#[tracing::instrument(
|
||||
name = "RPC",
|
||||
skip(self, ctx, request_name, request),
|
||||
fields(
|
||||
rpc.trace_id = tracing::field::Empty,
|
||||
otel.kind = "client",
|
||||
otel.name = request_name)
|
||||
)]
|
||||
pub async fn call(
|
||||
&self,
|
||||
mut ctx: context::Context,
|
||||
request_name: &str,
|
||||
request: Req,
|
||||
) -> impl Future<Output = io::Result<DispatchResponse<Resp>>> + '_ {
|
||||
// Convert the context to the call context.
|
||||
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
|
||||
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
|
||||
|
||||
let (response_completion, response) = oneshot::channel();
|
||||
let cancellation = self.cancellation.clone();
|
||||
) -> Result<Resp, RpcError> {
|
||||
let span = Span::current();
|
||||
ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
|
||||
tracing::warn!(
|
||||
"OpenTelemetry subscriber not installed; making unsampled child context."
|
||||
);
|
||||
ctx.trace_context.new_child()
|
||||
});
|
||||
span.record("rpc.trace_id", &tracing::field::display(ctx.trace_id()));
|
||||
let (response_completion, mut response) = oneshot::channel();
|
||||
let request_id =
|
||||
u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||
|
||||
// DispatchResponse impls Drop to cancel in-flight requests. It should be created before
|
||||
// ResponseGuard impls Drop to cancel in-flight requests. It should be created before
|
||||
// sending out the request; otherwise, the response future could be dropped after the
|
||||
// request is sent out but before DispatchResponse is created, rendering the cancellation
|
||||
// request is sent out but before ResponseGuard is created, rendering the cancellation
|
||||
// logic inactive.
|
||||
let response = DispatchResponse {
|
||||
response,
|
||||
let response_guard = ResponseGuard {
|
||||
response: &mut response,
|
||||
request_id,
|
||||
cancellation: Some(cancellation),
|
||||
ctx,
|
||||
cancellation: &self.cancellation,
|
||||
};
|
||||
async move {
|
||||
self.to_dispatch
|
||||
.send(DispatchRequest {
|
||||
ctx,
|
||||
request_id,
|
||||
request,
|
||||
response_completion,
|
||||
})
|
||||
.await
|
||||
.map_err(|mpsc::error::SendError(_)| {
|
||||
io::Error::from(io::ErrorKind::ConnectionReset)
|
||||
})?;
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
||||
/// resolves to the response.
|
||||
pub async fn call(&self, ctx: context::Context, request: Req) -> io::Result<Resp> {
|
||||
let dispatch_response = self.send(ctx, request).await?;
|
||||
dispatch_response.await
|
||||
self.to_dispatch
|
||||
.send(DispatchRequest {
|
||||
ctx,
|
||||
span,
|
||||
request_id,
|
||||
request,
|
||||
response_completion,
|
||||
})
|
||||
.await
|
||||
.map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
|
||||
response_guard.response().await
|
||||
}
|
||||
}
|
||||
|
||||
/// A server response that is completed by request dispatch when the corresponding response
|
||||
/// arrives off the wire.
|
||||
#[derive(Debug)]
|
||||
struct DispatchResponse<Resp> {
|
||||
response: oneshot::Receiver<Response<Resp>>,
|
||||
ctx: context::Context,
|
||||
cancellation: Option<RequestCancellation>,
|
||||
struct ResponseGuard<'a, Resp> {
|
||||
response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
cancellation: &'a RequestCancellation,
|
||||
request_id: u64,
|
||||
}
|
||||
|
||||
impl<Resp> Future for DispatchResponse<Resp> {
|
||||
type Output = io::Result<Resp>;
|
||||
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
|
||||
/// rather cross-cutting errors that can always occur.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum RpcError {
|
||||
/// The client disconnected from the server.
|
||||
#[error("the client disconnected from the server")]
|
||||
Disconnected,
|
||||
/// The request exceeded its deadline.
|
||||
#[error("the request exceeded its deadline")]
|
||||
DeadlineExceeded,
|
||||
/// The server aborted request processing.
|
||||
#[error("the server aborted request processing")]
|
||||
Server(#[from] ServerError),
|
||||
}
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
|
||||
let resp = ready!(self.response.poll_unpin(cx));
|
||||
self.cancellation.take();
|
||||
Poll::Ready(match resp {
|
||||
Ok(resp) => Ok(resp.message?),
|
||||
impl From<DeadlineExceededError> for RpcError {
|
||||
fn from(_: DeadlineExceededError) -> Self {
|
||||
RpcError::DeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
impl<Resp> ResponseGuard<'_, Resp> {
|
||||
async fn response(mut self) -> Result<Resp, RpcError> {
|
||||
let response = (&mut self.response).await;
|
||||
// Cancel drop logic once a response has been received.
|
||||
mem::forget(self);
|
||||
match response {
|
||||
Ok(resp) => Ok(resp?.message?),
|
||||
Err(oneshot::error::RecvError { .. }) => {
|
||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||
// there's nothing listening on the other side, so there's no point in
|
||||
// propagating cancellation.
|
||||
Err(io::Error::from(io::ErrorKind::ConnectionReset))
|
||||
Err(RpcError::Disconnected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cancels the request when dropped, if not already complete.
|
||||
impl<Resp> Drop for DispatchResponse<Resp> {
|
||||
impl<Resp> Drop for ResponseGuard<'_, Resp> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(cancellation) = &mut self.cancellation {
|
||||
// The receiver needs to be closed to handle the edge case that the request has not
|
||||
// yet been received by the dispatch task. It is possible for the cancel message to
|
||||
// arrive before the request itself, in which case the request could get stuck in the
|
||||
// dispatch map forever if the server never responds (e.g. if the server dies while
|
||||
// responding). Even if the server does respond, it will have unnecessarily done work
|
||||
// for a client no longer waiting for a response. To avoid this, the dispatch task
|
||||
// checks if the receiver is closed before inserting the request in the map. By
|
||||
// closing the receiver before sending the cancel message, it is guaranteed that if the
|
||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||
// receiver as closed.
|
||||
self.response.close();
|
||||
cancellation.cancel(self.request_id);
|
||||
}
|
||||
// The receiver needs to be closed to handle the edge case that the request has not
|
||||
// yet been received by the dispatch task. It is possible for the cancel message to
|
||||
// arrive before the request itself, in which case the request could get stuck in the
|
||||
// dispatch map forever if the server never responds (e.g. if the server dies while
|
||||
// responding). Even if the server does respond, it will have unnecessarily done work
|
||||
// for a client no longer waiting for a response. To avoid this, the dispatch task
|
||||
// checks if the receiver is closed before inserting the request in the map. By
|
||||
// closing the receiver before sending the cancel message, it is guaranteed that if the
|
||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||
// receiver as closed.
|
||||
self.response.close();
|
||||
self.cancellation.cancel(self.request_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -258,6 +273,29 @@ pub struct RequestDispatch<Req, Resp, C> {
|
||||
config: Config,
|
||||
}
|
||||
|
||||
/// Critical errors that result in a Channel disconnecting.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ChannelError<E>
|
||||
where
|
||||
E: Error + Send + Sync + 'static,
|
||||
{
|
||||
/// Could not read from the transport.
|
||||
#[error("could not read from the transport")]
|
||||
Read(#[source] E),
|
||||
/// Could not ready the transport for writes.
|
||||
#[error("could not ready the transport for writes")]
|
||||
Ready(#[source] E),
|
||||
/// Could not write to the transport.
|
||||
#[error("could not write to the transport")]
|
||||
Write(#[source] E),
|
||||
/// Could not flush the transport.
|
||||
#[error("could not flush the transport")]
|
||||
Flush(#[source] E),
|
||||
/// Could not poll expired requests.
|
||||
#[error("could not poll expired requests")]
|
||||
Timer(#[source] tokio::time::error::Error),
|
||||
}
|
||||
|
||||
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
|
||||
where
|
||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||
@@ -270,6 +308,33 @@ where
|
||||
self.as_mut().project().transport
|
||||
}
|
||||
|
||||
fn poll_ready<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||
self.transport_pin_mut()
|
||||
.poll_ready(cx)
|
||||
.map_err(ChannelError::Ready)
|
||||
}
|
||||
|
||||
fn start_send(
|
||||
self: &mut Pin<&mut Self>,
|
||||
message: ClientMessage<Req>,
|
||||
) -> Result<(), ChannelError<C::Error>> {
|
||||
self.transport_pin_mut()
|
||||
.start_send(message)
|
||||
.map_err(ChannelError::Write)
|
||||
}
|
||||
|
||||
fn poll_flush<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||
self.transport_pin_mut()
|
||||
.poll_flush(cx)
|
||||
.map_err(ChannelError::Flush)
|
||||
}
|
||||
|
||||
fn canceled_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut CanceledRequests {
|
||||
self.as_mut().project().canceled_requests
|
||||
}
|
||||
@@ -280,17 +345,22 @@ where
|
||||
self.as_mut().project().pending_requests
|
||||
}
|
||||
|
||||
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
Poll::Ready(match ready!(self.transport_pin_mut().poll_next(cx)?) {
|
||||
Some(response) => {
|
||||
fn pump_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
self.transport_pin_mut()
|
||||
.poll_next(cx)
|
||||
.map_err(ChannelError::Read)
|
||||
.map_ok(|response| {
|
||||
self.complete(response);
|
||||
Some(Ok(()))
|
||||
}
|
||||
None => None,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
fn pump_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
enum ReceiverStatus {
|
||||
Pending,
|
||||
Closed,
|
||||
@@ -311,7 +381,11 @@ where
|
||||
// Receiving Poll::Ready(None) when polling expired requests never indicates "Closed",
|
||||
// because there can temporarily be zero in-flight rquests. Therefore, there is no need to
|
||||
// track the status like is done with pending and cancelled requests.
|
||||
if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx)? {
|
||||
if let Poll::Ready(Some(_)) = self
|
||||
.in_flight_requests()
|
||||
.poll_expired(cx)
|
||||
.map_err(ChannelError::Timer)?
|
||||
{
|
||||
// Expired requests are considered complete; there is no compelling reason to send a
|
||||
// cancellation message to the server, since it will have already exhausted its
|
||||
// allotted processing time.
|
||||
@@ -320,12 +394,12 @@ where
|
||||
|
||||
match (pending_requests_status, canceled_requests_status) {
|
||||
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
|
||||
ready!(self.transport_pin_mut().poll_flush(cx)?);
|
||||
ready!(self.poll_flush(cx)?);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
(ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => {
|
||||
// No more messages to process, so flush any messages buffered in the transport.
|
||||
ready!(self.transport_pin_mut().poll_flush(cx)?);
|
||||
ready!(self.poll_flush(cx)?);
|
||||
|
||||
// Even if we fully-flush, we return Pending, because we have no more requests
|
||||
// or cancellations right now.
|
||||
@@ -341,9 +415,9 @@ where
|
||||
fn poll_next_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<DispatchRequest<Req, Resp>> {
|
||||
) -> Poll<Option<Result<DispatchRequest<Req, Resp>, ChannelError<C::Error>>>> {
|
||||
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
|
||||
info!(
|
||||
tracing::info!(
|
||||
"At in-flight request capacity ({}/{}).",
|
||||
self.in_flight_requests().len(),
|
||||
self.config.max_in_flight_requests
|
||||
@@ -360,10 +434,8 @@ where
|
||||
match ready!(self.pending_requests_mut().poll_recv(cx)) {
|
||||
Some(request) => {
|
||||
if request.response_completion.is_closed() {
|
||||
trace!(
|
||||
"[{}] Request canceled before being sent.",
|
||||
request.ctx.trace_id()
|
||||
);
|
||||
let _entered = request.span.enter();
|
||||
tracing::info!("AbortRequest");
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -381,14 +453,15 @@ where
|
||||
fn poll_next_cancellation(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<(context::Context, u64)> {
|
||||
) -> Poll<Option<Result<(context::Context, Span, u64), ChannelError<C::Error>>>> {
|
||||
ready!(self.ensure_writeable(cx)?);
|
||||
|
||||
loop {
|
||||
match ready!(self.canceled_requests_mut().poll_next_unpin(cx)) {
|
||||
Some(request_id) => {
|
||||
if let Some(ctx) = self.in_flight_requests().cancel_request(request_id) {
|
||||
return Poll::Ready(Some(Ok((ctx, request_id))));
|
||||
if let Some((ctx, span)) = self.in_flight_requests().cancel_request(request_id)
|
||||
{
|
||||
return Poll::Ready(Some(Ok((ctx, span, request_id))));
|
||||
}
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
@@ -399,54 +472,73 @@ where
|
||||
/// Returns Ready if writing a message to the transport (i.e. via write_request or
|
||||
/// write_cancel) would not fail due to a full buffer. If the transport is not ready to be
|
||||
/// written to, flushes it until it is ready.
|
||||
fn ensure_writeable<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
while self.transport_pin_mut().poll_ready(cx)?.is_pending() {
|
||||
ready!(self.transport_pin_mut().poll_flush(cx)?);
|
||||
fn ensure_writeable<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
while self.poll_ready(cx)?.is_pending() {
|
||||
ready!(self.poll_flush(cx)?);
|
||||
}
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
fn poll_write_request<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
let dispatch_request = match ready!(self.as_mut().poll_next_request(cx)?) {
|
||||
fn poll_write_request<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
let DispatchRequest {
|
||||
ctx,
|
||||
span,
|
||||
request_id,
|
||||
request,
|
||||
response_completion,
|
||||
} = match ready!(self.as_mut().poll_next_request(cx)?) {
|
||||
Some(dispatch_request) => dispatch_request,
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
let entered = span.enter();
|
||||
// poll_next_request only returns Ready if there is room to buffer another request.
|
||||
// Therefore, we can call write_request without fear of erroring due to a full
|
||||
// buffer.
|
||||
let request_id = dispatch_request.request_id;
|
||||
let request_id = request_id;
|
||||
let request = ClientMessage::Request(Request {
|
||||
id: request_id,
|
||||
message: dispatch_request.request,
|
||||
message: request,
|
||||
context: context::Context {
|
||||
deadline: dispatch_request.ctx.deadline,
|
||||
trace_context: dispatch_request.ctx.trace_context,
|
||||
deadline: ctx.deadline,
|
||||
trace_context: ctx.trace_context,
|
||||
},
|
||||
});
|
||||
self.transport_pin_mut().start_send(request)?;
|
||||
self.start_send(request)?;
|
||||
let deadline = ctx.deadline;
|
||||
tracing::info!(
|
||||
tarpc.deadline = %humantime::format_rfc3339(deadline),
|
||||
"SendRequest"
|
||||
);
|
||||
drop(entered);
|
||||
|
||||
self.in_flight_requests()
|
||||
.insert_request(
|
||||
request_id,
|
||||
dispatch_request.ctx,
|
||||
dispatch_request.response_completion,
|
||||
)
|
||||
.insert_request(request_id, ctx, span, response_completion)
|
||||
.expect("Request IDs should be unique");
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
fn poll_write_cancel<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
let (context, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
|
||||
Some((context, request_id)) => (context, request_id),
|
||||
fn poll_write_cancel<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
|
||||
Some(triple) => triple,
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
let _entered = span.enter();
|
||||
|
||||
let trace_id = *context.trace_id();
|
||||
let cancel = ClientMessage::Cancel {
|
||||
trace_context: context.trace_context,
|
||||
request_id,
|
||||
};
|
||||
self.transport_pin_mut().start_send(cancel)?;
|
||||
trace!("[{}] Cancel message sent.", trace_id);
|
||||
self.start_send(cancel)?;
|
||||
tracing::info!("CancelRequest");
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
@@ -460,28 +552,24 @@ impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
|
||||
where
|
||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||
{
|
||||
type Output = anyhow::Result<()>;
|
||||
type Output = Result<(), ChannelError<C::Error>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<anyhow::Result<()>> {
|
||||
fn poll(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||
loop {
|
||||
match (
|
||||
self.as_mut()
|
||||
.pump_read(cx)
|
||||
.context("failed to read from transport")?,
|
||||
self.as_mut()
|
||||
.pump_write(cx)
|
||||
.context("failed to write to transport")?,
|
||||
) {
|
||||
match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) {
|
||||
(Poll::Ready(None), _) => {
|
||||
info!("Shutdown: read half closed, so shutting down.");
|
||||
tracing::info!("Shutdown: read half closed, so shutting down.");
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
(read, Poll::Ready(None)) => {
|
||||
if self.in_flight_requests().is_empty() {
|
||||
info!("Shutdown: write half closed, and no requests in flight.");
|
||||
if self.in_flight_requests.is_empty() {
|
||||
tracing::info!("Shutdown: write half closed, and no requests in flight.");
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
info!(
|
||||
tracing::info!(
|
||||
"Shutdown: write half closed, and {} requests in flight.",
|
||||
self.in_flight_requests().len()
|
||||
);
|
||||
@@ -502,9 +590,10 @@ where
|
||||
#[derive(Debug)]
|
||||
struct DispatchRequest<Req, Resp> {
|
||||
pub ctx: context::Context,
|
||||
pub span: Span,
|
||||
pub request_id: u64,
|
||||
pub request: Req,
|
||||
pub response_completion: oneshot::Sender<Response<Resp>>,
|
||||
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
}
|
||||
|
||||
/// Sends request cancellation signals.
|
||||
@@ -518,16 +607,14 @@ struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
|
||||
/// Returns a channel to send request cancellation messages.
|
||||
fn cancellations() -> (RequestCancellation, CanceledRequests) {
|
||||
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
|
||||
// bounded by the number of in-flight requests. Additionally, each request has a clone
|
||||
// of the sender, so the bounded channel would have the same behavior,
|
||||
// since it guarantees a slot.
|
||||
// bounded by the number of in-flight requests.
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
(RequestCancellation(tx), CanceledRequests(rx))
|
||||
}
|
||||
|
||||
impl RequestCancellation {
|
||||
/// Cancels the request with ID `request_id`.
|
||||
fn cancel(&mut self, request_id: u64) {
|
||||
fn cancel(&self, request_id: u64) {
|
||||
let _ = self.0.send(request_id);
|
||||
}
|
||||
}
|
||||
@@ -549,28 +636,58 @@ impl Stream for CanceledRequests {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
cancellations, CanceledRequests, Channel, DispatchResponse, RequestCancellation,
|
||||
RequestDispatch,
|
||||
cancellations, CanceledRequests, Channel, DispatchRequest, RequestCancellation,
|
||||
RequestDispatch, ResponseGuard,
|
||||
};
|
||||
use crate::{
|
||||
client::{in_flight_requests::InFlightRequests, Config},
|
||||
client::{
|
||||
in_flight_requests::{DeadlineExceededError, InFlightRequests},
|
||||
Config,
|
||||
},
|
||||
context,
|
||||
transport::{self, channel::UnboundedChannel},
|
||||
ClientMessage, Response,
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{prelude::*, task::*};
|
||||
use std::{pin::Pin, sync::atomic::AtomicUsize, sync::Arc};
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
pin::Pin,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::Span;
|
||||
|
||||
#[tokio::test]
|
||||
async fn response_completes_request_future() {
|
||||
let (mut dispatch, mut _channel, mut server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
dispatch
|
||||
.in_flight_requests
|
||||
.insert_request(0, context::current(), Span::current(), tx)
|
||||
.unwrap();
|
||||
server_channel
|
||||
.send(Response {
|
||||
request_id: 0,
|
||||
message: Ok("Resp".into()),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
|
||||
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_response_cancels_on_drop() {
|
||||
let (cancellation, mut canceled_requests) = cancellations();
|
||||
let (_, response) = oneshot::channel();
|
||||
drop(DispatchResponse::<u32> {
|
||||
response,
|
||||
cancellation: Some(cancellation),
|
||||
let (_, mut response) = oneshot::channel();
|
||||
drop(ResponseGuard::<u32> {
|
||||
response: &mut response,
|
||||
cancellation: &cancellation,
|
||||
request_id: 3,
|
||||
ctx: context::current(),
|
||||
});
|
||||
// resp's drop() is run, which should send a cancel message.
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
@@ -580,23 +697,22 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn dispatch_response_doesnt_cancel_after_complete() {
|
||||
let (cancellation, mut canceled_requests) = cancellations();
|
||||
let (tx, response) = oneshot::channel();
|
||||
tx.send(Response {
|
||||
let (tx, mut response) = oneshot::channel();
|
||||
tx.send(Ok(Response {
|
||||
request_id: 0,
|
||||
message: Ok("well done"),
|
||||
})
|
||||
}))
|
||||
.unwrap();
|
||||
{
|
||||
DispatchResponse {
|
||||
response,
|
||||
cancellation: Some(cancellation),
|
||||
request_id: 3,
|
||||
ctx: context::current(),
|
||||
}
|
||||
.await
|
||||
.unwrap();
|
||||
// resp's drop() is run, but should not send a cancel message.
|
||||
// resp's drop() is run, but should not send a cancel message.
|
||||
ResponseGuard {
|
||||
response: &mut response,
|
||||
cancellation: &cancellation,
|
||||
request_id: 3,
|
||||
}
|
||||
.response()
|
||||
.await
|
||||
.unwrap();
|
||||
drop(cancellation);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(None));
|
||||
}
|
||||
@@ -604,12 +720,12 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn stage_request() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let dispatch = Pin::new(&mut dispatch);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let _resp = send_request(&mut channel, "hi").await;
|
||||
let _resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
|
||||
let req = dispatch.poll_next_request(cx).ready();
|
||||
let req = dispatch.as_mut().poll_next_request(cx).ready();
|
||||
assert!(req.is_some());
|
||||
|
||||
let req = req.unwrap();
|
||||
@@ -621,10 +737,10 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn stage_request_channel_dropped_doesnt_panic() {
|
||||
let (mut dispatch, mut channel, mut server_channel) = set_up();
|
||||
let mut dispatch = Pin::new(&mut dispatch);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let _ = send_request(&mut channel, "hi").await;
|
||||
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
drop(channel);
|
||||
|
||||
assert!(dispatch.as_mut().poll(cx).is_ready());
|
||||
@@ -642,61 +758,68 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn stage_request_response_future_dropped_is_canceled_before_sending() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let dispatch = Pin::new(&mut dispatch);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let _ = send_request(&mut channel, "hi").await;
|
||||
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
|
||||
// Drop the channel so polling returns none if no requests are currently ready.
|
||||
drop(channel);
|
||||
// Test that a request future dropped before it's processed by dispatch will cause the request
|
||||
// to not be added to the in-flight request map.
|
||||
assert!(dispatch.poll_next_request(cx).ready().is_none());
|
||||
assert!(dispatch.as_mut().poll_next_request(cx).ready().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stage_request_response_future_dropped_is_canceled_after_sending() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let mut dispatch = Pin::new(&mut dispatch);
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let req = send_request(&mut channel, "hi").await;
|
||||
let req = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
|
||||
assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
|
||||
assert!(!dispatch.in_flight_requests().is_empty());
|
||||
assert!(!dispatch.in_flight_requests.is_empty());
|
||||
|
||||
// Test that a request future dropped after it's processed by dispatch will cause the request
|
||||
// to be removed from the in-flight request map.
|
||||
drop(req);
|
||||
if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(cx).unwrap() {
|
||||
// ok
|
||||
} else {
|
||||
panic!("Expected request to be cancelled")
|
||||
};
|
||||
assert!(dispatch.in_flight_requests().is_empty());
|
||||
assert_matches!(
|
||||
dispatch.as_mut().poll_next_cancellation(cx),
|
||||
Poll::Ready(Some(Ok(_)))
|
||||
);
|
||||
assert!(dispatch.in_flight_requests.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stage_request_response_closed_skipped() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let dispatch = Pin::new(&mut dispatch);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
// Test that a request future that's closed its receiver but not yet canceled its request --
|
||||
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
|
||||
// map.
|
||||
let mut resp = send_request(&mut channel, "hi").await;
|
||||
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
resp.response.close();
|
||||
|
||||
assert!(dispatch.poll_next_request(cx).is_pending());
|
||||
assert!(dispatch.as_mut().poll_next_request(cx).is_pending());
|
||||
}
|
||||
|
||||
fn set_up() -> (
|
||||
RequestDispatch<String, String, UnboundedChannel<Response<String>, ClientMessage<String>>>,
|
||||
Pin<
|
||||
Box<
|
||||
RequestDispatch<
|
||||
String,
|
||||
String,
|
||||
UnboundedChannel<Response<String>, ClientMessage<String>>,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
Channel<String, String>,
|
||||
UnboundedChannel<ClientMessage<String>, Response<String>>,
|
||||
) {
|
||||
let _ = env_logger::try_init();
|
||||
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
|
||||
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
||||
let (cancel_tx, canceled_requests) = mpsc::unbounded_channel();
|
||||
@@ -717,17 +840,31 @@ mod tests {
|
||||
next_request_id: Arc::new(AtomicUsize::new(0)),
|
||||
};
|
||||
|
||||
(dispatch, channel, server_channel)
|
||||
(Box::pin(dispatch), channel, server_channel)
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
channel: &mut Channel<String, String>,
|
||||
async fn send_request<'a>(
|
||||
channel: &'a mut Channel<String, String>,
|
||||
request: &str,
|
||||
) -> DispatchResponse<String> {
|
||||
channel
|
||||
.send(context::current(), request.to_string())
|
||||
.await
|
||||
.unwrap()
|
||||
response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
|
||||
response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
|
||||
) -> ResponseGuard<'a, String> {
|
||||
let request_id =
|
||||
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||
let request = DispatchRequest {
|
||||
ctx: context::current(),
|
||||
span: Span::current(),
|
||||
request_id,
|
||||
request: request.to_string(),
|
||||
response_completion,
|
||||
};
|
||||
channel.to_dispatch.send(request).await.unwrap();
|
||||
|
||||
ResponseGuard {
|
||||
response,
|
||||
cancellation: &channel.cancellation,
|
||||
request_id,
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_response(
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
use crate::{
|
||||
context,
|
||||
util::{Compact, TimeUntil},
|
||||
PollIo, Response, ServerError,
|
||||
Response,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::ready;
|
||||
use log::{debug, trace};
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
io,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_util::time::delay_queue::{self, DelayQueue};
|
||||
use tracing::Span;
|
||||
|
||||
/// Requests already written to the wire that haven't yet received responses.
|
||||
#[derive(Debug)]
|
||||
@@ -30,10 +28,17 @@ impl<Resp> Default for InFlightRequests<Resp> {
|
||||
}
|
||||
}
|
||||
|
||||
/// The request exceeded its deadline.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[error("the request exceeded its deadline")]
|
||||
pub struct DeadlineExceededError;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RequestData<Resp> {
|
||||
ctx: context::Context,
|
||||
response_completion: oneshot::Sender<Response<Resp>>,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
/// The key to remove the timer for the request's deadline.
|
||||
deadline_key: delay_queue::Key,
|
||||
}
|
||||
@@ -59,20 +64,16 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
&mut self,
|
||||
request_id: u64,
|
||||
ctx: context::Context,
|
||||
response_completion: oneshot::Sender<Response<Resp>>,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
) -> Result<(), AlreadyExistsError> {
|
||||
match self.request_data.entry(request_id) {
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
let timeout = ctx.deadline.time_until();
|
||||
trace!(
|
||||
"[{}] Queuing request with timeout {:?}.",
|
||||
ctx.trace_id(),
|
||||
timeout,
|
||||
);
|
||||
|
||||
let deadline_key = self.deadlines.insert(request_id, timeout);
|
||||
vacant.insert(RequestData {
|
||||
ctx,
|
||||
span,
|
||||
response_completion,
|
||||
deadline_key,
|
||||
});
|
||||
@@ -85,15 +86,15 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
/// Removes a request without aborting. Returns true iff the request was found.
|
||||
pub fn complete_request(&mut self, response: Response<Resp>) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&response.request_id) {
|
||||
let _entered = request_data.span.enter();
|
||||
tracing::info!("ReceiveResponse");
|
||||
self.request_data.compact(0.1);
|
||||
|
||||
trace!("[{}] Received response.", request_data.ctx.trace_id());
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
request_data.complete(response);
|
||||
let _ = request_data.response_completion.send(Ok(response));
|
||||
return true;
|
||||
}
|
||||
|
||||
debug!(
|
||||
tracing::debug!(
|
||||
"No in-flight request found for request_id = {}.",
|
||||
response.request_id
|
||||
);
|
||||
@@ -104,12 +105,11 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
|
||||
/// Cancels a request without completing (typically used when a request handle was dropped
|
||||
/// before the request completed).
|
||||
pub fn cancel_request(&mut self, request_id: u64) -> Option<context::Context> {
|
||||
pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
trace!("[{}] Cancelling request.", request_data.ctx.trace_id());
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
Some(request_data.ctx)
|
||||
Some((request_data.ctx, request_data.span))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -117,46 +117,21 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
|
||||
/// Yields a request that has expired, completing it with a TimedOut error.
|
||||
/// The caller should send cancellation messages for any yielded request ID.
|
||||
pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo<u64> {
|
||||
Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) {
|
||||
Some(Ok(expired)) => {
|
||||
let request_id = expired.into_inner();
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
request_data.complete(Self::deadline_exceeded_error(request_id));
|
||||
}
|
||||
Some(Ok(request_id))
|
||||
pub fn poll_expired(
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
|
||||
self.deadlines.poll_expired(cx).map_ok(|expired| {
|
||||
let request_id = expired.into_inner();
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
let _entered = request_data.span.enter();
|
||||
tracing::error!("DeadlineExceeded");
|
||||
self.request_data.compact(0.1);
|
||||
let _ = request_data
|
||||
.response_completion
|
||||
.send(Err(DeadlineExceededError));
|
||||
}
|
||||
Some(Err(e)) => Some(Err(io::Error::new(io::ErrorKind::Other, e))),
|
||||
None => None,
|
||||
request_id
|
||||
})
|
||||
}
|
||||
|
||||
fn deadline_exceeded_error(request_id: u64) -> Response<Resp> {
|
||||
Response {
|
||||
request_id,
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::TimedOut,
|
||||
detail: Some("Client dropped expired request.".to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// When InFlightRequests is dropped, any outstanding requests are completed with a
|
||||
/// deadline-exceeded error.
|
||||
impl<Resp> Drop for InFlightRequests<Resp> {
|
||||
fn drop(&mut self) {
|
||||
let deadlines = &mut self.deadlines;
|
||||
for (_, request_data) in self.request_data.drain() {
|
||||
let expired = deadlines.remove(&request_data.deadline_key);
|
||||
request_data.complete(Self::deadline_exceeded_error(expired.into_inner()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Resp> RequestData<Resp> {
|
||||
fn complete(self, response: Response<Resp>) {
|
||||
let _ = self.response_completion.send(response);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,13 @@
|
||||
//! client to server and is used by the server to enforce response deadlines.
|
||||
|
||||
use crate::trace::{self, TraceId};
|
||||
use opentelemetry::trace::TraceContextExt;
|
||||
use static_assertions::assert_impl_all;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
/// A request context that carries request-scoped information like deadlines and trace information.
|
||||
/// It is sent from client to server and is used by the server to enforce response deadlines.
|
||||
@@ -22,14 +27,6 @@ use std::time::{Duration, SystemTime};
|
||||
pub struct Context {
|
||||
/// When the client expects the request to be complete by. The server should cancel the request
|
||||
/// if it is not complete by this time.
|
||||
#[cfg_attr(
|
||||
feature = "serde1",
|
||||
serde(serialize_with = "crate::util::serde::serialize_epoch_secs")
|
||||
)]
|
||||
#[cfg_attr(
|
||||
feature = "serde1",
|
||||
serde(deserialize_with = "crate::util::serde::deserialize_epoch_secs")
|
||||
)]
|
||||
#[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))]
|
||||
pub deadline: SystemTime,
|
||||
/// Uniquely identifies requests originating from the same source.
|
||||
@@ -41,23 +38,65 @@ pub struct Context {
|
||||
|
||||
assert_impl_all!(Context: Send, Sync);
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
fn ten_seconds_from_now() -> SystemTime {
|
||||
SystemTime::now() + Duration::from_secs(10)
|
||||
}
|
||||
|
||||
/// Returns the context for the current request, or a default Context if no request is active.
|
||||
// TODO: populate Context with request-scoped data, with default fallbacks.
|
||||
pub fn current() -> Context {
|
||||
Context {
|
||||
deadline: SystemTime::now() + Duration::from_secs(10),
|
||||
trace_context: trace::Context::new_root(),
|
||||
Context::current()
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Deadline(SystemTime);
|
||||
|
||||
impl Default for Deadline {
|
||||
fn default() -> Self {
|
||||
Self(ten_seconds_from_now())
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
/// Returns the context for the current request, or a default Context if no request is active.
|
||||
pub fn current() -> Self {
|
||||
let span = tracing::Span::current();
|
||||
Self {
|
||||
trace_context: trace::Context::try_from(&span)
|
||||
.unwrap_or_else(|_| trace::Context::default()),
|
||||
deadline: span
|
||||
.context()
|
||||
.get::<Deadline>()
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the ID of the request-scoped trace.
|
||||
pub fn trace_id(&self) -> &TraceId {
|
||||
&self.trace_context.trace_id
|
||||
}
|
||||
}
|
||||
|
||||
/// An extension trait for [`tracing::Span`] for propagating tarpc Contexts.
|
||||
pub(crate) trait SpanExt {
|
||||
/// Sets the given context on this span. Newly-created spans will be children of the given
|
||||
/// context's trace context.
|
||||
fn set_context(&self, context: &Context);
|
||||
}
|
||||
|
||||
impl SpanExt for tracing::Span {
|
||||
fn set_context(&self, context: &Context) {
|
||||
self.set_parent(
|
||||
opentelemetry::Context::new()
|
||||
.with_remote_span_context(opentelemetry::trace::SpanContext::new(
|
||||
opentelemetry::trace::TraceId::from(context.trace_context.trace_id),
|
||||
opentelemetry::trace::SpanId::from(context.trace_context.span_id),
|
||||
opentelemetry::trace::TraceFlags::from(context.trace_context.sampling_decision),
|
||||
true,
|
||||
opentelemetry::trace::TraceState::default(),
|
||||
))
|
||||
.with_value(Deadline(context.deadline)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,14 @@
|
||||
//! requests sent by the server that use the request context will propagate the request deadline.
|
||||
//! For example, if a server is handling a request with a 10s deadline, does 2s of work, then
|
||||
//! sends a request to another server, that server will see an 8s deadline.
|
||||
//! - Distributed tracing: tarpc is instrumented with
|
||||
//! [tracing](https://github.com/tokio-rs/tracing) primitives extended with
|
||||
//! [OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like
|
||||
//! [Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
|
||||
//! each RPC can be traced through the client, server, amd other dependencies downstream of the
|
||||
//! server. Even for applications not connected to a distributed tracing collector, the
|
||||
//! instrumentation can also be ingested by regular loggers like
|
||||
//! [env_logger](https://github.com/env-logger-rs/env_logger/).
|
||||
//! - Serde serialization: enabling the `serde1` Cargo feature will make service requests and
|
||||
//! responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
|
||||
//! be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
|
||||
@@ -46,7 +54,7 @@
|
||||
//! Add to your `Cargo.toml` dependencies:
|
||||
//!
|
||||
//! ```toml
|
||||
//! tarpc = "0.25"
|
||||
//! tarpc = "0.27"
|
||||
//! ```
|
||||
//!
|
||||
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
@@ -59,8 +67,9 @@
|
||||
//! your `Cargo.toml`:
|
||||
//!
|
||||
//! ```toml
|
||||
//! anyhow = "1.0"
|
||||
//! futures = "1.0"
|
||||
//! tarpc = { version = "0.25", features = ["tokio1"] }
|
||||
//! tarpc = { version = "0.27", features = ["tokio1"] }
|
||||
//! tokio = { version = "1.0", features = ["macros"] }
|
||||
//! ```
|
||||
//!
|
||||
@@ -79,9 +88,8 @@
|
||||
//! };
|
||||
//! use tarpc::{
|
||||
//! client, context,
|
||||
//! server::{self, Incoming},
|
||||
//! server::{self, incoming::Incoming},
|
||||
//! };
|
||||
//! use std::io;
|
||||
//!
|
||||
//! // This is the service definition. It looks a lot like a trait definition.
|
||||
//! // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
@@ -103,9 +111,8 @@
|
||||
//! # };
|
||||
//! # use tarpc::{
|
||||
//! # client, context,
|
||||
//! # server::{self, Incoming},
|
||||
//! # server::{self, incoming::Incoming},
|
||||
//! # };
|
||||
//! # use std::io;
|
||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
//! # #[tarpc::service]
|
||||
@@ -145,7 +152,6 @@
|
||||
//! # client, context,
|
||||
//! # server::{self, Channel},
|
||||
//! # };
|
||||
//! # use std::io;
|
||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
//! # #[tarpc::service]
|
||||
@@ -169,7 +175,7 @@
|
||||
//! # fn main() {}
|
||||
//! # #[cfg(feature = "tokio1")]
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> io::Result<()> {
|
||||
//! async fn main() -> anyhow::Result<()> {
|
||||
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
//!
|
||||
//! let server = server::BaseChannel::with_defaults(server_transport);
|
||||
@@ -177,7 +183,7 @@
|
||||
//!
|
||||
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
//! // that takes a config and any Transport as input.
|
||||
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
|
||||
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
||||
//!
|
||||
//! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
//! // args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||
@@ -199,6 +205,7 @@
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
#[doc(hidden)]
|
||||
pub use serde;
|
||||
|
||||
#[cfg(feature = "serde-transport")]
|
||||
@@ -303,7 +310,7 @@ pub use crate::transport::sealed::Transport;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use futures::task::*;
|
||||
use std::{fmt::Display, io, time::SystemTime};
|
||||
use std::{error::Error, fmt::Display, io, time::SystemTime};
|
||||
|
||||
/// A message from a client to a server.
|
||||
#[derive(Debug)]
|
||||
@@ -355,8 +362,9 @@ pub struct Response<T> {
|
||||
pub message: Result<T, ServerError>,
|
||||
}
|
||||
|
||||
/// An error response from a server to a client.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
/// An error indicating the server aborted the request early, e.g., due to request throttling.
|
||||
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[error("{kind:?}: {detail}")]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ServerError {
|
||||
@@ -371,13 +379,7 @@ pub struct ServerError {
|
||||
/// The type of error that occurred to fail the request.
|
||||
pub kind: io::ErrorKind,
|
||||
/// A message describing more detail about the error that occurred.
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
impl From<ServerError> for io::Error {
|
||||
fn from(e: ServerError) -> io::Error {
|
||||
io::Error::new(e.kind, e.detail.unwrap_or_default())
|
||||
}
|
||||
pub detail: String,
|
||||
}
|
||||
|
||||
impl<T> Request<T> {
|
||||
@@ -387,7 +389,6 @@ impl<T> Request<T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) type PollIo<T> = Poll<Option<io::Result<T>>>;
|
||||
pub(crate) trait PollContext<T> {
|
||||
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
@@ -399,7 +400,10 @@ pub(crate) trait PollContext<T> {
|
||||
F: FnOnce() -> C;
|
||||
}
|
||||
|
||||
impl<T> PollContext<T> for PollIo<T> {
|
||||
impl<T, E> PollContext<T> for Poll<Option<Result<T, E>>>
|
||||
where
|
||||
E: Error + Send + Sync + 'static,
|
||||
{
|
||||
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
|
||||
@@ -42,14 +42,10 @@ where
|
||||
type Item = io::Result<Item>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
|
||||
match self.project().inner.poll_next(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Ready(Some(Ok::<_, CodecError>(next))) => Poll::Ready(Some(Ok(next))),
|
||||
Poll::Ready(Some(Err::<_, CodecError>(e))) => {
|
||||
Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e))))
|
||||
}
|
||||
}
|
||||
self.project()
|
||||
.inner
|
||||
.poll_next(cx)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,7 +61,10 @@ where
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
convert(self.project().inner.poll_ready(cx))
|
||||
self.project()
|
||||
.inner
|
||||
.poll_ready(cx)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
||||
@@ -76,20 +75,20 @@ where
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
convert(self.project().inner.poll_flush(cx))
|
||||
self.project()
|
||||
.inner
|
||||
.poll_flush(cx)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
convert(self.project().inner.poll_close(cx))
|
||||
self.project()
|
||||
.inner
|
||||
.poll_close(cx)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
}
|
||||
|
||||
fn convert<E: Into<Box<dyn Error + Send + Sync>>>(
|
||||
poll: Poll<Result<(), E>>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
poll.map(|ready| ready.map_err(|e| io::Error::new(io::ErrorKind::Other, e)))
|
||||
}
|
||||
|
||||
/// Constructs a new transport from a framed transport and a serialization codec.
|
||||
pub fn new<S, Item, SinkItem, Codec>(
|
||||
framed_io: Framed<S, LengthDelimitedCodec>,
|
||||
|
||||
1156
tarpc/src/server.rs
1156
tarpc/src/server.rs
File diff suppressed because it is too large
Load Diff
@@ -1,19 +1,13 @@
|
||||
use crate::{
|
||||
util::{Compact, TimeUntil},
|
||||
PollIo,
|
||||
};
|
||||
use crate::util::{Compact, TimeUntil};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{
|
||||
future::{AbortHandle, AbortRegistration},
|
||||
ready,
|
||||
};
|
||||
use futures::future::{AbortHandle, AbortRegistration};
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
io,
|
||||
task::{Context, Poll},
|
||||
time::SystemTime,
|
||||
};
|
||||
use tokio_util::time::delay_queue::{self, DelayQueue};
|
||||
use tracing::Span;
|
||||
|
||||
/// A data structure that tracks in-flight requests. It aborts requests,
|
||||
/// either on demand or when a request deadline expires.
|
||||
@@ -23,13 +17,15 @@ pub struct InFlightRequests {
|
||||
deadlines: DelayQueue<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Data needed to clean up a single in-flight request.
|
||||
#[derive(Debug)]
|
||||
struct RequestData {
|
||||
/// Aborts the response handler for the associated request.
|
||||
abort_handle: AbortHandle,
|
||||
/// The key to remove the timer for the request's deadline.
|
||||
deadline_key: delay_queue::Key,
|
||||
/// The client span.
|
||||
span: Span,
|
||||
}
|
||||
|
||||
/// An error returned when a request attempted to start with the same ID as a request already
|
||||
@@ -48,6 +44,7 @@ impl InFlightRequests {
|
||||
&mut self,
|
||||
request_id: u64,
|
||||
deadline: SystemTime,
|
||||
span: Span,
|
||||
) -> Result<AbortRegistration, AlreadyExistsError> {
|
||||
match self.request_data.entry(request_id) {
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
@@ -57,6 +54,7 @@ impl InFlightRequests {
|
||||
vacant.insert(RequestData {
|
||||
abort_handle,
|
||||
deadline_key,
|
||||
span,
|
||||
});
|
||||
Ok(abort_registration)
|
||||
}
|
||||
@@ -66,12 +64,17 @@ impl InFlightRequests {
|
||||
|
||||
/// Cancels an in-flight request. Returns true iff the request was found.
|
||||
pub fn cancel_request(&mut self, request_id: u64) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
if let Some(RequestData {
|
||||
span,
|
||||
abort_handle,
|
||||
deadline_key,
|
||||
}) = self.request_data.remove(&request_id)
|
||||
{
|
||||
let _entered = span.enter();
|
||||
self.request_data.compact(0.1);
|
||||
|
||||
request_data.abort_handle.abort();
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
|
||||
abort_handle.abort();
|
||||
self.deadlines.remove(&deadline_key);
|
||||
tracing::info!("ReceiveCancel");
|
||||
true
|
||||
} else {
|
||||
false
|
||||
@@ -80,30 +83,32 @@ impl InFlightRequests {
|
||||
|
||||
/// Removes a request without aborting. Returns true iff the request was found.
|
||||
/// This method should be used when a response is being sent.
|
||||
pub fn remove_request(&mut self, request_id: u64) -> bool {
|
||||
pub fn remove_request(&mut self, request_id: u64) -> Option<Span> {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
|
||||
true
|
||||
Some(request_data.span)
|
||||
} else {
|
||||
false
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Yields a request that has expired, aborting any ongoing processing of that request.
|
||||
pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo<u64> {
|
||||
Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) {
|
||||
Some(Ok(expired)) => {
|
||||
if let Some(request_data) = self.request_data.remove(expired.get_ref()) {
|
||||
self.request_data.compact(0.1);
|
||||
request_data.abort_handle.abort();
|
||||
}
|
||||
Some(Ok(expired.into_inner()))
|
||||
pub fn poll_expired(
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
|
||||
self.deadlines.poll_expired(cx).map_ok(|expired| {
|
||||
if let Some(RequestData {
|
||||
abort_handle, span, ..
|
||||
}) = self.request_data.remove(expired.get_ref())
|
||||
{
|
||||
let _entered = span.enter();
|
||||
self.request_data.compact(0.1);
|
||||
abort_handle.abort();
|
||||
tracing::error!("DeadlineExceeded");
|
||||
}
|
||||
Some(Err(e)) => Some(Err(io::Error::new(io::ErrorKind::Other, e))),
|
||||
None => None,
|
||||
expired.into_inner()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -118,75 +123,77 @@ impl Drop for InFlightRequests {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use {
|
||||
assert_matches::assert_matches,
|
||||
futures::{
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{
|
||||
future::{pending, Abortable},
|
||||
FutureExt,
|
||||
},
|
||||
futures_test::task::noop_context,
|
||||
};
|
||||
};
|
||||
use futures_test::task::noop_context;
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_request_increases_len() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.unwrap();
|
||||
assert_eq!(in_flight_requests.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn polling_expired_aborts() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
tokio::time::pause();
|
||||
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
|
||||
|
||||
assert_matches!(
|
||||
in_flight_requests.poll_expired(&mut noop_context()),
|
||||
Poll::Ready(Some(Ok(_)))
|
||||
);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Ready(Err(_))
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancel_request_aborts() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
assert_eq!(in_flight_requests.cancel_request(0), true);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Ready(Err(_))
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_request_doesnt_abort() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
assert_eq!(in_flight_requests.remove_request(0), true);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Pending
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
#[tokio::test]
|
||||
async fn start_request_increases_len() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
in_flight_requests
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
assert_eq!(in_flight_requests.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn polling_expired_aborts() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
tokio::time::pause();
|
||||
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
|
||||
|
||||
assert_matches!(
|
||||
in_flight_requests.poll_expired(&mut noop_context()),
|
||||
Poll::Ready(Some(Ok(_)))
|
||||
);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Ready(Err(_))
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancel_request_aborts() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
assert_eq!(in_flight_requests.cancel_request(0), true);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Ready(Err(_))
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_request_doesnt_abort() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
assert_matches!(in_flight_requests.remove_request(0), Some(_));
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Pending
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
49
tarpc/src/server/incoming.rs
Normal file
49
tarpc/src/server/incoming.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use super::{
|
||||
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
|
||||
Channel,
|
||||
};
|
||||
use futures::prelude::*;
|
||||
use std::{fmt, hash::Hash};
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
use super::{tokio::TokioServerExecutor, Serve};
|
||||
|
||||
/// An extension trait for [streams](Stream) of [`Channels`](Channel).
|
||||
pub trait Incoming<C>
|
||||
where
|
||||
Self: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
/// Enforces channel per-key limits.
|
||||
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> MaxChannelsPerKey<Self, K, KF>
|
||||
where
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
KF: Fn(&C) -> K,
|
||||
{
|
||||
MaxChannelsPerKey::new(self, n, keymaker)
|
||||
}
|
||||
|
||||
/// Caps the number of concurrent requests per channel.
|
||||
fn max_concurrent_requests_per_channel(self, n: usize) -> MaxRequestsPerChannel<Self> {
|
||||
MaxRequestsPerChannel::new(self, n)
|
||||
}
|
||||
|
||||
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
|
||||
/// concurrently by spawning on tokio's default executor, and each request will be also
|
||||
/// be spawned on tokio's default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
{
|
||||
TokioServerExecutor::new(self, serve)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C> Incoming<C> for S
|
||||
where
|
||||
S: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
}
|
||||
5
tarpc/src/server/limits.rs
Normal file
5
tarpc/src/server/limits.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
/// Provides functionality to limit the number of active channels.
|
||||
pub mod channels_per_key;
|
||||
|
||||
/// Provides a [channel](crate::server::Channel) that limits the number of in-flight requests.
|
||||
pub mod requests_per_channel;
|
||||
@@ -9,20 +9,23 @@ use crate::{
|
||||
util::Compact,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
|
||||
use log::{debug, info, trace};
|
||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||
use pin_project::pin_project;
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::{
|
||||
collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin,
|
||||
time::SystemTime,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, info, trace};
|
||||
|
||||
/// A single-threaded filter that drops channels based on per-key limits.
|
||||
/// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on
|
||||
/// per-key limits.
|
||||
///
|
||||
/// The decision to drop a Channel is made once at the time the Channel materializes. Once a
|
||||
/// Channel is yielded, it will not be prematurely dropped.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ChannelFilter<S, K, F>
|
||||
pub struct MaxChannelsPerKey<S, K, F>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
{
|
||||
@@ -35,7 +38,7 @@ where
|
||||
keymaker: F,
|
||||
}
|
||||
|
||||
/// A channel that is tracked by a ChannelFilter.
|
||||
/// A channel that is tracked by [`MaxChannelsPerKey`].
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct TrackedChannel<C, K> {
|
||||
@@ -103,6 +106,7 @@ where
|
||||
{
|
||||
type Req = C::Req;
|
||||
type Resp = C::Resp;
|
||||
type Transport = C::Transport;
|
||||
|
||||
fn config(&self) -> &server::Config {
|
||||
self.inner.config()
|
||||
@@ -112,12 +116,8 @@ where
|
||||
self.inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn start_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.inner_pin_mut().start_request(id, deadline)
|
||||
fn transport(&self) -> &Self::Transport {
|
||||
self.inner.transport()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,7 +133,7 @@ impl<C, K> TrackedChannel<C, K> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, K, F> ChannelFilter<S, K, F>
|
||||
impl<S, K, F> MaxChannelsPerKey<S, K, F>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
S: Stream,
|
||||
@@ -142,7 +142,7 @@ where
|
||||
/// Sheds new channels to stay under configured limits.
|
||||
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
|
||||
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel();
|
||||
ChannelFilter {
|
||||
MaxChannelsPerKey {
|
||||
listener: listener.fuse(),
|
||||
channels_per_key,
|
||||
dropped_keys,
|
||||
@@ -153,7 +153,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, K, F> ChannelFilter<S, K, F>
|
||||
impl<S, K, F> MaxChannelsPerKey<S, K, F>
|
||||
where
|
||||
S: Stream,
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
@@ -171,11 +171,10 @@ where
|
||||
let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
|
||||
|
||||
trace!(
|
||||
"[{}] Opening channel ({}/{}) channels for key.",
|
||||
key,
|
||||
Arc::strong_count(&tracker),
|
||||
self.channels_per_key
|
||||
);
|
||||
channel_filter_key = %key,
|
||||
open_channels = Arc::strong_count(&tracker),
|
||||
max_open_channels = self.channels_per_key,
|
||||
"Opening channel");
|
||||
|
||||
Ok(TrackedChannel {
|
||||
tracker,
|
||||
@@ -200,9 +199,10 @@ where
|
||||
let count = o.get().strong_count();
|
||||
if count >= TryFrom::try_from(*self_.channels_per_key).unwrap() {
|
||||
info!(
|
||||
"[{}] Opened max channels from key ({}/{}).",
|
||||
key, count, self_.channels_per_key
|
||||
);
|
||||
channel_filter_key = %key,
|
||||
open_channels = count,
|
||||
max_open_channels = *self_.channels_per_key,
|
||||
"At open channel limit");
|
||||
Err(key)
|
||||
} else {
|
||||
Ok(o.get().upgrade().unwrap_or_else(|| {
|
||||
@@ -233,7 +233,9 @@ where
|
||||
let self_ = self.project();
|
||||
match ready!(self_.dropped_keys.poll_recv(cx)) {
|
||||
Some(key) => {
|
||||
debug!("All channels dropped for key [{}]", key);
|
||||
debug!(
|
||||
channel_filter_key = %key,
|
||||
"All channels dropped");
|
||||
self_.key_counts.remove(&key);
|
||||
self_.key_counts.compact(0.1);
|
||||
Poll::Ready(())
|
||||
@@ -243,7 +245,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, K, F> Stream for ChannelFilter<S, K, F>
|
||||
impl<S, K, F> Stream for MaxChannelsPerKey<S, K, F>
|
||||
where
|
||||
S: Stream,
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
@@ -346,7 +348,7 @@ fn channel_filter_increment_channels_for_key() {
|
||||
key: &'static str,
|
||||
}
|
||||
let (_, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
|
||||
assert_eq!(Arc::strong_count(&tracker1), 1);
|
||||
@@ -367,7 +369,7 @@ fn channel_filter_handle_new_channel() {
|
||||
key: &'static str,
|
||||
}
|
||||
let (_, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
let channel1 = filter
|
||||
.as_mut()
|
||||
@@ -399,7 +401,7 @@ fn channel_filter_poll_listener() {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
new_channels
|
||||
@@ -435,7 +437,7 @@ fn channel_filter_poll_closed_channels() {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
new_channels
|
||||
@@ -463,7 +465,7 @@ fn channel_filter_stream() {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
new_channels
|
||||
349
tarpc/src/server/limits/requests_per_channel.rs
Normal file
349
tarpc/src/server/limits/requests_per_channel.rs
Normal file
@@ -0,0 +1,349 @@
|
||||
// Copyright 2020 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use crate::{
|
||||
server::{Channel, Config},
|
||||
Response, ServerError,
|
||||
};
|
||||
use futures::{prelude::*, ready, task::*};
|
||||
use pin_project::pin_project;
|
||||
use std::{io, pin::Pin};
|
||||
|
||||
/// A [`Channel`] that limits the number of concurrent requests by throttling.
|
||||
///
|
||||
/// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low
|
||||
/// for the resources available to the server. For production use cases, a more advanced throttler
|
||||
/// is likely needed.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct MaxRequests<C> {
|
||||
max_in_flight_requests: usize,
|
||||
#[pin]
|
||||
inner: C,
|
||||
}
|
||||
|
||||
impl<C> MaxRequests<C> {
|
||||
/// Returns the inner channel.
|
||||
pub fn get_ref(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> MaxRequests<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
/// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to
|
||||
/// `max_in_flight_requests`.
|
||||
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
|
||||
MaxRequests {
|
||||
max_in_flight_requests,
|
||||
inner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Stream for MaxRequests<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Item = <C as Stream>::Item;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
|
||||
{
|
||||
ready!(self.as_mut().project().inner.poll_ready(cx)?);
|
||||
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
|
||||
Some(r) => {
|
||||
let _entered = r.span.enter();
|
||||
tracing::info!(
|
||||
in_flight_requests = self.as_mut().in_flight_requests(),
|
||||
"ThrottleRequest",
|
||||
);
|
||||
|
||||
self.as_mut().start_send(Response {
|
||||
request_id: r.request.id,
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::WouldBlock,
|
||||
detail: "server throttled the request.".into(),
|
||||
}),
|
||||
})?;
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
self.project().inner.poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Sink<Response<<C as Channel>::Resp>> for MaxRequests<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Error = C::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(
|
||||
self: Pin<&mut Self>,
|
||||
item: Response<<C as Channel>::Resp>,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.project().inner.start_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().inner.poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> AsRef<C> for MaxRequests<C> {
|
||||
fn as_ref(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Channel for MaxRequests<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Req = <C as Channel>::Req;
|
||||
type Resp = <C as Channel>::Resp;
|
||||
type Transport = <C as Channel>::Transport;
|
||||
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
self.inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
self.inner.config()
|
||||
}
|
||||
|
||||
fn transport(&self) -> &Self::Transport {
|
||||
self.inner.transport()
|
||||
}
|
||||
}
|
||||
|
||||
/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on
|
||||
/// the number of in-flight requests.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct MaxRequestsPerChannel<S> {
|
||||
#[pin]
|
||||
inner: S,
|
||||
max_in_flight_requests: usize,
|
||||
}
|
||||
|
||||
impl<S> MaxRequestsPerChannel<S>
|
||||
where
|
||||
S: Stream,
|
||||
<S as Stream>::Item: Channel,
|
||||
{
|
||||
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
max_in_flight_requests,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream for MaxRequestsPerChannel<S>
|
||||
where
|
||||
S: Stream,
|
||||
<S as Stream>::Item: Channel,
|
||||
{
|
||||
type Item = MaxRequests<<S as Stream>::Item>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)) {
|
||||
Some(channel) => Poll::Ready(Some(MaxRequests::new(
|
||||
channel,
|
||||
*self.project().max_in_flight_requests,
|
||||
))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::server::{
|
||||
testing::{self, FakeChannel, PollExt},
|
||||
TrackedRequest,
|
||||
};
|
||||
use pin_utils::pin_mut;
|
||||
use std::{
|
||||
marker::PhantomData,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use tracing::Span;
|
||||
|
||||
#[tokio::test]
|
||||
async fn throttler_in_flight_requests() {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
for i in 0..5 {
|
||||
throttler
|
||||
.inner
|
||||
.in_flight_requests
|
||||
.start_request(
|
||||
i,
|
||||
SystemTime::now() + Duration::from_secs(1),
|
||||
Span::current(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_done() {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_some() -> io::Result<()> {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 1,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.inner.push_req(0, 1);
|
||||
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
|
||||
assert_eq!(
|
||||
throttler
|
||||
.as_mut()
|
||||
.poll_next(&mut testing::cx())?
|
||||
.map(|r| r.map(|r| (r.request.id, r.request.message))),
|
||||
Poll::Ready(Some((0, 1)))
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_throttled() {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.inner.push_req(1, 1);
|
||||
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
||||
assert_eq!(throttler.inner.sink.len(), 1);
|
||||
let resp = throttler.inner.sink.get(0).unwrap();
|
||||
assert_eq!(resp.request_id, 1);
|
||||
assert!(resp.message.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_throttled_sink_not_ready() {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: PendingSink::default::<isize, isize>(),
|
||||
};
|
||||
pin_mut!(throttler);
|
||||
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
|
||||
|
||||
struct PendingSink<In, Out> {
|
||||
ghost: PhantomData<fn(Out) -> In>,
|
||||
}
|
||||
impl PendingSink<(), ()> {
|
||||
pub fn default<Req, Resp>(
|
||||
) -> PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
PendingSink { ghost: PhantomData }
|
||||
}
|
||||
}
|
||||
impl<In, Out> Stream for PendingSink<In, Out> {
|
||||
type Item = In;
|
||||
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
|
||||
type Error = io::Error;
|
||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
|
||||
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
impl<Req, Resp> Channel for PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
type Transport = ();
|
||||
fn config(&self) -> &Config {
|
||||
unimplemented!()
|
||||
}
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
0
|
||||
}
|
||||
fn transport(&self) -> &() {
|
||||
&()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn throttler_start_send() {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler
|
||||
.inner
|
||||
.in_flight_requests
|
||||
.start_request(
|
||||
0,
|
||||
SystemTime::now() + Duration::from_secs(1),
|
||||
Span::current(),
|
||||
)
|
||||
.unwrap();
|
||||
throttler
|
||||
.as_mut()
|
||||
.start_send(Response {
|
||||
request_id: 0,
|
||||
message: Ok(1),
|
||||
})
|
||||
.unwrap();
|
||||
assert_eq!(throttler.inner.in_flight_requests.len(), 0);
|
||||
assert_eq!(
|
||||
throttler.inner.sink.get(0),
|
||||
Some(&Response {
|
||||
request_id: 0,
|
||||
message: Ok(1),
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -6,15 +6,13 @@
|
||||
|
||||
use crate::{
|
||||
context,
|
||||
server::{Channel, Config},
|
||||
server::{Channel, Config, TrackedRequest},
|
||||
Request, Response,
|
||||
};
|
||||
use futures::{future::AbortRegistration, task::*, Sink, Stream};
|
||||
use futures::{task::*, Sink, Stream};
|
||||
use pin_project::pin_project;
|
||||
use std::collections::VecDeque;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::time::SystemTime;
|
||||
use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
|
||||
use tracing::Span;
|
||||
|
||||
#[pin_project]
|
||||
pub(crate) struct FakeChannel<In, Out> {
|
||||
@@ -64,12 +62,13 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> Channel for FakeChannel<io::Result<Request<Req>>, Response<Resp>>
|
||||
impl<Req, Resp> Channel for FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>>
|
||||
where
|
||||
Req: Unpin,
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
type Transport = ();
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
&self.config
|
||||
@@ -79,32 +78,31 @@ where
|
||||
self.in_flight_requests.len()
|
||||
}
|
||||
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.project()
|
||||
.in_flight_requests
|
||||
.start_request(id, deadline)
|
||||
fn transport(&self) -> &() {
|
||||
&()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
|
||||
impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
pub fn push_req(&mut self, id: u64, message: Req) {
|
||||
self.stream.push_back(Ok(Request {
|
||||
context: context::Context {
|
||||
deadline: SystemTime::UNIX_EPOCH,
|
||||
trace_context: Default::default(),
|
||||
let (_, abort_registration) = futures::future::AbortHandle::new_pair();
|
||||
self.stream.push_back(Ok(TrackedRequest {
|
||||
request: Request {
|
||||
context: context::Context {
|
||||
deadline: SystemTime::UNIX_EPOCH,
|
||||
trace_context: Default::default(),
|
||||
},
|
||||
id,
|
||||
message,
|
||||
},
|
||||
id,
|
||||
message,
|
||||
abort_registration,
|
||||
span: Span::none(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeChannel<(), ()> {
|
||||
pub fn default<Req, Resp>() -> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
|
||||
pub fn default<Req, Resp>() -> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
FakeChannel {
|
||||
stream: Default::default(),
|
||||
sink: Default::default(),
|
||||
|
||||
@@ -1,347 +0,0 @@
|
||||
// Copyright 2020 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use super::{Channel, Config};
|
||||
use crate::{Response, ServerError};
|
||||
use futures::{future::AbortRegistration, prelude::*, ready, task::*};
|
||||
use log::debug;
|
||||
use pin_project::pin_project;
|
||||
use std::{io, pin::Pin, time::SystemTime};
|
||||
|
||||
/// A [`Channel`] that limits the number of concurrent
|
||||
/// requests by throttling.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct Throttler<C> {
|
||||
max_in_flight_requests: usize,
|
||||
#[pin]
|
||||
inner: C,
|
||||
}
|
||||
|
||||
impl<C> Throttler<C> {
|
||||
/// Returns the inner channel.
|
||||
pub fn get_ref(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Throttler<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
/// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to
|
||||
/// `max_in_flight_requests`.
|
||||
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
|
||||
Throttler {
|
||||
max_in_flight_requests,
|
||||
inner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Stream for Throttler<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Item = <C as Stream>::Item;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
|
||||
{
|
||||
ready!(self.as_mut().project().inner.poll_ready(cx)?);
|
||||
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
|
||||
Some(request) => {
|
||||
debug!(
|
||||
"[{}] Client has reached in-flight request limit ({}/{}).",
|
||||
request.context.trace_id(),
|
||||
self.as_mut().in_flight_requests(),
|
||||
self.as_mut().project().max_in_flight_requests,
|
||||
);
|
||||
|
||||
self.as_mut().start_send(Response {
|
||||
request_id: request.id,
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::WouldBlock,
|
||||
detail: Some("Server throttled the request.".into()),
|
||||
}),
|
||||
})?;
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
self.project().inner.poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
|
||||
self.project().inner.start_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
self.project().inner.poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> AsRef<C> for Throttler<C> {
|
||||
fn as_ref(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Channel for Throttler<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Req = <C as Channel>::Req;
|
||||
type Resp = <C as Channel>::Resp;
|
||||
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
self.inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
self.inner.config()
|
||||
}
|
||||
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.project().inner.start_request(id, deadline)
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream of throttling channels.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ThrottlerStream<S> {
|
||||
#[pin]
|
||||
inner: S,
|
||||
max_in_flight_requests: usize,
|
||||
}
|
||||
|
||||
impl<S> ThrottlerStream<S>
|
||||
where
|
||||
S: Stream,
|
||||
<S as Stream>::Item: Channel,
|
||||
{
|
||||
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
max_in_flight_requests,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream for ThrottlerStream<S>
|
||||
where
|
||||
S: Stream,
|
||||
<S as Stream>::Item: Channel,
|
||||
{
|
||||
type Item = Throttler<<S as Stream>::Item>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)) {
|
||||
Some(channel) => Poll::Ready(Some(Throttler::new(
|
||||
channel,
|
||||
*self.project().max_in_flight_requests,
|
||||
))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use super::testing::{self, FakeChannel, PollExt};
|
||||
#[cfg(test)]
|
||||
use crate::Request;
|
||||
#[cfg(test)]
|
||||
use pin_utils::pin_mut;
|
||||
#[cfg(test)]
|
||||
use std::{marker::PhantomData, time::Duration};
|
||||
|
||||
#[tokio::test]
|
||||
async fn throttler_in_flight_requests() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
for i in 0..5 {
|
||||
throttler
|
||||
.inner
|
||||
.in_flight_requests
|
||||
.start_request(i, SystemTime::now() + Duration::from_secs(1))
|
||||
.unwrap();
|
||||
}
|
||||
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn throttler_start_request() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler
|
||||
.as_mut()
|
||||
.start_request(1, SystemTime::now() + Duration::from_secs(1))
|
||||
.unwrap();
|
||||
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_done() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_some() -> io::Result<()> {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 1,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.inner.push_req(0, 1);
|
||||
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
|
||||
assert_eq!(
|
||||
throttler
|
||||
.as_mut()
|
||||
.poll_next(&mut testing::cx())?
|
||||
.map(|r| r.map(|r| (r.id, r.message))),
|
||||
Poll::Ready(Some((0, 1)))
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_throttled() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.inner.push_req(1, 1);
|
||||
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
||||
assert_eq!(throttler.inner.sink.len(), 1);
|
||||
let resp = throttler.inner.sink.get(0).unwrap();
|
||||
assert_eq!(resp.request_id, 1);
|
||||
assert!(resp.message.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_throttled_sink_not_ready() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: PendingSink::default::<isize, isize>(),
|
||||
};
|
||||
pin_mut!(throttler);
|
||||
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
|
||||
|
||||
struct PendingSink<In, Out> {
|
||||
ghost: PhantomData<fn(Out) -> In>,
|
||||
}
|
||||
impl PendingSink<(), ()> {
|
||||
pub fn default<Req, Resp>() -> PendingSink<io::Result<Request<Req>>, Response<Resp>> {
|
||||
PendingSink { ghost: PhantomData }
|
||||
}
|
||||
}
|
||||
impl<In, Out> Stream for PendingSink<In, Out> {
|
||||
type Item = In;
|
||||
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
|
||||
type Error = io::Error;
|
||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
|
||||
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>> {
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
fn config(&self) -> &Config {
|
||||
unimplemented!()
|
||||
}
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
0
|
||||
}
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
_id: u64,
|
||||
_deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn throttler_start_send() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler
|
||||
.inner
|
||||
.in_flight_requests
|
||||
.start_request(0, SystemTime::now() + Duration::from_secs(1))
|
||||
.unwrap();
|
||||
throttler
|
||||
.as_mut()
|
||||
.start_send(Response {
|
||||
request_id: 0,
|
||||
message: Ok(1),
|
||||
})
|
||||
.unwrap();
|
||||
assert_eq!(throttler.inner.in_flight_requests.len(), 0);
|
||||
assert_eq!(
|
||||
throttler.inner.sink.get(0),
|
||||
Some(&Response {
|
||||
request_id: 0,
|
||||
message: Ok(1),
|
||||
})
|
||||
);
|
||||
}
|
||||
111
tarpc/src/server/tokio.rs
Normal file
111
tarpc/src/server/tokio.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
use super::{Channel, Requests, Serve};
|
||||
use futures::{prelude::*, ready, task::*};
|
||||
use pin_project::pin_project;
|
||||
use std::pin::Pin;
|
||||
|
||||
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
|
||||
/// for each new channel. Returned by
|
||||
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct TokioServerExecutor<T, S> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
serve: S,
|
||||
}
|
||||
|
||||
impl<T, S> TokioServerExecutor<T, S> {
|
||||
pub(crate) fn new(inner: T, serve: S) -> Self {
|
||||
Self { inner, serve }
|
||||
}
|
||||
}
|
||||
|
||||
/// A future that drives the server by [spawning](tokio::spawn) each [response
|
||||
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
|
||||
/// [`Channel::execute`](crate::server::Channel::execute).
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct TokioChannelExecutor<T, S> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
serve: S,
|
||||
}
|
||||
|
||||
impl<T, S> TokioServerExecutor<T, S> {
|
||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||
self.as_mut().project().inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> TokioChannelExecutor<T, S> {
|
||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||
self.as_mut().project().inner
|
||||
}
|
||||
}
|
||||
|
||||
// Send + 'static execution helper methods.
|
||||
|
||||
impl<C> Requests<C>
|
||||
where
|
||||
C: Channel,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
{
|
||||
/// Executes all requests using the given service function. Requests are handled concurrently
|
||||
/// by [spawning](::tokio::spawn) each handler on tokio's default executor.
|
||||
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
||||
{
|
||||
TokioChannelExecutor { inner: self, serve }
|
||||
}
|
||||
}
|
||||
|
||||
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
|
||||
where
|
||||
St: Sized + Stream<Item = C>,
|
||||
C: Channel + Send + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||
Se::Fut: Send,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||
tokio::spawn(channel.execute(self.serve.clone()));
|
||||
}
|
||||
tracing::info!("Server shutting down.");
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
|
||||
where
|
||||
C: Channel + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||
S::Fut: Send,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||
match response_handler {
|
||||
Ok(resp) => {
|
||||
let server = self.serve.clone();
|
||||
tokio::spawn(async move {
|
||||
resp.execute(server).await;
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Requests stream errored out: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
@@ -16,11 +16,14 @@
|
||||
//! This crate's design is based on [opencensus
|
||||
//! tracing](https://opencensus.io/core-concepts/tracing/).
|
||||
|
||||
use opentelemetry::trace::TraceContextExt;
|
||||
use rand::Rng;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
fmt::{self, Formatter},
|
||||
mem,
|
||||
num::{NonZeroU128, NonZeroU64},
|
||||
};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
/// A context for tracing the execution of processes, distributed or otherwise.
|
||||
///
|
||||
@@ -36,33 +39,50 @@ pub struct Context {
|
||||
/// before making an RPC, and the span ID is sent to the server. The server is free to create
|
||||
/// its own spans, for which it sets the client's span as the parent span.
|
||||
pub span_id: SpanId,
|
||||
/// An identifier of the span that originated the current span. For example, if a server sends
|
||||
/// an RPC in response to a client request that included a span, the server would create a span
|
||||
/// for the RPC and set its parent to the span_id in the incoming request's context.
|
||||
///
|
||||
/// If `parent_id` is `None`, then this is a root context.
|
||||
pub parent_id: Option<SpanId>,
|
||||
/// Indicates whether a sampler has already decided whether or not to sample the trace
|
||||
/// associated with the Context. If `sampling_decision` is None, then a decision has not yet
|
||||
/// been made. Downstream samplers do not need to abide by "no sample" decisions--for example,
|
||||
/// an upstream client may choose to never sample, which may not make sense for the client's
|
||||
/// dependencies. On the other hand, if an upstream process has chosen to sample this trace,
|
||||
/// then the downstream samplers are expected to respect that decision and also sample the
|
||||
/// trace. Otherwise, the full trace would not be able to be reconstructed.
|
||||
pub sampling_decision: SamplingDecision,
|
||||
}
|
||||
|
||||
/// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the
|
||||
/// same trace ID.
|
||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct TraceId(u128);
|
||||
pub struct TraceId(#[cfg_attr(feature = "serde1", serde(with = "u128_serde"))] u128);
|
||||
|
||||
/// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace.
|
||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct SpanId(u64);
|
||||
|
||||
/// Indicates whether a sampler has decided whether or not to sample the trace associated with the
|
||||
/// Context. Downstream samplers do not need to abide by "no sample" decisions--for example, an
|
||||
/// upstream client may choose to never sample, which may not make sense for the client's
|
||||
/// dependencies. On the other hand, if an upstream process has chosen to sample this trace, then
|
||||
/// the downstream samplers are expected to respect that decision and also sample the trace.
|
||||
/// Otherwise, the full trace would not be able to be reconstructed reliably.
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[repr(u8)]
|
||||
pub enum SamplingDecision {
|
||||
/// The associated span was sampled by its creating process. Child spans must also be sampled.
|
||||
Sampled,
|
||||
/// The associated span was not sampled by its creating process.
|
||||
Unsampled,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
/// Constructs a new root context. A root context is one with no parent span.
|
||||
pub fn new_root() -> Self {
|
||||
let rng = &mut rand::thread_rng();
|
||||
Context {
|
||||
trace_id: TraceId::random(rng),
|
||||
span_id: SpanId::random(rng),
|
||||
parent_id: None,
|
||||
/// Constructs a new context with the trace ID and sampling decision inherited from the parent.
|
||||
pub(crate) fn new_child(&self) -> Self {
|
||||
Self {
|
||||
trace_id: self.trace_id,
|
||||
span_id: SpanId::random(&mut rand::thread_rng()),
|
||||
sampling_decision: self.sampling_decision,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -71,17 +91,128 @@ impl TraceId {
|
||||
/// Returns a random trace ID that can be assumed to be globally unique if `rng` generates
|
||||
/// actually-random numbers.
|
||||
pub fn random<R: Rng>(rng: &mut R) -> Self {
|
||||
TraceId(u128::from(rng.next_u64()) << mem::size_of::<u64>() | u128::from(rng.next_u64()))
|
||||
TraceId(rng.gen::<NonZeroU128>().get())
|
||||
}
|
||||
|
||||
/// Returns true iff the trace ID is 0.
|
||||
pub fn is_none(&self) -> bool {
|
||||
self.0 == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl SpanId {
|
||||
/// Returns a random span ID that can be assumed to be unique within a single trace.
|
||||
pub fn random<R: Rng>(rng: &mut R) -> Self {
|
||||
SpanId(rng.next_u64())
|
||||
SpanId(rng.gen::<NonZeroU64>().get())
|
||||
}
|
||||
|
||||
/// Returns true iff the span ID is 0.
|
||||
pub fn is_none(&self) -> bool {
|
||||
self.0 == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TraceId> for u128 {
|
||||
fn from(trace_id: TraceId) -> Self {
|
||||
trace_id.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u128> for TraceId {
|
||||
fn from(trace_id: u128) -> Self {
|
||||
Self(trace_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SpanId> for u64 {
|
||||
fn from(span_id: SpanId) -> Self {
|
||||
span_id.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u64> for SpanId {
|
||||
fn from(span_id: u64) -> Self {
|
||||
Self(span_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<opentelemetry::trace::TraceId> for TraceId {
|
||||
fn from(trace_id: opentelemetry::trace::TraceId) -> Self {
|
||||
Self::from(trace_id.to_u128())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TraceId> for opentelemetry::trace::TraceId {
|
||||
fn from(trace_id: TraceId) -> Self {
|
||||
Self::from_u128(trace_id.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<opentelemetry::trace::SpanId> for SpanId {
|
||||
fn from(span_id: opentelemetry::trace::SpanId) -> Self {
|
||||
Self::from(span_id.to_u64())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SpanId> for opentelemetry::trace::SpanId {
|
||||
fn from(span_id: SpanId) -> Self {
|
||||
Self::from_u64(span_id.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&tracing::Span> for Context {
|
||||
type Error = NoActiveSpan;
|
||||
|
||||
fn try_from(span: &tracing::Span) -> Result<Self, NoActiveSpan> {
|
||||
let context = span.context();
|
||||
if context.has_active_span() {
|
||||
Ok(Self::from(context.span()))
|
||||
} else {
|
||||
Err(NoActiveSpan)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<opentelemetry::trace::SpanRef<'_>> for Context {
|
||||
fn from(span: opentelemetry::trace::SpanRef<'_>) -> Self {
|
||||
let otel_ctx = span.span_context();
|
||||
Self {
|
||||
trace_id: TraceId::from(otel_ctx.trace_id()),
|
||||
span_id: SpanId::from(otel_ctx.span_id()),
|
||||
sampling_decision: SamplingDecision::from(otel_ctx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SamplingDecision> for opentelemetry::trace::TraceFlags {
|
||||
fn from(decision: SamplingDecision) -> Self {
|
||||
match decision {
|
||||
SamplingDecision::Sampled => opentelemetry::trace::TraceFlags::SAMPLED,
|
||||
SamplingDecision::Unsampled => opentelemetry::trace::TraceFlags::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&opentelemetry::trace::SpanContext> for SamplingDecision {
|
||||
fn from(context: &opentelemetry::trace::SpanContext) -> Self {
|
||||
if context.is_sampled() {
|
||||
SamplingDecision::Sampled
|
||||
} else {
|
||||
SamplingDecision::Unsampled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SamplingDecision {
|
||||
fn default() -> Self {
|
||||
Self::Unsampled
|
||||
}
|
||||
}
|
||||
|
||||
/// Returned when a [`Context`] cannot be constructed from a [`Span`](tracing::Span).
|
||||
#[derive(Debug)]
|
||||
pub struct NoActiveSpan;
|
||||
|
||||
impl fmt::Display for TraceId {
|
||||
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "{:02x}", self.0)?;
|
||||
@@ -89,9 +220,42 @@ impl fmt::Display for TraceId {
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for TraceId {
|
||||
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "{:02x}", self.0)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SpanId {
|
||||
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "{:02x}", self.0)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for SpanId {
|
||||
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "{:02x}", self.0)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
mod u128_serde {
|
||||
pub fn serialize<S>(u: &u128, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serde::Serialize::serialize(&u.to_le_bytes(), serializer)
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<u128, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
Ok(u128::from_le_bytes(serde::Deserialize::deserialize(
|
||||
deserializer,
|
||||
)?))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,22 +9,32 @@
|
||||
//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`](sealed::Transport)
|
||||
//! can be plugged in, using whatever protocol it wants.
|
||||
|
||||
use futures::prelude::*;
|
||||
use std::io;
|
||||
|
||||
pub mod channel;
|
||||
|
||||
pub(crate) mod sealed {
|
||||
use super::*;
|
||||
use futures::prelude::*;
|
||||
use std::error::Error;
|
||||
|
||||
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
|
||||
pub trait Transport<SinkItem, Item>:
|
||||
Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error>
|
||||
pub trait Transport<SinkItem, Item>
|
||||
where
|
||||
Self: Stream<Item = Result<Item, <Self as Sink<SinkItem>>::Error>>,
|
||||
Self: Sink<SinkItem, Error = <Self as Transport<SinkItem, Item>>::TransportError>,
|
||||
<Self as Sink<SinkItem>>::Error: Error,
|
||||
{
|
||||
/// Associated type where clauses are not elaborated; this associated type allows users
|
||||
/// bounding types by Transport to avoid having to explicitly add `T::Error: Error` to their
|
||||
/// bounds.
|
||||
type TransportError: Error + Send + Sync + 'static;
|
||||
}
|
||||
|
||||
impl<T, SinkItem, Item> Transport<SinkItem, Item> for T where
|
||||
T: Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error> + ?Sized
|
||||
impl<T, SinkItem, Item, E> Transport<SinkItem, Item> for T
|
||||
where
|
||||
T: ?Sized,
|
||||
T: Stream<Item = Result<Item, E>>,
|
||||
T: Sink<SinkItem, Error = E>,
|
||||
T::Error: Error + Send + Sync + 'static,
|
||||
{
|
||||
type TransportError = E;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,13 +6,19 @@
|
||||
|
||||
//! Transports backed by in-memory channels.
|
||||
|
||||
use crate::PollIo;
|
||||
use futures::{task::*, Sink, Stream};
|
||||
use pin_project::pin_project;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::{error::Error, pin::Pin};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Errors that occur in the sending or receiving of messages over a channel.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ChannelError {
|
||||
/// An error occurred sending over the channel.
|
||||
#[error("an error occurred sending over the channel")]
|
||||
Send(#[source] Box<dyn Error + Send + Sync + 'static>),
|
||||
}
|
||||
|
||||
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
|
||||
/// [`Sink`].
|
||||
pub fn unbounded<SinkItem, Item>() -> (
|
||||
@@ -36,28 +42,33 @@ pub struct UnboundedChannel<Item, SinkItem> {
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
|
||||
type Item = Result<Item, io::Error>;
|
||||
type Item = Result<Item, ChannelError>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> {
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Item, ChannelError>>> {
|
||||
self.rx.poll_recv(cx).map(|option| option.map(Ok))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
|
||||
type Error = io::Error;
|
||||
const CLOSED_MESSAGE: &str = "the channel is closed and cannot accept new items for sending";
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
|
||||
type Error = ChannelError;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(if self.tx.is_closed() {
|
||||
Err(io::Error::from(io::ErrorKind::NotConnected))
|
||||
Err(ChannelError::Send(CLOSED_MESSAGE.into()))
|
||||
} else {
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
self.tx
|
||||
.send(item)
|
||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
||||
.map_err(|_| ChannelError::Send(CLOSED_MESSAGE.into()))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
@@ -65,7 +76,7 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
// UnboundedSender can't initiate closure.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
@@ -93,52 +104,45 @@ pub struct Channel<Item, SinkItem> {
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Stream for Channel<Item, SinkItem> {
|
||||
type Item = Result<Item, io::Error>;
|
||||
type Item = Result<Item, ChannelError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> {
|
||||
fn poll_next(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Item, ChannelError>>> {
|
||||
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
|
||||
type Error = io::Error;
|
||||
type Error = ChannelError;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_ready(cx)
|
||||
.map_err(convert_send_err_to_io)
|
||||
.map_err(|e| ChannelError::Send(Box::new(e)))
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
self.project()
|
||||
.tx
|
||||
.start_send(item)
|
||||
.map_err(convert_send_err_to_io)
|
||||
.map_err(|e| ChannelError::Send(Box::new(e)))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_flush(cx)
|
||||
.map_err(convert_send_err_to_io)
|
||||
.map_err(|e| ChannelError::Send(Box::new(e)))
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_close(cx)
|
||||
.map_err(convert_send_err_to_io)
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_send_err_to_io(e: futures::channel::mpsc::SendError) -> io::Error {
|
||||
if e.is_disconnected() {
|
||||
io::Error::from(io::ErrorKind::NotConnected)
|
||||
} else if e.is_full() {
|
||||
io::Error::from(io::ErrorKind::WouldBlock)
|
||||
} else {
|
||||
io::Error::new(io::ErrorKind::Other, e)
|
||||
.map_err(|e| ChannelError::Send(Box::new(e)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,17 +151,27 @@ fn convert_send_err_to_io(e: futures::channel::mpsc::SendError) -> io::Error {
|
||||
mod tests {
|
||||
use crate::{
|
||||
client, context,
|
||||
server::{BaseChannel, Incoming},
|
||||
transport,
|
||||
server::{incoming::Incoming, BaseChannel},
|
||||
transport::{
|
||||
self,
|
||||
channel::{Channel, UnboundedChannel},
|
||||
},
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{prelude::*, stream};
|
||||
use log::trace;
|
||||
use std::io;
|
||||
use tracing::trace;
|
||||
|
||||
#[test]
|
||||
fn ensure_is_transport() {
|
||||
fn is_transport<SinkItem, Item, T: crate::Transport<SinkItem, Item>>() {}
|
||||
is_transport::<(), (), UnboundedChannel<(), ()>>();
|
||||
is_transport::<(), (), Channel<(), ()>>();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
async fn integration() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||
tokio::spawn(
|
||||
@@ -173,10 +187,10 @@ mod tests {
|
||||
}),
|
||||
);
|
||||
|
||||
let client = client::new(client::Config::default(), client_channel).spawn()?;
|
||||
let client = client::new(client::Config::default(), client_channel).spawn();
|
||||
|
||||
let response1 = client.call(context::current(), "123".into()).await?;
|
||||
let response2 = client.call(context::current(), "abc".into()).await?;
|
||||
let response1 = client.call(context::current(), "", "123".into()).await?;
|
||||
let response2 = client.call(context::current(), "", "abc".into()).await?;
|
||||
|
||||
trace!("response1: {:?}, response2: {:?}", response1, response2);
|
||||
|
||||
|
||||
@@ -5,31 +5,7 @@
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::{
|
||||
io,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
/// Serializes `system_time` as a `u64` equal to the number of seconds since the epoch.
|
||||
pub fn serialize_epoch_secs<S>(system_time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
const ZERO_SECS: Duration = Duration::from_secs(0);
|
||||
system_time
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or(ZERO_SECS)
|
||||
.as_secs() // Only care about second precision
|
||||
.serialize(serializer)
|
||||
}
|
||||
|
||||
/// Deserializes [`SystemTime`] from a `u64` equal to the number of seconds since the epoch.
|
||||
pub fn deserialize_epoch_secs<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Ok(SystemTime::UNIX_EPOCH + Duration::from_secs(u64::deserialize(deserializer)?))
|
||||
}
|
||||
use std::io;
|
||||
|
||||
/// Serializes [`io::ErrorKind`] as a `u32`.
|
||||
#[allow(clippy::trivially_copy_pass_by_ref)] // Exact fn signature required by serde derive
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use futures::prelude::*;
|
||||
use std::io;
|
||||
use tarpc::serde_transport;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{BaseChannel, Incoming},
|
||||
server::{incoming::Incoming, BaseChannel},
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
@@ -33,7 +32,7 @@ impl ColorProtocol for ColorServer {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_call() -> io::Result<()> {
|
||||
async fn test_call() -> anyhow::Result<()> {
|
||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
|
||||
let addr = transport.local_addr();
|
||||
tokio::spawn(
|
||||
@@ -45,7 +44,7 @@ async fn test_call() -> io::Result<()> {
|
||||
);
|
||||
|
||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let client = ColorProtocolClient::new(client::Config::default(), transport).spawn()?;
|
||||
let client = ColorProtocolClient::new(client::Config::default(), transport).spawn();
|
||||
|
||||
let color = client
|
||||
.get_opposite_color(context::current(), TestData::White)
|
||||
|
||||
@@ -3,14 +3,11 @@ use futures::{
|
||||
future::{join_all, ready, Ready},
|
||||
prelude::*,
|
||||
};
|
||||
use std::{
|
||||
io,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use std::time::{Duration, SystemTime};
|
||||
use tarpc::{
|
||||
client::{self},
|
||||
context,
|
||||
server::{self, BaseChannel, Channel, Incoming},
|
||||
server::{self, incoming::Incoming, BaseChannel, Channel},
|
||||
transport::channel,
|
||||
};
|
||||
use tokio::join;
|
||||
@@ -39,8 +36,8 @@ impl Service for Server {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sequential() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
async fn sequential() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
|
||||
@@ -50,7 +47,7 @@ async fn sequential() -> io::Result<()> {
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||
|
||||
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
|
||||
assert_matches!(
|
||||
@@ -61,7 +58,7 @@ async fn sequential() -> io::Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
|
||||
async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
|
||||
#[tarpc_plugins::service]
|
||||
trait Loop {
|
||||
async fn r#loop();
|
||||
@@ -82,16 +79,14 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
let _ = env_logger::try_init();
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
|
||||
// Set up a client that initiates a long-lived request.
|
||||
// The request will complete in error when the server drops the connection.
|
||||
tokio::spawn(async move {
|
||||
let client = LoopClient::new(client::Config::default(), tx)
|
||||
.spawn()
|
||||
.unwrap();
|
||||
let client = LoopClient::new(client::Config::default(), tx).spawn();
|
||||
|
||||
let mut ctx = context::current();
|
||||
ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60);
|
||||
@@ -113,11 +108,11 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
|
||||
|
||||
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
||||
#[tokio::test]
|
||||
async fn serde() -> io::Result<()> {
|
||||
async fn serde() -> anyhow::Result<()> {
|
||||
use tarpc::serde_transport;
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
let _ = env_logger::try_init();
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56789", Json::default).await?;
|
||||
let addr = transport.local_addr();
|
||||
@@ -130,7 +125,7 @@ async fn serde() -> io::Result<()> {
|
||||
);
|
||||
|
||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let client = ServiceClient::new(client::Config::default(), transport).spawn()?;
|
||||
let client = ServiceClient::new(client::Config::default(), transport).spawn();
|
||||
|
||||
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
|
||||
assert_matches!(
|
||||
@@ -142,8 +137,8 @@ async fn serde() -> io::Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
async fn concurrent() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
@@ -152,7 +147,7 @@ async fn concurrent() -> io::Result<()> {
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||
|
||||
let req1 = client.add(context::current(), 1, 2);
|
||||
let req2 = client.add(context::current(), 3, 4);
|
||||
@@ -166,8 +161,8 @@ async fn concurrent() -> io::Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent_join() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
async fn concurrent_join() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
@@ -176,7 +171,7 @@ async fn concurrent_join() -> io::Result<()> {
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||
|
||||
let req1 = client.add(context::current(), 1, 2);
|
||||
let req2 = client.add(context::current(), 3, 4);
|
||||
@@ -191,8 +186,8 @@ async fn concurrent_join() -> io::Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent_join_all() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
async fn concurrent_join_all() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
@@ -201,7 +196,7 @@ async fn concurrent_join_all() -> io::Result<()> {
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||
|
||||
let req1 = client.add(context::current(), 1, 2);
|
||||
let req2 = client.add(context::current(), 3, 4);
|
||||
@@ -214,7 +209,7 @@ async fn concurrent_join_all() -> io::Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn counter() -> io::Result<()> {
|
||||
async fn counter() -> anyhow::Result<()> {
|
||||
#[tarpc::service]
|
||||
trait Counter {
|
||||
async fn count() -> u32;
|
||||
@@ -241,7 +236,7 @@ async fn counter() -> io::Result<()> {
|
||||
}
|
||||
});
|
||||
|
||||
let client = CounterClient::new(client::Config::default(), tx).spawn()?;
|
||||
let client = CounterClient::new(client::Config::default(), tx).spawn();
|
||||
assert_matches!(client.count(context::current()).await, Ok(1));
|
||||
assert_matches!(client.count(context::current()).await, Ok(2));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user