mirror of
https://github.com/OMGeeky/tarpc.git
synced 2025-12-26 17:02:32 +01:00
Instrument tarpc with tracing.
tarpc is now instrumented with tracing primitives extended with OpenTelemetry traces. Using a compatible tracing-opentelemetry subscriber like Jaeger, each RPC can be traced through the client, server, amd other dependencies downstream of the server. Even for applications not connected to a distributed tracing collector, the instrumentation can also be ingested by regular loggers like env_logger. # Breaking Changes ## Logging Logged events are now structured using tracing. For applications using a logger and not a tracing subscriber, these logs may look different or contain information in a less consumable manner. The easiest solution is to add a tracing subscriber that logs to stdout, such as tracing_subscriber::fmt. ## Context - Context no longer has parent_span, which was actually never needed, because the context sent in an RPC is inherently the parent context. For purposes of distributed tracing, the client side of the RPC has all necessary information to link the span to its parent; the server side need do nothing more than export the (trace ID, span ID) tuple. - Context has a new field, SamplingDecision, which has two variants, Sampled and Unsampled. This field can be used by downstream systems to determine whether a trace needs to be exported. If the parent span is sampled, the expectation is that all child spans be exported, as well; to do otherwise could result in lossy traces being exported. Note that if an Openetelemetry tracing subscriber is not installed, the fallback context will still be used, but the Context's sampling decision will always be inherited by the parent Context's sampling decision. - Context::scope has been removed. Context propagation is now done via tracing's task-local spans. Spans can be propagated across tasks via Span::in_scope. When a service receives a request, it attaches an Opentelemetry context to the local Span created before request handling, and this context contains the request deadline. This span-local deadline is retrieved by Context::current, but it cannot be modified so that future Context::current calls contain a different deadline. However, the deadline in the context passed into an RPC call will override it, so users can retrieve the current context and then modify the deadline field, as has been historically possible. - Context propgation precedence changes: when an RPC is initiated, the current Span's Opentelemetry context takes precedence over the trace context passed into the RPC method. If there is no current Span, then the trace context argument is used as it has been historically. Note that Opentelemetry context propagation requires an Opentelemetry tracing subscriber to be installed. ## Server - The server::Channel trait now has an additional required associated type and method which returns the underlying transport. This makes it more ergonomic for users to retrieve transport-specific information, like IP Address. BaseChannel implements Channel::transport by returning the underlying transport, and channel decorators like Throttler just delegate to the Channel::transport method of the wrapped channel. # References [1] https://github.com/tokio-rs/tracing [2] https://opentelemetry.io [3] https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger [4] https://github.com/env-logger-rs/env_logger
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
|
||||
members = [
|
||||
"example-service",
|
||||
"tarpc",
|
||||
"plugins",
|
||||
]
|
||||
|
||||
[profile.dev]
|
||||
split-debuginfo = "unpacked"
|
||||
|
||||
@@ -42,7 +42,6 @@ process, and no context switching between different languages.
|
||||
Some other features of tarpc:
|
||||
- Pluggable transport: any type impling `Stream<Item = Request> + Sink<Response>` can be
|
||||
used as a transport to connect the client and server.
|
||||
- `Send + 'static` optional: if the transport doesn't require it, neither does tarpc!
|
||||
- Cascading cancellation: dropping a request will send a cancellation message to the server.
|
||||
The server will cease any unfinished work on the request, subsequently cancelling any of its
|
||||
own requests, repeating for the entire chain of transitive dependencies.
|
||||
@@ -51,6 +50,14 @@ Some other features of tarpc:
|
||||
requests sent by the server that use the request context will propagate the request deadline.
|
||||
For example, if a server is handling a request with a 10s deadline, does 2s of work, then
|
||||
sends a request to another server, that server will see an 8s deadline.
|
||||
- Distributed tracing: tarpc is instrumented with [tracing](https://github.com/tokio-rs/tracing)
|
||||
primitives extended with [OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible
|
||||
tracing subscriber like
|
||||
[Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger),
|
||||
each RPC can be traced through the client, server, amd other dependencies downstream of the
|
||||
server. Even for applications not connected to a distributed tracing collector, the
|
||||
instrumentation can also be ingested by regular loggers like
|
||||
[env_logger](https://github.com/env-logger-rs/env_logger/).
|
||||
- Serde serialization: enabling the `serde1` Cargo feature will make service requests and
|
||||
responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can
|
||||
be used, as well, so the price of serialization doesn't have to be paid when it's not needed.
|
||||
|
||||
@@ -13,13 +13,21 @@ readme = "../README.md"
|
||||
description = "An example server built on tarpc."
|
||||
|
||||
[dependencies]
|
||||
clap = "2.33"
|
||||
env_logger = "0.8"
|
||||
anyhow = "1.0"
|
||||
clap = "3.0.0-beta.2"
|
||||
log = "0.4"
|
||||
futures = "0.3"
|
||||
opentelemetry = { version = "0.13", features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.12", features = ["tokio"] }
|
||||
rand = "0.8"
|
||||
serde = { version = "1.0" }
|
||||
tarpc = { path = "../tarpc", features = ["full"] }
|
||||
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
|
||||
tokio-serde = { version = "0.8", features = ["json"] }
|
||||
tracing = { version = "0.1" }
|
||||
tracing-appender = "0.1"
|
||||
tracing-opentelemetry = "0.12"
|
||||
tracing-subscriber = "0.2"
|
||||
|
||||
[lib]
|
||||
name = "service"
|
||||
|
||||
@@ -4,57 +4,49 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use clap::{App, Arg};
|
||||
use std::{io, net::SocketAddr};
|
||||
use clap::Clap;
|
||||
use service::{init_tracing, WorldClient};
|
||||
use std::{net::SocketAddr, time::Duration};
|
||||
use tarpc::{client, context, tokio_serde::formats::Json};
|
||||
use tokio::time::sleep;
|
||||
use tracing::Instrument;
|
||||
|
||||
#[derive(Clap)]
|
||||
struct Flags {
|
||||
/// Sets the server address to connect to.
|
||||
#[clap(long)]
|
||||
server_addr: SocketAddr,
|
||||
/// Sets the name to say hello to.
|
||||
#[clap(long)]
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
env_logger::init();
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let flags = Flags::parse();
|
||||
let _uninstall = init_tracing("Tarpc Example Client")?;
|
||||
|
||||
let flags = App::new("Hello Client")
|
||||
.version("0.1")
|
||||
.author("Tim <tikue@google.com>")
|
||||
.about("Say hello!")
|
||||
.arg(
|
||||
Arg::with_name("server_addr")
|
||||
.long("server_addr")
|
||||
.value_name("ADDRESS")
|
||||
.help("Sets the server address to connect to.")
|
||||
.required(true)
|
||||
.takes_value(true),
|
||||
)
|
||||
.arg(
|
||||
Arg::with_name("name")
|
||||
.short("n")
|
||||
.long("name")
|
||||
.value_name("STRING")
|
||||
.help("Sets the name to say hello to.")
|
||||
.required(true)
|
||||
.takes_value(true),
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
let server_addr = flags.value_of("server_addr").unwrap();
|
||||
let server_addr = server_addr
|
||||
.parse::<SocketAddr>()
|
||||
.unwrap_or_else(|e| panic!(r#"--server_addr value "{}" invalid: {}"#, server_addr, e));
|
||||
|
||||
let name = flags.value_of("name").unwrap().into();
|
||||
|
||||
let mut transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default);
|
||||
transport.config_mut().max_frame_length(usize::MAX);
|
||||
let transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
|
||||
|
||||
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
|
||||
// config and any Transport as input.
|
||||
let client = service::WorldClient::new(client::Config::default(), transport.await?).spawn()?;
|
||||
let client = WorldClient::new(client::Config::default(), transport.await?).spawn()?;
|
||||
|
||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||
// specifies a deadline and trace information which can be helpful in debugging requests.
|
||||
let hello = client.hello(context::current(), name).await?;
|
||||
let hello = async move {
|
||||
// Send the request twice, just to be safe! ;)
|
||||
tokio::select! {
|
||||
hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 }
|
||||
hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 }
|
||||
}
|
||||
}
|
||||
.instrument(tracing::info_span!("Two Hellos"))
|
||||
.await;
|
||||
|
||||
println!("{}", hello);
|
||||
tracing::info!("{:?}", hello);
|
||||
|
||||
// Let the background span processor finish.
|
||||
sleep(Duration::from_micros(1)).await;
|
||||
opentelemetry::global::shutdown_tracer_provider();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use std::env;
|
||||
use tracing_subscriber::{fmt::format::FmtSpan, prelude::*};
|
||||
|
||||
/// This is the service definition. It looks a lot like a trait definition.
|
||||
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
#[tarpc::service]
|
||||
@@ -11,3 +14,28 @@ pub trait World {
|
||||
/// Returns a greeting for name.
|
||||
async fn hello(name: String) -> String;
|
||||
}
|
||||
|
||||
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
|
||||
pub fn init_tracing(
|
||||
service_name: &str,
|
||||
) -> anyhow::Result<tracing_appender::non_blocking::WorkerGuard> {
|
||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||
|
||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||
.with_service_name(service_name)
|
||||
.with_max_packet_size(2usize.pow(13))
|
||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||
let (non_blocking, guard) = tracing_appender::non_blocking(std::io::stdout());
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with(
|
||||
tracing_subscriber::fmt::layer()
|
||||
.with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
|
||||
.with_writer(non_blocking),
|
||||
)
|
||||
.with(tracing_opentelemetry::layer().with_tracer(tracer))
|
||||
.try_init()?;
|
||||
|
||||
Ok(guard)
|
||||
}
|
||||
|
||||
@@ -4,18 +4,30 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use clap::{App, Arg};
|
||||
use clap::Clap;
|
||||
use futures::{future, prelude::*};
|
||||
use service::World;
|
||||
use rand::{
|
||||
distributions::{Distribution, Uniform},
|
||||
thread_rng,
|
||||
};
|
||||
use service::{init_tracing, World};
|
||||
use std::{
|
||||
io,
|
||||
net::{IpAddr, SocketAddr},
|
||||
time::Duration,
|
||||
};
|
||||
use tarpc::{
|
||||
context,
|
||||
server::{self, Channel, Incoming},
|
||||
tokio_serde::formats::Json,
|
||||
};
|
||||
use tokio::time;
|
||||
|
||||
#[derive(Clap)]
|
||||
struct Flags {
|
||||
/// Sets the port number to listen on.
|
||||
#[clap(long)]
|
||||
port: u16,
|
||||
}
|
||||
|
||||
// This is the type that implements the generated World trait. It is the business logic
|
||||
// and is used to start the server.
|
||||
@@ -25,35 +37,19 @@ struct HelloServer(SocketAddr);
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
async fn hello(self, _: context::Context, name: String) -> String {
|
||||
let sleep_time =
|
||||
Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng()));
|
||||
time::sleep(sleep_time).await;
|
||||
format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
env_logger::init();
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let flags = Flags::parse();
|
||||
let _uninstall = init_tracing("Tarpc Example Server")?;
|
||||
|
||||
let flags = App::new("Hello Server")
|
||||
.version("0.1")
|
||||
.author("Tim <tikue@google.com>")
|
||||
.about("Say hello!")
|
||||
.arg(
|
||||
Arg::with_name("port")
|
||||
.short("p")
|
||||
.long("port")
|
||||
.value_name("NUMBER")
|
||||
.help("Sets the port number to listen on")
|
||||
.required(true)
|
||||
.takes_value(true),
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
let port = flags.value_of("port").unwrap();
|
||||
let port = port
|
||||
.parse()
|
||||
.unwrap_or_else(|e| panic!(r#"--port value "{}" invalid: {}"#, port, e));
|
||||
|
||||
let server_addr = (IpAddr::from([0, 0, 0, 0]), port);
|
||||
let server_addr = (IpAddr::from([0, 0, 0, 0]), flags.port);
|
||||
|
||||
// JSON transport is provided by the json_transport tarpc module. It makes it easy
|
||||
// to start up a serde-powered json serialization strategy over TCP.
|
||||
@@ -64,12 +60,12 @@ async fn main() -> io::Result<()> {
|
||||
.filter_map(|r| future::ready(r.ok()))
|
||||
.map(server::BaseChannel::with_defaults)
|
||||
// Limit channels to 1 per IP.
|
||||
.max_channels_per_key(1, |t| t.as_ref().peer_addr().unwrap().ip())
|
||||
.max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip())
|
||||
// serve is generated by the service attribute. It takes as input any type implementing
|
||||
// the generated World trait.
|
||||
.map(|channel| {
|
||||
let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap());
|
||||
channel.requests().execute(server.serve())
|
||||
let server = HelloServer(channel.transport().peer_addr().unwrap());
|
||||
channel.execute(server.serve())
|
||||
})
|
||||
// Max 10 channels.
|
||||
.buffer_unordered(10)
|
||||
|
||||
@@ -30,4 +30,4 @@ proc-macro = true
|
||||
assert-type-eq = "0.1.0"
|
||||
futures = "0.3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tarpc = { path = "../tarpc" }
|
||||
tarpc = { path = "../tarpc", features = ["serde1"] }
|
||||
|
||||
@@ -267,6 +267,12 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
None
|
||||
};
|
||||
|
||||
let methods = rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>();
|
||||
let request_names = methods
|
||||
.iter()
|
||||
.map(|m| format!("{}.{}", ident, m))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
ServiceGenerator {
|
||||
response_fut_name,
|
||||
service_ident: ident,
|
||||
@@ -278,7 +284,8 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
vis,
|
||||
args,
|
||||
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
|
||||
method_idents: &rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>(),
|
||||
method_idents: &methods,
|
||||
request_names: &*request_names,
|
||||
attrs,
|
||||
rpcs,
|
||||
return_types: &rpcs
|
||||
@@ -441,6 +448,7 @@ struct ServiceGenerator<'a> {
|
||||
camel_case_idents: &'a [Ident],
|
||||
future_types: &'a [Type],
|
||||
method_idents: &'a [&'a Ident],
|
||||
request_names: &'a [String],
|
||||
method_attrs: &'a [&'a [Attribute]],
|
||||
args: &'a [&'a [PatType]],
|
||||
return_types: &'a [&'a Type],
|
||||
@@ -524,6 +532,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
camel_case_idents,
|
||||
arg_pats,
|
||||
method_idents,
|
||||
request_names,
|
||||
..
|
||||
} = self;
|
||||
|
||||
@@ -534,6 +543,16 @@ impl<'a> ServiceGenerator<'a> {
|
||||
type Resp = #response_ident;
|
||||
type Fut = #response_fut_ident<S>;
|
||||
|
||||
fn method(&self, req: &#request_ident) -> Option<&'static str> {
|
||||
Some(match req {
|
||||
#(
|
||||
#request_ident::#camel_case_idents{..} => {
|
||||
#request_names
|
||||
}
|
||||
)*
|
||||
})
|
||||
}
|
||||
|
||||
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
|
||||
match req {
|
||||
#(
|
||||
@@ -714,6 +733,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
method_attrs,
|
||||
vis,
|
||||
method_idents,
|
||||
request_names,
|
||||
args,
|
||||
return_types,
|
||||
arg_pats,
|
||||
@@ -729,7 +749,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
|
||||
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
|
||||
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
|
||||
let resp = self.0.call(ctx, request);
|
||||
let resp = self.0.call(ctx, #request_names, request);
|
||||
async move {
|
||||
match resp.await? {
|
||||
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
|
||||
|
||||
@@ -30,26 +30,31 @@ anyhow = "1.0"
|
||||
fnv = "1.0"
|
||||
futures = "0.3"
|
||||
humantime = "2.0"
|
||||
log = "0.4"
|
||||
pin-project = "1.0"
|
||||
rand = "0.7"
|
||||
rand = "0.8"
|
||||
serde = { optional = true, version = "1.0", features = ["derive"] }
|
||||
static_assertions = "1.1.0"
|
||||
tarpc-plugins = { path = "../plugins", version = "0.10" }
|
||||
tokio = { version = "1", features = ["time"] }
|
||||
tokio-util = { version = "0.6.3", features = ["time"] }
|
||||
tokio-serde = { optional = true, version = "0.8" }
|
||||
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
|
||||
tracing-opentelemetry = { version = "0.12", default-features = false }
|
||||
opentelemetry = { version = "0.13", default-features = false }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = "1.4"
|
||||
bincode = "1.3"
|
||||
bytes = { version = "1", features = ["serde"] }
|
||||
env_logger = "0.8"
|
||||
flate2 = "1.0"
|
||||
futures-test = "0.3"
|
||||
log = "0.4"
|
||||
opentelemetry = { version = "0.13", default-features = false, features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.12", features = ["tokio"] }
|
||||
pin-utils = "0.1.0-alpha"
|
||||
serde_bytes = "0.11"
|
||||
tracing-appender = "0.1"
|
||||
tracing-subscriber = "0.2"
|
||||
tokio = { version = "1", features = ["full", "test-util"] }
|
||||
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
|
||||
trybuild = "1.0"
|
||||
@@ -63,7 +68,7 @@ name = "compression"
|
||||
required-features = ["serde-transport", "tcp"]
|
||||
|
||||
[[example]]
|
||||
name = "server_calling_server"
|
||||
name = "tracing"
|
||||
required-features = ["full"]
|
||||
|
||||
[[example]]
|
||||
|
||||
@@ -38,11 +38,10 @@ use futures::{
|
||||
future::{self, AbortHandle},
|
||||
prelude::*,
|
||||
};
|
||||
use log::info;
|
||||
use publisher::Publisher as _;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io,
|
||||
env, io,
|
||||
net::SocketAddr,
|
||||
sync::{Arc, Mutex, RwLock},
|
||||
};
|
||||
@@ -54,6 +53,8 @@ use tarpc::{
|
||||
};
|
||||
use tokio::net::ToSocketAddrs;
|
||||
use tokio_serde::formats::Json;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
pub mod subscriber {
|
||||
#[tarpc::service]
|
||||
@@ -83,10 +84,7 @@ impl subscriber::Subscriber for Subscriber {
|
||||
}
|
||||
|
||||
async fn receive(self, _: context::Context, topic: String, message: String) {
|
||||
info!(
|
||||
"[{}] received message on topic '{}': {}",
|
||||
self.local_addr, topic, message
|
||||
);
|
||||
info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,7 +118,7 @@ impl Subscriber {
|
||||
let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve()));
|
||||
tokio::spawn(async move {
|
||||
match handler.await {
|
||||
Ok(()) | Err(future::Aborted) => info!("[{}] subscriber shutdown.", local_addr),
|
||||
Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."),
|
||||
}
|
||||
});
|
||||
Ok(SubscriberHandle(abort_handle))
|
||||
@@ -153,13 +151,13 @@ impl Publisher {
|
||||
subscriptions: self.clone().start_subscription_manager().await?,
|
||||
};
|
||||
|
||||
info!("[{}] listening for publishers.", publisher_addrs.publisher);
|
||||
info!(publisher_addr = %publisher_addrs.publisher, "listening for publishers.",);
|
||||
tokio::spawn(async move {
|
||||
// Because this is just an example, we know there will only be one publisher. In more
|
||||
// realistic code, this would be a loop to continually accept new publisher
|
||||
// connections.
|
||||
let publisher = connecting_publishers.next().await.unwrap().unwrap();
|
||||
info!("[{}] publisher connected.", publisher.peer_addr().unwrap());
|
||||
info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected.");
|
||||
|
||||
server::BaseChannel::with_defaults(publisher)
|
||||
.execute(self.serve())
|
||||
@@ -174,7 +172,7 @@ impl Publisher {
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let new_subscriber_addr = connecting_subscribers.get_ref().local_addr();
|
||||
info!("[{}] listening for subscribers.", new_subscriber_addr);
|
||||
info!(?new_subscriber_addr, "listening for subscribers.");
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(conn) = connecting_subscribers.next().await {
|
||||
@@ -215,7 +213,7 @@ impl Publisher {
|
||||
},
|
||||
);
|
||||
|
||||
info!("[{}] subscribed to topics: {:?}", subscriber_addr, topics);
|
||||
info!(%subscriber_addr, ?topics, "subscribed to new topics");
|
||||
let mut subscriptions = self.subscriptions.write().unwrap();
|
||||
for topic in topics {
|
||||
subscriptions
|
||||
@@ -235,9 +233,9 @@ impl Publisher {
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = client_dispatch.await {
|
||||
info!(
|
||||
"[{}] subscriber connection broken: {:?}",
|
||||
subscriber_addr, e
|
||||
)
|
||||
%subscriber_addr,
|
||||
error = %e,
|
||||
"subscriber connection broken");
|
||||
}
|
||||
// Don't clean up the subscriber until initialization is done.
|
||||
let _ = subscriber_ready.await;
|
||||
@@ -281,13 +279,31 @@ impl publisher::Publisher for Publisher {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
|
||||
fn init_tracing(service_name: &str) -> anyhow::Result<tracing_appender::non_blocking::WorkerGuard> {
|
||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||
.with_service_name(service_name)
|
||||
.with_max_packet_size(2usize.pow(13))
|
||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||
|
||||
let (non_blocking, guard) = tracing_appender::non_blocking(std::io::stdout());
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with(tracing_subscriber::fmt::layer().with_writer(non_blocking))
|
||||
.with(tracing_opentelemetry::layer().with_tracer(tracer))
|
||||
.try_init()?;
|
||||
|
||||
Ok(guard)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
env_logger::init();
|
||||
let _uninstall = init_tracing("Pub/Sub")?;
|
||||
|
||||
let clients = Arc::new(Mutex::new(HashMap::new()));
|
||||
let addrs = Publisher {
|
||||
clients,
|
||||
clients: Arc::new(Mutex::new(HashMap::new())),
|
||||
subscriptions: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
.start()
|
||||
@@ -337,6 +353,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
opentelemetry::global::shutdown_tracer_provider();
|
||||
info!("done.");
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -6,12 +6,13 @@
|
||||
|
||||
use crate::{add::Add as AddService, double::Double as DoubleService};
|
||||
use futures::{future, prelude::*};
|
||||
use std::io;
|
||||
use std::env;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{BaseChannel, Incoming},
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
pub mod add {
|
||||
#[tarpc::service]
|
||||
@@ -35,7 +36,6 @@ struct AddServer;
|
||||
#[tarpc::server]
|
||||
impl AddService for AddServer {
|
||||
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
|
||||
log::info!("AddService {:#?}", context::current());
|
||||
x + y
|
||||
}
|
||||
}
|
||||
@@ -48,18 +48,34 @@ struct DoubleServer {
|
||||
#[tarpc::server]
|
||||
impl DoubleService for DoubleServer {
|
||||
async fn double(self, _: context::Context, x: i32) -> Result<i32, String> {
|
||||
let ctx = context::current();
|
||||
log::info!("DoubleService {:#?}", ctx);
|
||||
self.add_client
|
||||
.add(ctx, x, x)
|
||||
.add(context::current(), x, x)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn init_tracing(service_name: &str) -> anyhow::Result<tracing_appender::non_blocking::WorkerGuard> {
|
||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||
.with_service_name(service_name)
|
||||
.with_max_packet_size(2usize.pow(13))
|
||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||
|
||||
let (non_blocking, guard) = tracing_appender::non_blocking(std::io::stdout());
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with(tracing_subscriber::fmt::layer().with_writer(non_blocking))
|
||||
.with(tracing_opentelemetry::layer().with_tracer(tracer))
|
||||
.try_init()?;
|
||||
|
||||
Ok(guard)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
env_logger::init();
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let _uninstall = init_tracing("tarpc_tracing_example")?;
|
||||
|
||||
let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||
.await?
|
||||
@@ -89,9 +105,11 @@ async fn main() -> io::Result<()> {
|
||||
double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?;
|
||||
|
||||
let ctx = context::current();
|
||||
log::info!("Client {:#?}", ctx);
|
||||
for i in 1..=5 {
|
||||
eprintln!("{:?}", double_client.double(ctx, i).await?);
|
||||
for _ in 1..=5 {
|
||||
tracing::info!("{:?}", double_client.double(ctx, 1).await?);
|
||||
}
|
||||
|
||||
opentelemetry::global::shutdown_tracer_provider();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -8,16 +8,13 @@
|
||||
|
||||
mod in_flight_requests;
|
||||
|
||||
use crate::{
|
||||
context, trace::SpanId, ClientMessage, PollContext, PollIo, Request, Response, Transport,
|
||||
};
|
||||
use crate::{context, trace, ClientMessage, PollContext, PollIo, Request, Response, Transport};
|
||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||
use in_flight_requests::InFlightRequests;
|
||||
use log::{info, trace};
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
fmt, io,
|
||||
fmt, io, mem,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
@@ -25,6 +22,7 @@ use std::{
|
||||
},
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::Span;
|
||||
|
||||
/// Settings that control the behavior of the client.
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -67,11 +65,9 @@ where
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub fn spawn(self) -> io::Result<C> {
|
||||
use log::warn;
|
||||
|
||||
let dispatch = self
|
||||
.dispatch
|
||||
.unwrap_or_else(move |e| warn!("Connection broken: {}", e));
|
||||
.unwrap_or_else(move |e| tracing::warn!("Connection broken: {}", e));
|
||||
tokio::spawn(dispatch);
|
||||
Ok(self.client)
|
||||
}
|
||||
@@ -114,72 +110,70 @@ impl<Req, Resp> Clone for Channel<Req, Resp> {
|
||||
|
||||
impl<Req, Resp> Channel<Req, Resp> {
|
||||
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
||||
/// resolves when the request is sent (not when the response is received).
|
||||
fn send(
|
||||
/// resolves to the response.
|
||||
#[tracing::instrument(
|
||||
name = "RPC",
|
||||
skip(self, ctx, request_name, request),
|
||||
fields(
|
||||
rpc.trace_id = tracing::field::Empty,
|
||||
otel.kind = "client",
|
||||
otel.name = request_name)
|
||||
)]
|
||||
pub async fn call(
|
||||
&self,
|
||||
mut ctx: context::Context,
|
||||
request_name: &str,
|
||||
request: Req,
|
||||
) -> impl Future<Output = io::Result<DispatchResponse<Resp>>> + '_ {
|
||||
// Convert the context to the call context.
|
||||
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
|
||||
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
|
||||
|
||||
let (response_completion, response) = oneshot::channel();
|
||||
let cancellation = self.cancellation.clone();
|
||||
) -> io::Result<Resp> {
|
||||
let span = Span::current();
|
||||
ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
|
||||
tracing::warn!(
|
||||
"OpenTelemetry subscriber not installed; making unsampled child context."
|
||||
);
|
||||
ctx.trace_context.new_child()
|
||||
});
|
||||
span.record("rpc.trace_id", &tracing::field::display(ctx.trace_id()));
|
||||
let (response_completion, mut response) = oneshot::channel();
|
||||
let request_id =
|
||||
u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||
|
||||
// DispatchResponse impls Drop to cancel in-flight requests. It should be created before
|
||||
// ResponseGuard impls Drop to cancel in-flight requests. It should be created before
|
||||
// sending out the request; otherwise, the response future could be dropped after the
|
||||
// request is sent out but before DispatchResponse is created, rendering the cancellation
|
||||
// request is sent out but before ResponseGuard is created, rendering the cancellation
|
||||
// logic inactive.
|
||||
let response = DispatchResponse {
|
||||
response,
|
||||
let response_guard = ResponseGuard {
|
||||
response: &mut response,
|
||||
request_id,
|
||||
cancellation: Some(cancellation),
|
||||
ctx,
|
||||
cancellation: &self.cancellation,
|
||||
};
|
||||
async move {
|
||||
self.to_dispatch
|
||||
.send(DispatchRequest {
|
||||
ctx,
|
||||
request_id,
|
||||
request,
|
||||
response_completion,
|
||||
})
|
||||
.await
|
||||
.map_err(|mpsc::error::SendError(_)| {
|
||||
io::Error::from(io::ErrorKind::ConnectionReset)
|
||||
})?;
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
||||
/// resolves to the response.
|
||||
pub async fn call(&self, ctx: context::Context, request: Req) -> io::Result<Resp> {
|
||||
let dispatch_response = self.send(ctx, request).await?;
|
||||
dispatch_response.await
|
||||
self.to_dispatch
|
||||
.send(DispatchRequest {
|
||||
ctx,
|
||||
span,
|
||||
request_id,
|
||||
request,
|
||||
response_completion,
|
||||
})
|
||||
.await
|
||||
.map_err(|mpsc::error::SendError(_)| io::Error::from(io::ErrorKind::ConnectionReset))?;
|
||||
response_guard.response().await
|
||||
}
|
||||
}
|
||||
|
||||
/// A server response that is completed by request dispatch when the corresponding response
|
||||
/// arrives off the wire.
|
||||
#[derive(Debug)]
|
||||
struct DispatchResponse<Resp> {
|
||||
response: oneshot::Receiver<Response<Resp>>,
|
||||
ctx: context::Context,
|
||||
cancellation: Option<RequestCancellation>,
|
||||
struct ResponseGuard<'a, Resp> {
|
||||
response: &'a mut oneshot::Receiver<Response<Resp>>,
|
||||
cancellation: &'a RequestCancellation,
|
||||
request_id: u64,
|
||||
}
|
||||
|
||||
impl<Resp> Future for DispatchResponse<Resp> {
|
||||
type Output = io::Result<Resp>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
|
||||
let resp = ready!(self.response.poll_unpin(cx));
|
||||
self.cancellation.take();
|
||||
Poll::Ready(match resp {
|
||||
impl<Resp> ResponseGuard<'_, Resp> {
|
||||
async fn response(mut self) -> io::Result<Resp> {
|
||||
let response = (&mut self.response).await;
|
||||
// Cancel drop logic once a response has been received.
|
||||
mem::forget(self);
|
||||
match response {
|
||||
Ok(resp) => Ok(resp.message?),
|
||||
Err(oneshot::error::RecvError { .. }) => {
|
||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||
@@ -187,27 +181,25 @@ impl<Resp> Future for DispatchResponse<Resp> {
|
||||
// propagating cancellation.
|
||||
Err(io::Error::from(io::ErrorKind::ConnectionReset))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cancels the request when dropped, if not already complete.
|
||||
impl<Resp> Drop for DispatchResponse<Resp> {
|
||||
impl<Resp> Drop for ResponseGuard<'_, Resp> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(cancellation) = &mut self.cancellation {
|
||||
// The receiver needs to be closed to handle the edge case that the request has not
|
||||
// yet been received by the dispatch task. It is possible for the cancel message to
|
||||
// arrive before the request itself, in which case the request could get stuck in the
|
||||
// dispatch map forever if the server never responds (e.g. if the server dies while
|
||||
// responding). Even if the server does respond, it will have unnecessarily done work
|
||||
// for a client no longer waiting for a response. To avoid this, the dispatch task
|
||||
// checks if the receiver is closed before inserting the request in the map. By
|
||||
// closing the receiver before sending the cancel message, it is guaranteed that if the
|
||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||
// receiver as closed.
|
||||
self.response.close();
|
||||
cancellation.cancel(self.request_id);
|
||||
}
|
||||
// The receiver needs to be closed to handle the edge case that the request has not
|
||||
// yet been received by the dispatch task. It is possible for the cancel message to
|
||||
// arrive before the request itself, in which case the request could get stuck in the
|
||||
// dispatch map forever if the server never responds (e.g. if the server dies while
|
||||
// responding). Even if the server does respond, it will have unnecessarily done work
|
||||
// for a client no longer waiting for a response. To avoid this, the dispatch task
|
||||
// checks if the receiver is closed before inserting the request in the map. By
|
||||
// closing the receiver before sending the cancel message, it is guaranteed that if the
|
||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||
// receiver as closed.
|
||||
self.response.close();
|
||||
self.cancellation.cancel(self.request_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -343,7 +335,7 @@ where
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<DispatchRequest<Req, Resp>> {
|
||||
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
|
||||
info!(
|
||||
tracing::info!(
|
||||
"At in-flight request capacity ({}/{}).",
|
||||
self.in_flight_requests().len(),
|
||||
self.config.max_in_flight_requests
|
||||
@@ -360,10 +352,8 @@ where
|
||||
match ready!(self.pending_requests_mut().poll_recv(cx)) {
|
||||
Some(request) => {
|
||||
if request.response_completion.is_closed() {
|
||||
trace!(
|
||||
"[{}] Request canceled before being sent.",
|
||||
request.ctx.trace_id()
|
||||
);
|
||||
let _entered = request.span.enter();
|
||||
tracing::info!("AbortRequest");
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -381,14 +371,15 @@ where
|
||||
fn poll_next_cancellation(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<(context::Context, u64)> {
|
||||
) -> PollIo<(context::Context, Span, u64)> {
|
||||
ready!(self.ensure_writeable(cx)?);
|
||||
|
||||
loop {
|
||||
match ready!(self.canceled_requests_mut().poll_next_unpin(cx)) {
|
||||
Some(request_id) => {
|
||||
if let Some(ctx) = self.in_flight_requests().cancel_request(request_id) {
|
||||
return Poll::Ready(Some(Ok((ctx, request_id))));
|
||||
if let Some((ctx, span)) = self.in_flight_requests().cancel_request(request_id)
|
||||
{
|
||||
return Poll::Ready(Some(Ok((ctx, span, request_id))));
|
||||
}
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
@@ -407,46 +398,56 @@ where
|
||||
}
|
||||
|
||||
fn poll_write_request<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
let dispatch_request = match ready!(self.as_mut().poll_next_request(cx)?) {
|
||||
let DispatchRequest {
|
||||
ctx,
|
||||
span,
|
||||
request_id,
|
||||
request,
|
||||
response_completion,
|
||||
} = match ready!(self.as_mut().poll_next_request(cx)?) {
|
||||
Some(dispatch_request) => dispatch_request,
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
let entered = span.enter();
|
||||
// poll_next_request only returns Ready if there is room to buffer another request.
|
||||
// Therefore, we can call write_request without fear of erroring due to a full
|
||||
// buffer.
|
||||
let request_id = dispatch_request.request_id;
|
||||
let request_id = request_id;
|
||||
let request = ClientMessage::Request(Request {
|
||||
id: request_id,
|
||||
message: dispatch_request.request,
|
||||
message: request,
|
||||
context: context::Context {
|
||||
deadline: dispatch_request.ctx.deadline,
|
||||
trace_context: dispatch_request.ctx.trace_context,
|
||||
deadline: ctx.deadline,
|
||||
trace_context: ctx.trace_context,
|
||||
},
|
||||
});
|
||||
self.transport_pin_mut().start_send(request)?;
|
||||
let deadline = ctx.deadline;
|
||||
tracing::info!(
|
||||
tarpc.deadline = %humantime::format_rfc3339(deadline),
|
||||
"SendRequest"
|
||||
);
|
||||
drop(entered);
|
||||
|
||||
self.in_flight_requests()
|
||||
.insert_request(
|
||||
request_id,
|
||||
dispatch_request.ctx,
|
||||
dispatch_request.response_completion,
|
||||
)
|
||||
.insert_request(request_id, ctx, span, response_completion)
|
||||
.expect("Request IDs should be unique");
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
fn poll_write_cancel<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
let (context, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
|
||||
Some((context, request_id)) => (context, request_id),
|
||||
let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
|
||||
Some(triple) => triple,
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
let _entered = span.enter();
|
||||
|
||||
let trace_id = *context.trace_id();
|
||||
let cancel = ClientMessage::Cancel {
|
||||
trace_context: context.trace_context,
|
||||
request_id,
|
||||
};
|
||||
self.transport_pin_mut().start_send(cancel)?;
|
||||
trace!("[{}] Cancel message sent.", trace_id);
|
||||
tracing::info!("CancelRequest");
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
@@ -473,15 +474,15 @@ where
|
||||
.context("failed to write to transport")?,
|
||||
) {
|
||||
(Poll::Ready(None), _) => {
|
||||
info!("Shutdown: read half closed, so shutting down.");
|
||||
tracing::info!("Shutdown: read half closed, so shutting down.");
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
(read, Poll::Ready(None)) => {
|
||||
if self.in_flight_requests().is_empty() {
|
||||
info!("Shutdown: write half closed, and no requests in flight.");
|
||||
if self.in_flight_requests.is_empty() {
|
||||
tracing::info!("Shutdown: write half closed, and no requests in flight.");
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
info!(
|
||||
tracing::info!(
|
||||
"Shutdown: write half closed, and {} requests in flight.",
|
||||
self.in_flight_requests().len()
|
||||
);
|
||||
@@ -502,6 +503,7 @@ where
|
||||
#[derive(Debug)]
|
||||
struct DispatchRequest<Req, Resp> {
|
||||
pub ctx: context::Context,
|
||||
pub span: Span,
|
||||
pub request_id: u64,
|
||||
pub request: Req,
|
||||
pub response_completion: oneshot::Sender<Response<Resp>>,
|
||||
@@ -518,16 +520,14 @@ struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
|
||||
/// Returns a channel to send request cancellation messages.
|
||||
fn cancellations() -> (RequestCancellation, CanceledRequests) {
|
||||
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
|
||||
// bounded by the number of in-flight requests. Additionally, each request has a clone
|
||||
// of the sender, so the bounded channel would have the same behavior,
|
||||
// since it guarantees a slot.
|
||||
// bounded by the number of in-flight requests.
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
(RequestCancellation(tx), CanceledRequests(rx))
|
||||
}
|
||||
|
||||
impl RequestCancellation {
|
||||
/// Cancels the request with ID `request_id`.
|
||||
fn cancel(&mut self, request_id: u64) {
|
||||
fn cancel(&self, request_id: u64) {
|
||||
let _ = self.0.send(request_id);
|
||||
}
|
||||
}
|
||||
@@ -549,8 +549,8 @@ impl Stream for CanceledRequests {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
cancellations, CanceledRequests, Channel, DispatchResponse, RequestCancellation,
|
||||
RequestDispatch,
|
||||
cancellations, CanceledRequests, Channel, DispatchRequest, RequestCancellation,
|
||||
RequestDispatch, ResponseGuard,
|
||||
};
|
||||
use crate::{
|
||||
client::{in_flight_requests::InFlightRequests, Config},
|
||||
@@ -558,19 +558,46 @@ mod tests {
|
||||
transport::{self, channel::UnboundedChannel},
|
||||
ClientMessage, Response,
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{prelude::*, task::*};
|
||||
use std::{pin::Pin, sync::atomic::AtomicUsize, sync::Arc};
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
pin::Pin,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::Span;
|
||||
|
||||
#[tokio::test]
|
||||
async fn response_completes_request_future() {
|
||||
let (mut dispatch, mut _channel, mut server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
dispatch
|
||||
.in_flight_requests
|
||||
.insert_request(0, context::current(), Span::current(), tx)
|
||||
.unwrap();
|
||||
server_channel
|
||||
.send(Response {
|
||||
request_id: 0,
|
||||
message: Ok("Resp".into()),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
|
||||
assert_matches!(rx.try_recv(), Ok(Response { request_id: 0, message: Ok(resp) }) if resp == "Resp");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_response_cancels_on_drop() {
|
||||
let (cancellation, mut canceled_requests) = cancellations();
|
||||
let (_, response) = oneshot::channel();
|
||||
drop(DispatchResponse::<u32> {
|
||||
response,
|
||||
cancellation: Some(cancellation),
|
||||
let (_, mut response) = oneshot::channel();
|
||||
drop(ResponseGuard::<u32> {
|
||||
response: &mut response,
|
||||
cancellation: &cancellation,
|
||||
request_id: 3,
|
||||
ctx: context::current(),
|
||||
});
|
||||
// resp's drop() is run, which should send a cancel message.
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
@@ -580,23 +607,22 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn dispatch_response_doesnt_cancel_after_complete() {
|
||||
let (cancellation, mut canceled_requests) = cancellations();
|
||||
let (tx, response) = oneshot::channel();
|
||||
let (tx, mut response) = oneshot::channel();
|
||||
tx.send(Response {
|
||||
request_id: 0,
|
||||
message: Ok("well done"),
|
||||
})
|
||||
.unwrap();
|
||||
{
|
||||
DispatchResponse {
|
||||
response,
|
||||
cancellation: Some(cancellation),
|
||||
request_id: 3,
|
||||
ctx: context::current(),
|
||||
}
|
||||
.await
|
||||
.unwrap();
|
||||
// resp's drop() is run, but should not send a cancel message.
|
||||
// resp's drop() is run, but should not send a cancel message.
|
||||
ResponseGuard {
|
||||
response: &mut response,
|
||||
cancellation: &cancellation,
|
||||
request_id: 3,
|
||||
}
|
||||
.response()
|
||||
.await
|
||||
.unwrap();
|
||||
drop(cancellation);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(None));
|
||||
}
|
||||
@@ -604,12 +630,12 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn stage_request() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let dispatch = Pin::new(&mut dispatch);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let _resp = send_request(&mut channel, "hi").await;
|
||||
let _resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
|
||||
let req = dispatch.poll_next_request(cx).ready();
|
||||
let req = dispatch.as_mut().poll_next_request(cx).ready();
|
||||
assert!(req.is_some());
|
||||
|
||||
let req = req.unwrap();
|
||||
@@ -621,10 +647,10 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn stage_request_channel_dropped_doesnt_panic() {
|
||||
let (mut dispatch, mut channel, mut server_channel) = set_up();
|
||||
let mut dispatch = Pin::new(&mut dispatch);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let _ = send_request(&mut channel, "hi").await;
|
||||
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
drop(channel);
|
||||
|
||||
assert!(dispatch.as_mut().poll(cx).is_ready());
|
||||
@@ -642,61 +668,68 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn stage_request_response_future_dropped_is_canceled_before_sending() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let dispatch = Pin::new(&mut dispatch);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let _ = send_request(&mut channel, "hi").await;
|
||||
let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
|
||||
// Drop the channel so polling returns none if no requests are currently ready.
|
||||
drop(channel);
|
||||
// Test that a request future dropped before it's processed by dispatch will cause the request
|
||||
// to not be added to the in-flight request map.
|
||||
assert!(dispatch.poll_next_request(cx).ready().is_none());
|
||||
assert!(dispatch.as_mut().poll_next_request(cx).ready().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stage_request_response_future_dropped_is_canceled_after_sending() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let mut dispatch = Pin::new(&mut dispatch);
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let req = send_request(&mut channel, "hi").await;
|
||||
let req = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
|
||||
assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
|
||||
assert!(!dispatch.in_flight_requests().is_empty());
|
||||
assert!(!dispatch.in_flight_requests.is_empty());
|
||||
|
||||
// Test that a request future dropped after it's processed by dispatch will cause the request
|
||||
// to be removed from the in-flight request map.
|
||||
drop(req);
|
||||
if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(cx).unwrap() {
|
||||
// ok
|
||||
} else {
|
||||
panic!("Expected request to be cancelled")
|
||||
};
|
||||
assert!(dispatch.in_flight_requests().is_empty());
|
||||
assert_matches!(
|
||||
dispatch.as_mut().poll_next_cancellation(cx),
|
||||
Poll::Ready(Some(Ok(_)))
|
||||
);
|
||||
assert!(dispatch.in_flight_requests.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stage_request_response_closed_skipped() {
|
||||
let (mut dispatch, mut channel, _server_channel) = set_up();
|
||||
let dispatch = Pin::new(&mut dispatch);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
// Test that a request future that's closed its receiver but not yet canceled its request --
|
||||
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
|
||||
// map.
|
||||
let mut resp = send_request(&mut channel, "hi").await;
|
||||
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
resp.response.close();
|
||||
|
||||
assert!(dispatch.poll_next_request(cx).is_pending());
|
||||
assert!(dispatch.as_mut().poll_next_request(cx).is_pending());
|
||||
}
|
||||
|
||||
fn set_up() -> (
|
||||
RequestDispatch<String, String, UnboundedChannel<Response<String>, ClientMessage<String>>>,
|
||||
Pin<
|
||||
Box<
|
||||
RequestDispatch<
|
||||
String,
|
||||
String,
|
||||
UnboundedChannel<Response<String>, ClientMessage<String>>,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
Channel<String, String>,
|
||||
UnboundedChannel<ClientMessage<String>, Response<String>>,
|
||||
) {
|
||||
let _ = env_logger::try_init();
|
||||
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
|
||||
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
||||
let (cancel_tx, canceled_requests) = mpsc::unbounded_channel();
|
||||
@@ -717,17 +750,31 @@ mod tests {
|
||||
next_request_id: Arc::new(AtomicUsize::new(0)),
|
||||
};
|
||||
|
||||
(dispatch, channel, server_channel)
|
||||
(Box::pin(dispatch), channel, server_channel)
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
channel: &mut Channel<String, String>,
|
||||
async fn send_request<'a>(
|
||||
channel: &'a mut Channel<String, String>,
|
||||
request: &str,
|
||||
) -> DispatchResponse<String> {
|
||||
channel
|
||||
.send(context::current(), request.to_string())
|
||||
.await
|
||||
.unwrap()
|
||||
response_completion: oneshot::Sender<Response<String>>,
|
||||
response: &'a mut oneshot::Receiver<Response<String>>,
|
||||
) -> ResponseGuard<'a, String> {
|
||||
let request_id =
|
||||
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||
let request = DispatchRequest {
|
||||
ctx: context::current(),
|
||||
span: Span::current(),
|
||||
request_id,
|
||||
request: request.to_string(),
|
||||
response_completion,
|
||||
};
|
||||
channel.to_dispatch.send(request).await.unwrap();
|
||||
|
||||
ResponseGuard {
|
||||
response,
|
||||
cancellation: &channel.cancellation,
|
||||
request_id,
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_response(
|
||||
|
||||
@@ -5,7 +5,6 @@ use crate::{
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::ready;
|
||||
use log::{debug, trace};
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
io,
|
||||
@@ -13,6 +12,7 @@ use std::{
|
||||
};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_util::time::delay_queue::{self, DelayQueue};
|
||||
use tracing::Span;
|
||||
|
||||
/// Requests already written to the wire that haven't yet received responses.
|
||||
#[derive(Debug)]
|
||||
@@ -33,6 +33,7 @@ impl<Resp> Default for InFlightRequests<Resp> {
|
||||
#[derive(Debug)]
|
||||
struct RequestData<Resp> {
|
||||
ctx: context::Context,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Response<Resp>>,
|
||||
/// The key to remove the timer for the request's deadline.
|
||||
deadline_key: delay_queue::Key,
|
||||
@@ -59,20 +60,16 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
&mut self,
|
||||
request_id: u64,
|
||||
ctx: context::Context,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Response<Resp>>,
|
||||
) -> Result<(), AlreadyExistsError> {
|
||||
match self.request_data.entry(request_id) {
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
let timeout = ctx.deadline.time_until();
|
||||
trace!(
|
||||
"[{}] Queuing request with timeout {:?}.",
|
||||
ctx.trace_id(),
|
||||
timeout,
|
||||
);
|
||||
|
||||
let deadline_key = self.deadlines.insert(request_id, timeout);
|
||||
vacant.insert(RequestData {
|
||||
ctx,
|
||||
span,
|
||||
response_completion,
|
||||
deadline_key,
|
||||
});
|
||||
@@ -85,15 +82,15 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
/// Removes a request without aborting. Returns true iff the request was found.
|
||||
pub fn complete_request(&mut self, response: Response<Resp>) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&response.request_id) {
|
||||
let _entered = request_data.span.enter();
|
||||
tracing::info!("ReceiveResponse");
|
||||
self.request_data.compact(0.1);
|
||||
|
||||
trace!("[{}] Received response.", request_data.ctx.trace_id());
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
request_data.complete(response);
|
||||
let _ = request_data.response_completion.send(response);
|
||||
return true;
|
||||
}
|
||||
|
||||
debug!(
|
||||
tracing::debug!(
|
||||
"No in-flight request found for request_id = {}.",
|
||||
response.request_id
|
||||
);
|
||||
@@ -104,12 +101,11 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
|
||||
/// Cancels a request without completing (typically used when a request handle was dropped
|
||||
/// before the request completed).
|
||||
pub fn cancel_request(&mut self, request_id: u64) -> Option<context::Context> {
|
||||
pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
trace!("[{}] Cancelling request.", request_data.ctx.trace_id());
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
Some(request_data.ctx)
|
||||
Some((request_data.ctx, request_data.span))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -122,8 +118,12 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
Some(Ok(expired)) => {
|
||||
let request_id = expired.into_inner();
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
let _entered = request_data.span.enter();
|
||||
tracing::error!("DeadlineExceeded");
|
||||
self.request_data.compact(0.1);
|
||||
request_data.complete(Self::deadline_exceeded_error(request_id));
|
||||
let _ = request_data
|
||||
.response_completion
|
||||
.send(Self::deadline_exceeded_error(request_id));
|
||||
}
|
||||
Some(Ok(request_id))
|
||||
}
|
||||
@@ -142,21 +142,3 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// When InFlightRequests is dropped, any outstanding requests are completed with a
|
||||
/// deadline-exceeded error.
|
||||
impl<Resp> Drop for InFlightRequests<Resp> {
|
||||
fn drop(&mut self) {
|
||||
let deadlines = &mut self.deadlines;
|
||||
for (_, request_data) in self.request_data.drain() {
|
||||
let expired = deadlines.remove(&request_data.deadline_key);
|
||||
request_data.complete(Self::deadline_exceeded_error(expired.into_inner()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Resp> RequestData<Resp> {
|
||||
fn complete(self, response: Response<Resp>) {
|
||||
let _ = self.response_completion.send(response);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,13 @@
|
||||
//! client to server and is used by the server to enforce response deadlines.
|
||||
|
||||
use crate::trace::{self, TraceId};
|
||||
use opentelemetry::trace::TraceContextExt;
|
||||
use static_assertions::assert_impl_all;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
/// A request context that carries request-scoped information like deadlines and trace information.
|
||||
/// It is sent from client to server and is used by the server to enforce response deadlines.
|
||||
@@ -33,10 +38,6 @@ pub struct Context {
|
||||
|
||||
assert_impl_all!(Context: Send, Sync);
|
||||
|
||||
tokio::task_local! {
|
||||
static CURRENT_CONTEXT: Context;
|
||||
}
|
||||
|
||||
fn ten_seconds_from_now() -> SystemTime {
|
||||
SystemTime::now() + Duration::from_secs(10)
|
||||
}
|
||||
@@ -46,64 +47,56 @@ pub fn current() -> Context {
|
||||
Context::current()
|
||||
}
|
||||
|
||||
impl Context {
|
||||
/// Returns a Context containing a new root trace context and a default deadline.
|
||||
pub fn new_root() -> Self {
|
||||
Self {
|
||||
deadline: ten_seconds_from_now(),
|
||||
trace_context: trace::Context::new_root(),
|
||||
}
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Deadline(SystemTime);
|
||||
|
||||
impl Default for Deadline {
|
||||
fn default() -> Self {
|
||||
Self(ten_seconds_from_now())
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
/// Returns the context for the current request, or a default Context if no request is active.
|
||||
pub fn current() -> Self {
|
||||
CURRENT_CONTEXT
|
||||
.try_with(Self::clone)
|
||||
.unwrap_or_else(|_| Self::new_root())
|
||||
let span = tracing::Span::current();
|
||||
Self {
|
||||
trace_context: trace::Context::try_from(&span)
|
||||
.unwrap_or_else(|_| trace::Context::default()),
|
||||
deadline: span
|
||||
.context()
|
||||
.get::<Deadline>()
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the ID of the request-scoped trace.
|
||||
pub fn trace_id(&self) -> &TraceId {
|
||||
&self.trace_context.trace_id
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a future with this context as the current context.
|
||||
pub async fn scope<F>(self, f: F) -> F::Output
|
||||
where
|
||||
F: std::future::Future,
|
||||
{
|
||||
CURRENT_CONTEXT.scope(self, f).await
|
||||
/// An extension trait for [`tracing::Span`] for propagating tarpc Contexts.
|
||||
pub(crate) trait SpanExt {
|
||||
/// Sets the given context on this span. Newly-created spans will be children of the given
|
||||
/// context's trace context.
|
||||
fn set_context(&self, context: &Context);
|
||||
}
|
||||
|
||||
impl SpanExt for tracing::Span {
|
||||
fn set_context(&self, context: &Context) {
|
||||
self.set_parent(
|
||||
opentelemetry::Context::new()
|
||||
.with_remote_span_context(opentelemetry::trace::SpanContext::new(
|
||||
opentelemetry::trace::TraceId::from(context.trace_context.trace_id),
|
||||
opentelemetry::trace::SpanId::from(context.trace_context.span_id),
|
||||
context.trace_context.sampling_decision as u8,
|
||||
true,
|
||||
opentelemetry::trace::TraceState::default(),
|
||||
))
|
||||
.with_value(Deadline(context.deadline)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use {
|
||||
assert_matches::assert_matches, futures::prelude::*, futures_test::task::noop_context,
|
||||
std::task::Poll,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn context_current_has_no_parent() {
|
||||
let ctx = current();
|
||||
assert_matches!(ctx.trace_context.parent_id, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_root_has_no_parent() {
|
||||
let ctx = Context::new_root();
|
||||
assert_matches!(ctx.trace_context.parent_id, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_scope() {
|
||||
let ctx = Context::new_root();
|
||||
let mut ctx_copy = Box::pin(ctx.scope(async { current() }));
|
||||
assert_matches!(ctx_copy.poll_unpin(&mut noop_context()),
|
||||
Poll::Ready(Context {
|
||||
deadline,
|
||||
trace_context: trace::Context { trace_id, span_id, parent_id },
|
||||
}) if deadline == ctx.deadline
|
||||
&& trace_id == ctx.trace_context.trace_id
|
||||
&& span_id == ctx.trace_context.span_id
|
||||
&& parent_id == ctx.trace_context.parent_id);
|
||||
}
|
||||
|
||||
@@ -199,6 +199,7 @@
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
#[doc(hidden)]
|
||||
pub use serde;
|
||||
|
||||
#[cfg(feature = "serde-transport")]
|
||||
|
||||
@@ -6,7 +6,10 @@
|
||||
|
||||
//! Provides a server that concurrently handles many connections sending multiplexed requests.
|
||||
|
||||
use crate::{context, ClientMessage, PollIo, Request, Response, Transport};
|
||||
use crate::{
|
||||
context::{self, SpanExt},
|
||||
trace, ClientMessage, PollIo, Request, Response, Transport,
|
||||
};
|
||||
use futures::{
|
||||
future::{AbortRegistration, Abortable},
|
||||
prelude::*,
|
||||
@@ -14,12 +17,11 @@ use futures::{
|
||||
stream::Fuse,
|
||||
task::*,
|
||||
};
|
||||
use humantime::format_rfc3339;
|
||||
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
|
||||
use log::{debug, info, trace};
|
||||
use pin_project::pin_project;
|
||||
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
|
||||
use std::{convert::TryFrom, fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{info_span, instrument::Instrument, Span};
|
||||
|
||||
mod filter;
|
||||
mod in_flight_requests;
|
||||
@@ -67,6 +69,11 @@ pub trait Serve<Req> {
|
||||
/// Type of response future.
|
||||
type Fut: Future<Output = Self::Resp>;
|
||||
|
||||
/// Extracts a method name from the request.
|
||||
fn method(&self, _request: &Req) -> Option<&'static str> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Responds to a single request.
|
||||
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut;
|
||||
}
|
||||
@@ -219,12 +226,18 @@ where
|
||||
/// Type of response sink item.
|
||||
type Resp;
|
||||
|
||||
/// The wrapped transport.
|
||||
type Transport;
|
||||
|
||||
/// Configuration of the channel.
|
||||
fn config(&self) -> &Config;
|
||||
|
||||
/// Returns the number of in-flight requests over this channel.
|
||||
fn in_flight_requests(&self) -> usize;
|
||||
|
||||
/// Returns the transport underlying the channel.
|
||||
fn transport(&self) -> &Self::Transport;
|
||||
|
||||
/// Caps the number of concurrent requests to `limit`.
|
||||
fn max_concurrent_requests(self, limit: usize) -> Throttler<Self>
|
||||
where
|
||||
@@ -240,6 +253,7 @@ where
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
span: Span,
|
||||
) -> Result<AbortRegistration, AlreadyExistsError>;
|
||||
|
||||
/// Returns a stream of requests that automatically handle request cancellation and response
|
||||
@@ -294,12 +308,9 @@ where
|
||||
|
||||
loop {
|
||||
let expiration_status = match self.in_flight_requests_mut().poll_expired(cx)? {
|
||||
Poll::Ready(Some(request_id)) => {
|
||||
// No need to send a response, since the client wouldn't be waiting for one
|
||||
// anymore.
|
||||
debug!("Request {} did not complete before deadline", request_id);
|
||||
Ready
|
||||
}
|
||||
// No need to send a response, since the client wouldn't be waiting for one
|
||||
// anymore.
|
||||
Poll::Ready(Some(_)) => Ready,
|
||||
Poll::Ready(None) => Closed,
|
||||
Poll::Pending => Pending,
|
||||
};
|
||||
@@ -313,18 +324,10 @@ where
|
||||
trace_context,
|
||||
request_id,
|
||||
} => {
|
||||
if self.in_flight_requests_mut().cancel_request(request_id) {
|
||||
let remaining = self.in_flight_requests.len();
|
||||
trace!(
|
||||
"[{}] Request canceled. In-flight requests = {}",
|
||||
trace_context.trace_id,
|
||||
remaining,
|
||||
);
|
||||
} else {
|
||||
trace!(
|
||||
"[{}] Received cancellation, but response handler \
|
||||
is already complete.",
|
||||
trace_context.trace_id,
|
||||
if !self.in_flight_requests_mut().cancel_request(request_id) {
|
||||
tracing::trace!(
|
||||
rpc.trace_id = %trace_context.trace_id,
|
||||
"Received cancellation, but response handler is already complete.",
|
||||
);
|
||||
}
|
||||
Ready
|
||||
@@ -354,11 +357,19 @@ where
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
|
||||
self.as_mut()
|
||||
if let Some(span) = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove_request(response.request_id);
|
||||
self.project().transport.start_send(response)
|
||||
.remove_request(response.request_id)
|
||||
{
|
||||
let _entered = span.enter();
|
||||
tracing::info!("SendResponse");
|
||||
self.project().transport.start_send(response)
|
||||
} else {
|
||||
// If the request isn't tracked anymore, there's no need to send the response.
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
@@ -382,6 +393,7 @@ where
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
type Transport = T;
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
&self.config
|
||||
@@ -391,14 +403,19 @@ where
|
||||
self.in_flight_requests.len()
|
||||
}
|
||||
|
||||
fn transport(&self) -> &Self::Transport {
|
||||
self.get_ref()
|
||||
}
|
||||
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
span: Span,
|
||||
) -> Result<AbortRegistration, AlreadyExistsError> {
|
||||
self.project()
|
||||
.in_flight_requests
|
||||
.start_request(id, deadline)
|
||||
.start_request(id, deadline, span)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -412,9 +429,9 @@ where
|
||||
#[pin]
|
||||
channel: C,
|
||||
/// Responses waiting to be written to the wire.
|
||||
pending_responses: mpsc::Receiver<(context::Context, Response<C::Resp>)>,
|
||||
pending_responses: mpsc::Receiver<Response<C::Resp>>,
|
||||
/// Handed out to request handlers to fan in responses.
|
||||
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
|
||||
responses_tx: mpsc::Sender<Response<C::Resp>>,
|
||||
}
|
||||
|
||||
impl<C> Requests<C>
|
||||
@@ -429,7 +446,7 @@ where
|
||||
/// Returns the inner channel over which messages are sent and received.
|
||||
pub fn pending_responses_mut<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
) -> &'a mut mpsc::Receiver<(context::Context, Response<C::Resp>)> {
|
||||
) -> &'a mut mpsc::Receiver<Response<C::Resp>> {
|
||||
self.as_mut().project().pending_responses
|
||||
}
|
||||
|
||||
@@ -439,33 +456,45 @@ where
|
||||
) -> PollIo<InFlightRequest<C::Req, C::Resp>> {
|
||||
loop {
|
||||
match ready!(self.channel_pin_mut().poll_next(cx)?) {
|
||||
Some(request) => {
|
||||
trace!(
|
||||
"[{}] Handling request with deadline {}.",
|
||||
request.context.trace_id(),
|
||||
format_rfc3339(request.context.deadline),
|
||||
Some(mut request) => {
|
||||
let span = info_span!(
|
||||
"RPC",
|
||||
rpc.trace_id = %request.context.trace_id(),
|
||||
otel.kind = "server",
|
||||
otel.name = tracing::field::Empty,
|
||||
);
|
||||
|
||||
match self
|
||||
.channel_pin_mut()
|
||||
.start_request(request.id, request.context.deadline)
|
||||
{
|
||||
span.set_context(&request.context);
|
||||
request.context.trace_context =
|
||||
trace::Context::try_from(&span).unwrap_or_else(|_| {
|
||||
tracing::trace!(
|
||||
"OpenTelemetry subscriber not installed; making unsampled
|
||||
child context."
|
||||
);
|
||||
request.context.trace_context.new_child()
|
||||
});
|
||||
let entered = span.enter();
|
||||
tracing::info!("ReceiveRequest");
|
||||
let start = self.channel_pin_mut().start_request(
|
||||
request.id,
|
||||
request.context.deadline,
|
||||
span.clone(),
|
||||
);
|
||||
match start {
|
||||
Ok(abort_registration) => {
|
||||
let response_tx = self.responses_tx.clone();
|
||||
drop(entered);
|
||||
return Poll::Ready(Some(Ok(InFlightRequest {
|
||||
request,
|
||||
response_tx: self.responses_tx.clone(),
|
||||
response_tx,
|
||||
abort_registration,
|
||||
})))
|
||||
span,
|
||||
})));
|
||||
}
|
||||
// Instead of closing the channel if a duplicate request is sent, just
|
||||
// ignore it, since it's already being processed. Note that we cannot
|
||||
// return Poll::Pending here, since nothing has scheduled a wakeup yet.
|
||||
Err(AlreadyExistsError) => {
|
||||
info!(
|
||||
"[{}] Request ID {} delivered more than once.",
|
||||
request.context.trace_id(),
|
||||
request.id
|
||||
);
|
||||
tracing::trace!("DuplicateRequest");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -481,12 +510,7 @@ where
|
||||
read_half_closed: bool,
|
||||
) -> PollIo<()> {
|
||||
match self.as_mut().poll_next_response(cx)? {
|
||||
Poll::Ready(Some((context, response))) => {
|
||||
trace!(
|
||||
"[{}] Staging response. In-flight requests = {}.",
|
||||
context.trace_id(),
|
||||
self.channel.in_flight_requests(),
|
||||
);
|
||||
Poll::Ready(Some(response)) => {
|
||||
// A Ready result from poll_next_response means the Channel is ready to be written
|
||||
// to. Therefore, we can call start_send without worry of a full buffer.
|
||||
self.channel_pin_mut().start_send(response)?;
|
||||
@@ -520,7 +544,7 @@ where
|
||||
fn poll_next_response(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<(context::Context, Response<C::Resp>)> {
|
||||
) -> PollIo<Response<C::Resp>> {
|
||||
ready!(self.ensure_writeable(cx)?);
|
||||
|
||||
match ready!(self.pending_responses_mut().poll_recv(cx)) {
|
||||
@@ -555,8 +579,9 @@ where
|
||||
#[derive(Debug)]
|
||||
pub struct InFlightRequest<Req, Res> {
|
||||
request: Request<Req>,
|
||||
response_tx: mpsc::Sender<(context::Context, Response<Res>)>,
|
||||
response_tx: mpsc::Sender<Response<Res>>,
|
||||
abort_registration: AbortRegistration,
|
||||
span: Span,
|
||||
}
|
||||
|
||||
impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
@@ -576,36 +601,38 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
/// message](ClientMessage::Cancel) for this request.
|
||||
/// 2. The request [deadline](crate::context::Context::deadline) is reached.
|
||||
/// 3. The service function completes.
|
||||
pub fn execute<S>(self, serve: S) -> impl Future<Output = ()>
|
||||
pub async fn execute<S>(self, serve: S)
|
||||
where
|
||||
S: Serve<Req, Resp = Res>,
|
||||
{
|
||||
let Self {
|
||||
abort_registration,
|
||||
request,
|
||||
response_tx,
|
||||
} = self;
|
||||
Abortable::new(
|
||||
async move {
|
||||
let Request {
|
||||
request:
|
||||
Request {
|
||||
context,
|
||||
message,
|
||||
id: request_id,
|
||||
} = request;
|
||||
context
|
||||
.scope(async {
|
||||
let response = serve.serve(context, message).await;
|
||||
let response = Response {
|
||||
request_id,
|
||||
message: Ok(response),
|
||||
};
|
||||
let _ = response_tx.send((context, response)).await;
|
||||
})
|
||||
.await;
|
||||
},
|
||||
response_tx,
|
||||
span,
|
||||
} = self;
|
||||
let method = serve.method(&message);
|
||||
span.record("otel.name", &method.unwrap_or(""));
|
||||
let _ = Abortable::new(
|
||||
async move {
|
||||
let response = serve.serve(context, message).await;
|
||||
tracing::info!("CompleteRequest");
|
||||
let response = Response {
|
||||
request_id,
|
||||
message: Ok(response),
|
||||
};
|
||||
let _ = response_tx.send(response).await;
|
||||
tracing::info!("BufferResponse");
|
||||
},
|
||||
abort_registration,
|
||||
)
|
||||
.unwrap_or_else(|_| {})
|
||||
.instrument(span)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -711,7 +738,7 @@ where
|
||||
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||
tokio::spawn(channel.execute(self.serve.clone()));
|
||||
}
|
||||
info!("Server shutting down.");
|
||||
tracing::info!("Server shutting down.");
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
@@ -737,7 +764,7 @@ where
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
info!("Requests stream errored out: {}", e);
|
||||
tracing::info!("Requests stream errored out: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -823,10 +850,12 @@ mod tests {
|
||||
|
||||
channel
|
||||
.as_mut()
|
||||
.start_request(0, SystemTime::now())
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
assert_matches!(
|
||||
channel.as_mut().start_request(0, SystemTime::now()),
|
||||
channel
|
||||
.as_mut()
|
||||
.start_request(0, SystemTime::now(), Span::current()),
|
||||
Err(AlreadyExistsError)
|
||||
);
|
||||
}
|
||||
@@ -838,11 +867,11 @@ mod tests {
|
||||
tokio::time::pause();
|
||||
let abort_registration0 = channel
|
||||
.as_mut()
|
||||
.start_request(0, SystemTime::now())
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
let abort_registration1 = channel
|
||||
.as_mut()
|
||||
.start_request(1, SystemTime::now())
|
||||
.start_request(1, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
|
||||
|
||||
@@ -861,7 +890,11 @@ mod tests {
|
||||
tokio::time::pause();
|
||||
let abort_registration = channel
|
||||
.as_mut()
|
||||
.start_request(0, SystemTime::now() + Duration::from_millis(100))
|
||||
.start_request(
|
||||
0,
|
||||
SystemTime::now() + Duration::from_millis(100),
|
||||
Span::current(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
tx.send(ClientMessage::Cancel {
|
||||
@@ -886,7 +919,11 @@ mod tests {
|
||||
tokio::time::pause();
|
||||
let _abort_registration = channel
|
||||
.as_mut()
|
||||
.start_request(0, SystemTime::now() + Duration::from_millis(100))
|
||||
.start_request(
|
||||
0,
|
||||
SystemTime::now() + Duration::from_millis(100),
|
||||
Span::current(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
drop(tx);
|
||||
@@ -924,7 +961,7 @@ mod tests {
|
||||
tokio::time::pause();
|
||||
let abort_registration = channel
|
||||
.as_mut()
|
||||
.start_request(0, SystemTime::now())
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
|
||||
|
||||
@@ -943,7 +980,7 @@ mod tests {
|
||||
|
||||
channel
|
||||
.as_mut()
|
||||
.start_request(0, SystemTime::now())
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
assert_eq!(channel.in_flight_requests(), 1);
|
||||
channel
|
||||
@@ -961,6 +998,11 @@ mod tests {
|
||||
let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);
|
||||
|
||||
// Response written to the transport.
|
||||
requests
|
||||
.as_mut()
|
||||
.channel_pin_mut()
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
requests
|
||||
.as_mut()
|
||||
.channel_pin_mut()
|
||||
@@ -975,20 +1017,17 @@ mod tests {
|
||||
.as_mut()
|
||||
.project()
|
||||
.responses_tx
|
||||
.send((
|
||||
context::current(),
|
||||
Response {
|
||||
request_id: 1,
|
||||
message: Ok(()),
|
||||
},
|
||||
))
|
||||
.send(Response {
|
||||
request_id: 1,
|
||||
message: Ok(()),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
requests
|
||||
.as_mut()
|
||||
.channel_pin_mut()
|
||||
.start_request(1, SystemTime::now())
|
||||
.start_request(1, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
|
||||
assert_matches!(
|
||||
@@ -1002,6 +1041,11 @@ mod tests {
|
||||
let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);
|
||||
|
||||
// Response written to the transport.
|
||||
requests
|
||||
.as_mut()
|
||||
.channel_pin_mut()
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
requests
|
||||
.as_mut()
|
||||
.channel_pin_mut()
|
||||
@@ -1014,22 +1058,18 @@ mod tests {
|
||||
// Response waiting to be written.
|
||||
requests
|
||||
.as_mut()
|
||||
.project()
|
||||
.responses_tx
|
||||
.send((
|
||||
context::current(),
|
||||
Response {
|
||||
request_id: 1,
|
||||
message: Ok(()),
|
||||
},
|
||||
))
|
||||
.await
|
||||
.channel_pin_mut()
|
||||
.start_request(1, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
|
||||
requests
|
||||
.as_mut()
|
||||
.channel_pin_mut()
|
||||
.start_request(1, SystemTime::now())
|
||||
.project()
|
||||
.responses_tx
|
||||
.send(Response {
|
||||
request_id: 1,
|
||||
message: Ok(()),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_matches!(
|
||||
|
||||
@@ -10,7 +10,6 @@ use crate::{
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
|
||||
use log::{debug, info, trace};
|
||||
use pin_project::pin_project;
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::{
|
||||
@@ -18,6 +17,7 @@ use std::{
|
||||
time::SystemTime,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, info, trace};
|
||||
|
||||
/// A single-threaded filter that drops channels based on per-key limits.
|
||||
#[pin_project]
|
||||
@@ -103,6 +103,7 @@ where
|
||||
{
|
||||
type Req = C::Req;
|
||||
type Resp = C::Resp;
|
||||
type Transport = C::Transport;
|
||||
|
||||
fn config(&self) -> &server::Config {
|
||||
self.inner.config()
|
||||
@@ -112,12 +113,17 @@ where
|
||||
self.inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn transport(&self) -> &Self::Transport {
|
||||
self.inner.transport()
|
||||
}
|
||||
|
||||
fn start_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
span: tracing::Span,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.inner_pin_mut().start_request(id, deadline)
|
||||
self.inner_pin_mut().start_request(id, deadline, span)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,11 +177,10 @@ where
|
||||
let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
|
||||
|
||||
trace!(
|
||||
"[{}] Opening channel ({}/{}) channels for key.",
|
||||
key,
|
||||
Arc::strong_count(&tracker),
|
||||
self.channels_per_key
|
||||
);
|
||||
channel_filter_key = %key,
|
||||
open_channels = Arc::strong_count(&tracker),
|
||||
max_open_channels = self.channels_per_key,
|
||||
"Opening channel");
|
||||
|
||||
Ok(TrackedChannel {
|
||||
tracker,
|
||||
@@ -200,9 +205,10 @@ where
|
||||
let count = o.get().strong_count();
|
||||
if count >= TryFrom::try_from(*self_.channels_per_key).unwrap() {
|
||||
info!(
|
||||
"[{}] Opened max channels from key ({}/{}).",
|
||||
key, count, self_.channels_per_key
|
||||
);
|
||||
channel_filter_key = %key,
|
||||
open_channels = count,
|
||||
max_open_channels = *self_.channels_per_key,
|
||||
"At open channel limit");
|
||||
Err(key)
|
||||
} else {
|
||||
Ok(o.get().upgrade().unwrap_or_else(|| {
|
||||
@@ -233,7 +239,9 @@ where
|
||||
let self_ = self.project();
|
||||
match ready!(self_.dropped_keys.poll_recv(cx)) {
|
||||
Some(key) => {
|
||||
debug!("All channels dropped for key [{}]", key);
|
||||
debug!(
|
||||
channel_filter_key = %key,
|
||||
"All channels dropped");
|
||||
self_.key_counts.remove(&key);
|
||||
self_.key_counts.compact(0.1);
|
||||
Poll::Ready(())
|
||||
|
||||
@@ -14,6 +14,7 @@ use std::{
|
||||
time::SystemTime,
|
||||
};
|
||||
use tokio_util::time::delay_queue::{self, DelayQueue};
|
||||
use tracing::Span;
|
||||
|
||||
/// A data structure that tracks in-flight requests. It aborts requests,
|
||||
/// either on demand or when a request deadline expires.
|
||||
@@ -23,13 +24,15 @@ pub struct InFlightRequests {
|
||||
deadlines: DelayQueue<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Data needed to clean up a single in-flight request.
|
||||
#[derive(Debug)]
|
||||
struct RequestData {
|
||||
/// Aborts the response handler for the associated request.
|
||||
abort_handle: AbortHandle,
|
||||
/// The key to remove the timer for the request's deadline.
|
||||
deadline_key: delay_queue::Key,
|
||||
/// The client span.
|
||||
span: Span,
|
||||
}
|
||||
|
||||
/// An error returned when a request attempted to start with the same ID as a request already
|
||||
@@ -48,6 +51,7 @@ impl InFlightRequests {
|
||||
&mut self,
|
||||
request_id: u64,
|
||||
deadline: SystemTime,
|
||||
span: Span,
|
||||
) -> Result<AbortRegistration, AlreadyExistsError> {
|
||||
match self.request_data.entry(request_id) {
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
@@ -57,6 +61,7 @@ impl InFlightRequests {
|
||||
vacant.insert(RequestData {
|
||||
abort_handle,
|
||||
deadline_key,
|
||||
span,
|
||||
});
|
||||
Ok(abort_registration)
|
||||
}
|
||||
@@ -66,12 +71,17 @@ impl InFlightRequests {
|
||||
|
||||
/// Cancels an in-flight request. Returns true iff the request was found.
|
||||
pub fn cancel_request(&mut self, request_id: u64) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
if let Some(RequestData {
|
||||
span,
|
||||
abort_handle,
|
||||
deadline_key,
|
||||
}) = self.request_data.remove(&request_id)
|
||||
{
|
||||
let _entered = span.enter();
|
||||
self.request_data.compact(0.1);
|
||||
|
||||
request_data.abort_handle.abort();
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
|
||||
abort_handle.abort();
|
||||
self.deadlines.remove(&deadline_key);
|
||||
tracing::info!("ReceiveCancel");
|
||||
true
|
||||
} else {
|
||||
false
|
||||
@@ -80,15 +90,13 @@ impl InFlightRequests {
|
||||
|
||||
/// Removes a request without aborting. Returns true iff the request was found.
|
||||
/// This method should be used when a response is being sent.
|
||||
pub fn remove_request(&mut self, request_id: u64) -> bool {
|
||||
pub fn remove_request(&mut self, request_id: u64) -> Option<Span> {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
|
||||
true
|
||||
Some(request_data.span)
|
||||
} else {
|
||||
false
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,9 +104,14 @@ impl InFlightRequests {
|
||||
pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo<u64> {
|
||||
Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) {
|
||||
Some(Ok(expired)) => {
|
||||
if let Some(request_data) = self.request_data.remove(expired.get_ref()) {
|
||||
if let Some(RequestData {
|
||||
abort_handle, span, ..
|
||||
}) = self.request_data.remove(expired.get_ref())
|
||||
{
|
||||
let _entered = span.enter();
|
||||
self.request_data.compact(0.1);
|
||||
request_data.abort_handle.abort();
|
||||
abort_handle.abort();
|
||||
tracing::error!("DeadlineExceeded");
|
||||
}
|
||||
Some(Ok(expired.into_inner()))
|
||||
}
|
||||
@@ -133,7 +146,7 @@ mod tests {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
assert_eq!(in_flight_requests.len(), 1);
|
||||
}
|
||||
@@ -142,7 +155,7 @@ mod tests {
|
||||
async fn polling_expired_aborts() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
@@ -164,7 +177,7 @@ mod tests {
|
||||
async fn cancel_request_aborts() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
@@ -180,11 +193,11 @@ mod tests {
|
||||
async fn remove_request_doesnt_abort() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
assert_eq!(in_flight_requests.remove_request(0), true);
|
||||
assert_matches!(in_flight_requests.remove_request(0), Some(_));
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Pending
|
||||
|
||||
@@ -11,10 +11,8 @@ use crate::{
|
||||
};
|
||||
use futures::{future::AbortRegistration, task::*, Sink, Stream};
|
||||
use pin_project::pin_project;
|
||||
use std::collections::VecDeque;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::time::SystemTime;
|
||||
use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
|
||||
use tracing::Span;
|
||||
|
||||
#[pin_project]
|
||||
pub(crate) struct FakeChannel<In, Out> {
|
||||
@@ -70,6 +68,7 @@ where
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
type Transport = ();
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
&self.config
|
||||
@@ -79,14 +78,19 @@ where
|
||||
self.in_flight_requests.len()
|
||||
}
|
||||
|
||||
fn transport(&self) -> &() {
|
||||
&()
|
||||
}
|
||||
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
span: Span,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.project()
|
||||
.in_flight_requests
|
||||
.start_request(id, deadline)
|
||||
.start_request(id, deadline, span)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
use super::{Channel, Config};
|
||||
use crate::{Response, ServerError};
|
||||
use futures::{future::AbortRegistration, prelude::*, ready, task::*};
|
||||
use log::debug;
|
||||
use pin_project::pin_project;
|
||||
use std::{io, pin::Pin, time::SystemTime};
|
||||
use tracing::Span;
|
||||
|
||||
/// A [`Channel`] that limits the number of concurrent
|
||||
/// requests by throttling.
|
||||
@@ -55,11 +55,11 @@ where
|
||||
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
|
||||
Some(request) => {
|
||||
debug!(
|
||||
"[{}] Client has reached in-flight request limit ({}/{}).",
|
||||
request.context.trace_id(),
|
||||
self.as_mut().in_flight_requests(),
|
||||
self.as_mut().project().max_in_flight_requests,
|
||||
tracing::debug!(
|
||||
rpc.trace_id = %request.context.trace_id(),
|
||||
in_flight_requests = self.as_mut().in_flight_requests(),
|
||||
max_in_flight_requests = *self.as_mut().project().max_in_flight_requests,
|
||||
"At in-flight request limit",
|
||||
);
|
||||
|
||||
self.as_mut().start_send(Response {
|
||||
@@ -112,6 +112,7 @@ where
|
||||
{
|
||||
type Req = <C as Channel>::Req;
|
||||
type Resp = <C as Channel>::Resp;
|
||||
type Transport = <C as Channel>::Transport;
|
||||
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
self.inner.in_flight_requests()
|
||||
@@ -121,12 +122,17 @@ where
|
||||
self.inner.config()
|
||||
}
|
||||
|
||||
fn transport(&self) -> &Self::Transport {
|
||||
self.inner.transport()
|
||||
}
|
||||
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
span: Span,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.project().inner.start_request(id, deadline)
|
||||
self.project().inner.start_request(id, deadline, span)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -196,7 +202,11 @@ mod tests {
|
||||
throttler
|
||||
.inner
|
||||
.in_flight_requests
|
||||
.start_request(i, SystemTime::now() + Duration::from_secs(1))
|
||||
.start_request(
|
||||
i,
|
||||
SystemTime::now() + Duration::from_secs(1),
|
||||
Span::current(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
|
||||
@@ -212,7 +222,11 @@ mod tests {
|
||||
pin_mut!(throttler);
|
||||
throttler
|
||||
.as_mut()
|
||||
.start_request(1, SystemTime::now() + Duration::from_secs(1))
|
||||
.start_request(
|
||||
1,
|
||||
SystemTime::now() + Duration::from_secs(1),
|
||||
Span::current(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
|
||||
}
|
||||
@@ -305,16 +319,21 @@ mod tests {
|
||||
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>> {
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
type Transport = ();
|
||||
fn config(&self) -> &Config {
|
||||
unimplemented!()
|
||||
}
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
0
|
||||
}
|
||||
fn transport(&self) -> &() {
|
||||
&()
|
||||
}
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
_id: u64,
|
||||
_deadline: SystemTime,
|
||||
_span: tracing::Span,
|
||||
) -> Result<AbortRegistration, AlreadyExistsError> {
|
||||
unimplemented!()
|
||||
}
|
||||
@@ -332,7 +351,11 @@ mod tests {
|
||||
throttler
|
||||
.inner
|
||||
.in_flight_requests
|
||||
.start_request(0, SystemTime::now() + Duration::from_secs(1))
|
||||
.start_request(
|
||||
0,
|
||||
SystemTime::now() + Duration::from_secs(1),
|
||||
Span::current(),
|
||||
)
|
||||
.unwrap();
|
||||
throttler
|
||||
.as_mut()
|
||||
|
||||
@@ -16,11 +16,14 @@
|
||||
//! This crate's design is based on [opencensus
|
||||
//! tracing](https://opencensus.io/core-concepts/tracing/).
|
||||
|
||||
use opentelemetry::trace::TraceContextExt;
|
||||
use rand::Rng;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
fmt::{self, Formatter},
|
||||
mem,
|
||||
num::{NonZeroU128, NonZeroU64},
|
||||
};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
/// A context for tracing the execution of processes, distributed or otherwise.
|
||||
///
|
||||
@@ -36,33 +39,50 @@ pub struct Context {
|
||||
/// before making an RPC, and the span ID is sent to the server. The server is free to create
|
||||
/// its own spans, for which it sets the client's span as the parent span.
|
||||
pub span_id: SpanId,
|
||||
/// An identifier of the span that originated the current span. For example, if a server sends
|
||||
/// an RPC in response to a client request that included a span, the server would create a span
|
||||
/// for the RPC and set its parent to the span_id in the incoming request's context.
|
||||
///
|
||||
/// If `parent_id` is `None`, then this is a root context.
|
||||
pub parent_id: Option<SpanId>,
|
||||
/// Indicates whether a sampler has already decided whether or not to sample the trace
|
||||
/// associated with the Context. If `sampling_decision` is None, then a decision has not yet
|
||||
/// been made. Downstream samplers do not need to abide by "no sample" decisions--for example,
|
||||
/// an upstream client may choose to never sample, which may not make sense for the client's
|
||||
/// dependencies. On the other hand, if an upstream process has chosen to sample this trace,
|
||||
/// then the downstream samplers are expected to respect that decision and also sample the
|
||||
/// trace. Otherwise, the full trace would not be able to be reconstructed.
|
||||
pub sampling_decision: SamplingDecision,
|
||||
}
|
||||
|
||||
/// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the
|
||||
/// same trace ID.
|
||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct TraceId(#[cfg_attr(feature = "serde1", serde(with = "u128_serde"))] u128);
|
||||
|
||||
/// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace.
|
||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct SpanId(u64);
|
||||
|
||||
/// Indicates whether a sampler has decided whether or not to sample the trace associated with the
|
||||
/// Context. Downstream samplers do not need to abide by "no sample" decisions--for example, an
|
||||
/// upstream client may choose to never sample, which may not make sense for the client's
|
||||
/// dependencies. On the other hand, if an upstream process has chosen to sample this trace, then
|
||||
/// the downstream samplers are expected to respect that decision and also sample the trace.
|
||||
/// Otherwise, the full trace would not be able to be reconstructed reliably.
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[repr(u8)]
|
||||
pub enum SamplingDecision {
|
||||
/// The associated span was sampled by its creating process. Child spans must also be sampled.
|
||||
Sampled = opentelemetry::trace::TRACE_FLAG_SAMPLED,
|
||||
/// The associated span was not sampled by its creating process.
|
||||
Unsampled = opentelemetry::trace::TRACE_FLAG_NOT_SAMPLED,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
/// Constructs a new root context. A root context is one with no parent span.
|
||||
pub fn new_root() -> Self {
|
||||
let rng = &mut rand::thread_rng();
|
||||
Context {
|
||||
trace_id: TraceId::random(rng),
|
||||
span_id: SpanId::random(rng),
|
||||
parent_id: None,
|
||||
/// Constructs a new context with the trace ID and sampling decision inherited from the parent.
|
||||
pub(crate) fn new_child(&self) -> Self {
|
||||
Self {
|
||||
trace_id: self.trace_id,
|
||||
span_id: SpanId::random(&mut rand::thread_rng()),
|
||||
sampling_decision: self.sampling_decision,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -71,17 +91,119 @@ impl TraceId {
|
||||
/// Returns a random trace ID that can be assumed to be globally unique if `rng` generates
|
||||
/// actually-random numbers.
|
||||
pub fn random<R: Rng>(rng: &mut R) -> Self {
|
||||
TraceId(u128::from(rng.next_u64()) << mem::size_of::<u64>() | u128::from(rng.next_u64()))
|
||||
TraceId(rng.gen::<NonZeroU128>().get())
|
||||
}
|
||||
|
||||
/// Returns true iff the trace ID is 0.
|
||||
pub fn is_none(&self) -> bool {
|
||||
self.0 == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl SpanId {
|
||||
/// Returns a random span ID that can be assumed to be unique within a single trace.
|
||||
pub fn random<R: Rng>(rng: &mut R) -> Self {
|
||||
SpanId(rng.next_u64())
|
||||
SpanId(rng.gen::<NonZeroU64>().get())
|
||||
}
|
||||
|
||||
/// Returns true iff the span ID is 0.
|
||||
pub fn is_none(&self) -> bool {
|
||||
self.0 == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TraceId> for u128 {
|
||||
fn from(trace_id: TraceId) -> Self {
|
||||
trace_id.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u128> for TraceId {
|
||||
fn from(trace_id: u128) -> Self {
|
||||
Self(trace_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SpanId> for u64 {
|
||||
fn from(span_id: SpanId) -> Self {
|
||||
span_id.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u64> for SpanId {
|
||||
fn from(span_id: u64) -> Self {
|
||||
Self(span_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<opentelemetry::trace::TraceId> for TraceId {
|
||||
fn from(trace_id: opentelemetry::trace::TraceId) -> Self {
|
||||
Self::from(trace_id.to_u128())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TraceId> for opentelemetry::trace::TraceId {
|
||||
fn from(trace_id: TraceId) -> Self {
|
||||
Self::from_u128(trace_id.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<opentelemetry::trace::SpanId> for SpanId {
|
||||
fn from(span_id: opentelemetry::trace::SpanId) -> Self {
|
||||
Self::from(span_id.to_u64())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SpanId> for opentelemetry::trace::SpanId {
|
||||
fn from(span_id: SpanId) -> Self {
|
||||
Self::from_u64(span_id.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&tracing::Span> for Context {
|
||||
type Error = NoActiveSpan;
|
||||
|
||||
fn try_from(span: &tracing::Span) -> Result<Self, NoActiveSpan> {
|
||||
let context = span.context();
|
||||
if context.has_active_span() {
|
||||
Ok(Self::from(context.span()))
|
||||
} else {
|
||||
Err(NoActiveSpan)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&dyn opentelemetry::trace::Span> for Context {
|
||||
fn from(span: &dyn opentelemetry::trace::Span) -> Self {
|
||||
let otel_ctx = span.span_context();
|
||||
Self {
|
||||
trace_id: TraceId::from(otel_ctx.trace_id()),
|
||||
span_id: SpanId::from(otel_ctx.span_id()),
|
||||
sampling_decision: SamplingDecision::from(otel_ctx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&opentelemetry::trace::SpanContext> for SamplingDecision {
|
||||
fn from(context: &opentelemetry::trace::SpanContext) -> Self {
|
||||
if context.is_sampled() {
|
||||
SamplingDecision::Sampled
|
||||
} else {
|
||||
SamplingDecision::Unsampled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SamplingDecision {
|
||||
fn default() -> Self {
|
||||
Self::Unsampled
|
||||
}
|
||||
}
|
||||
|
||||
/// Returned when a [`Context`] cannot be constructed from a [`Span`](tracing::Span).
|
||||
#[derive(Debug)]
|
||||
pub struct NoActiveSpan;
|
||||
|
||||
impl fmt::Display for TraceId {
|
||||
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "{:02x}", self.0)?;
|
||||
@@ -89,6 +211,13 @@ impl fmt::Display for TraceId {
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for TraceId {
|
||||
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "{:02x}", self.0)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SpanId {
|
||||
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "{:02x}", self.0)?;
|
||||
@@ -96,6 +225,13 @@ impl fmt::Display for SpanId {
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for SpanId {
|
||||
fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "{:02x}", self.0)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
mod u128_serde {
|
||||
pub fn serialize<S>(u: &u128, serializer: S) -> Result<S::Ok, S::Error>
|
||||
|
||||
@@ -152,12 +152,12 @@ mod tests {
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{prelude::*, stream};
|
||||
use log::trace;
|
||||
use std::io;
|
||||
use tracing::trace;
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||
tokio::spawn(
|
||||
@@ -175,8 +175,8 @@ mod tests {
|
||||
|
||||
let client = client::new(client::Config::default(), client_channel).spawn()?;
|
||||
|
||||
let response1 = client.call(context::current(), "123".into()).await?;
|
||||
let response2 = client.call(context::current(), "abc".into()).await?;
|
||||
let response1 = client.call(context::current(), "", "123".into()).await?;
|
||||
let response2 = client.call(context::current(), "", "abc".into()).await?;
|
||||
|
||||
trace!("response1: {:?}, response2: {:?}", response1, response2);
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ impl Service for Server {
|
||||
|
||||
#[tokio::test]
|
||||
async fn sequential() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
|
||||
@@ -82,7 +82,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
let _ = env_logger::try_init();
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
|
||||
@@ -117,7 +117,7 @@ async fn serde() -> io::Result<()> {
|
||||
use tarpc::serde_transport;
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
let _ = env_logger::try_init();
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56789", Json::default).await?;
|
||||
let addr = transport.local_addr();
|
||||
@@ -143,7 +143,7 @@ async fn serde() -> io::Result<()> {
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
@@ -167,7 +167,7 @@ async fn concurrent() -> io::Result<()> {
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent_join() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
@@ -192,7 +192,7 @@ async fn concurrent_join() -> io::Result<()> {
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent_join_all() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
|
||||
Reference in New Issue
Block a user