mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
Compare commits
21 Commits
v0.26.0
...
client-clo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e4c3a2b8b | ||
|
|
d78b24b631 | ||
|
|
49900d7a35 | ||
|
|
1e680e3a5a | ||
|
|
2591d21e94 | ||
|
|
6632f68d95 | ||
|
|
25985ad56a | ||
|
|
d6a24e9420 | ||
|
|
281a78f3c7 | ||
|
|
a0787d0091 | ||
|
|
d2acba0e8a | ||
|
|
ea7b6763c4 | ||
|
|
eb67c540b9 | ||
|
|
4151d0abd3 | ||
|
|
d0c11a6efa | ||
|
|
82c4da1743 | ||
|
|
0a15e0b75c | ||
|
|
0b315c29bf | ||
|
|
56f09bf61f | ||
|
|
6d82e82419 | ||
|
|
9bebaf814a |
12
README.md
12
README.md
@@ -67,7 +67,7 @@ Some other features of tarpc:
|
||||
Add to your `Cargo.toml` dependencies:
|
||||
|
||||
```toml
|
||||
tarpc = "0.26"
|
||||
tarpc = "0.27"
|
||||
```
|
||||
|
||||
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
@@ -80,8 +80,9 @@ This example uses [tokio](https://tokio.rs), so add the following dependencies t
|
||||
your `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
futures = "1.0"
|
||||
tarpc = { version = "0.26", features = ["tokio1"] }
|
||||
anyhow = "1.0"
|
||||
futures = "0.3"
|
||||
tarpc = { version = "0.27", features = ["tokio1"] }
|
||||
tokio = { version = "1.0", features = ["macros"] }
|
||||
```
|
||||
|
||||
@@ -99,9 +100,8 @@ use futures::{
|
||||
};
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{self, Incoming},
|
||||
server::{self, incoming::Incoming},
|
||||
};
|
||||
use std::io;
|
||||
|
||||
// This is the service definition. It looks a lot like a trait definition.
|
||||
// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
@@ -140,7 +140,7 @@ available behind the `tcp` feature.
|
||||
|
||||
```rust
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
|
||||
let server = server::BaseChannel::with_defaults(server_transport);
|
||||
|
||||
41
RELEASES.md
41
RELEASES.md
@@ -1,3 +1,36 @@
|
||||
## 0.27.1 (2021-09-22)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
### RPC error type is changing
|
||||
|
||||
RPC return types are changing from `Result<Response, io::Error>` to `Result<Response,
|
||||
tarpc::client::RpcError>`.
|
||||
|
||||
Becaue tarpc is a library, not an application, it should strive to
|
||||
use structured errors in its API so that users have maximal flexibility
|
||||
in how they handle errors. io::Error makes that hard, because it is a
|
||||
kitchen-sink error type.
|
||||
|
||||
RPCs in particular only have 3 classes of errors:
|
||||
|
||||
- The connection breaks.
|
||||
- The request expires.
|
||||
- The server decides not to process the request.
|
||||
|
||||
RPC responses can also contain application-specific errors, but from the
|
||||
perspective of the RPC library, those are opaque to the framework, classified
|
||||
as successful responsees.
|
||||
|
||||
### Open Telemetry
|
||||
|
||||
The Opentelemetry dependency is updated to version 0.16.x.
|
||||
|
||||
## 0.27.0 (2021-09-22)
|
||||
|
||||
This version was yanked due to tarpc-plugins version mismatches.
|
||||
|
||||
|
||||
## 0.26.0 (2021-04-14)
|
||||
|
||||
### New Features
|
||||
@@ -68,10 +101,10 @@ tracing_subscriber::fmt.
|
||||
|
||||
### References
|
||||
|
||||
[1] https://github.com/tokio-rs/tracing
|
||||
[2] https://opentelemetry.io
|
||||
[3] https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
|
||||
[4] https://github.com/env-logger-rs/env_logger
|
||||
1. https://github.com/tokio-rs/tracing
|
||||
2. https://opentelemetry.io
|
||||
3. https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger
|
||||
4. https://github.com/env-logger-rs/env_logger
|
||||
|
||||
## 0.25.0 (2021-03-10)
|
||||
|
||||
|
||||
@@ -17,15 +17,13 @@ anyhow = "1.0"
|
||||
clap = "3.0.0-beta.2"
|
||||
log = "0.4"
|
||||
futures = "0.3"
|
||||
opentelemetry = { version = "0.13", features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.12", features = ["tokio"] }
|
||||
opentelemetry = { version = "0.16", features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.15", features = ["rt-tokio"] }
|
||||
rand = "0.8"
|
||||
serde = { version = "1.0" }
|
||||
tarpc = { version = "0.26", path = "../tarpc", features = ["full"] }
|
||||
tarpc = { version = "0.27", path = "../tarpc", features = ["full"] }
|
||||
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
|
||||
tokio-serde = { version = "0.8", features = ["json"] }
|
||||
tracing = { version = "0.1" }
|
||||
tracing-opentelemetry = "0.12"
|
||||
tracing-opentelemetry = "0.15"
|
||||
tracing-subscriber = "0.2"
|
||||
|
||||
[lib]
|
||||
|
||||
@@ -17,7 +17,7 @@ use std::{
|
||||
};
|
||||
use tarpc::{
|
||||
context,
|
||||
server::{self, Channel, Incoming},
|
||||
server::{self, incoming::Incoming, Channel},
|
||||
tokio_serde::formats::Json,
|
||||
};
|
||||
use tokio::time;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc-plugins"
|
||||
version = "0.11.0"
|
||||
version = "0.12.0"
|
||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||
edition = "2018"
|
||||
license = "MIT"
|
||||
|
||||
@@ -277,7 +277,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
response_fut_name,
|
||||
service_ident: ident,
|
||||
server_ident: &format_ident!("Serve{}", ident),
|
||||
response_fut_ident: &Ident::new(&response_fut_name, ident.span()),
|
||||
response_fut_ident: &Ident::new(response_fut_name, ident.span()),
|
||||
client_ident: &format_ident!("{}Client", ident),
|
||||
request_ident: &format_ident!("{}Request", ident),
|
||||
response_ident: &format_ident!("{}Response", ident),
|
||||
@@ -406,11 +406,7 @@ fn verify_types_were_provided(
|
||||
) -> syn::Result<()> {
|
||||
let mut result = Ok(());
|
||||
for (method, expected) in expected {
|
||||
if provided
|
||||
.iter()
|
||||
.find(|typedecl| typedecl.ident == expected)
|
||||
.is_none()
|
||||
{
|
||||
if !provided.iter().any(|typedecl| typedecl.ident == expected) {
|
||||
let mut e = syn::Error::new(
|
||||
span,
|
||||
format!("not all trait items implemented, missing: `{}`", expected),
|
||||
@@ -483,7 +479,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
),
|
||||
output,
|
||||
)| {
|
||||
let ty_doc = format!("The response future returned by {}.", ident);
|
||||
let ty_doc = format!("The response future returned by [`{}::{}`].", service_ident, ident);
|
||||
quote! {
|
||||
#[doc = #ty_doc]
|
||||
type #future_type: std::future::Future<Output = #output>;
|
||||
@@ -582,6 +578,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
|
||||
quote! {
|
||||
/// The request sent over the wire from the client to the server.
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug)]
|
||||
#derive_serialize
|
||||
#vis enum #request_ident {
|
||||
@@ -602,6 +599,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
|
||||
quote! {
|
||||
/// The response sent over the wire from the server to the client.
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug)]
|
||||
#derive_serialize
|
||||
#vis enum #response_ident {
|
||||
@@ -622,6 +620,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
|
||||
quote! {
|
||||
/// A future resolving to a server response.
|
||||
#[allow(missing_docs)]
|
||||
#vis enum #response_fut_ident<S: #service_ident> {
|
||||
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
|
||||
}
|
||||
@@ -747,7 +746,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
#[allow(unused)]
|
||||
#( #method_attrs )*
|
||||
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
|
||||
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
|
||||
-> impl std::future::Future<Output = Result<#return_types, tarpc::client::RpcError>> + '_ {
|
||||
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
|
||||
let resp = self.0.call(ctx, #request_names, request);
|
||||
async move {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc"
|
||||
version = "0.26.0"
|
||||
version = "0.27.1"
|
||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||
edition = "2018"
|
||||
license = "MIT"
|
||||
@@ -18,9 +18,11 @@ default = []
|
||||
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
|
||||
tokio1 = ["tokio/rt-multi-thread"]
|
||||
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
||||
serde-transport-json = ["tokio-serde/json"]
|
||||
serde-transport-bincode = ["tokio-serde/bincode"]
|
||||
tcp = ["tokio/net"]
|
||||
|
||||
full = ["serde1", "tokio1", "serde-transport", "tcp"]
|
||||
full = ["serde1", "tokio1", "serde-transport", "serde-transport-json", "serde-transport-bincode", "tcp"]
|
||||
|
||||
[badges]
|
||||
travis-ci = { repository = "google/tarpc" }
|
||||
@@ -34,14 +36,14 @@ pin-project = "1.0"
|
||||
rand = "0.8"
|
||||
serde = { optional = true, version = "1.0", features = ["derive"] }
|
||||
static_assertions = "1.1.0"
|
||||
tarpc-plugins = { path = "../plugins", version = "0.11" }
|
||||
tarpc-plugins = { path = "../plugins", version = "0.12" }
|
||||
thiserror = "1.0"
|
||||
tokio = { version = "1", features = ["time"] }
|
||||
tokio-util = { version = "0.6.3", features = ["time"] }
|
||||
tokio-serde = { optional = true, version = "0.8" }
|
||||
tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
|
||||
tracing-opentelemetry = { version = "0.12", default-features = false }
|
||||
opentelemetry = { version = "0.13", default-features = false }
|
||||
tracing-opentelemetry = { version = "0.15", default-features = false }
|
||||
opentelemetry = { version = "0.16", default-features = false }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
@@ -50,8 +52,8 @@ bincode = "1.3"
|
||||
bytes = { version = "1", features = ["serde"] }
|
||||
flate2 = "1.0"
|
||||
futures-test = "0.3"
|
||||
opentelemetry = { version = "0.13", default-features = false, features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.12", features = ["tokio"] }
|
||||
opentelemetry = { version = "0.16", default-features = false, features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.15", features = ["rt-tokio"] }
|
||||
pin-utils = "0.1.0-alpha"
|
||||
serde_bytes = "0.11"
|
||||
tracing-subscriber = "0.2"
|
||||
|
||||
@@ -7,8 +7,8 @@ use tarpc::{
|
||||
client, context,
|
||||
serde_transport::tcp,
|
||||
server::{BaseChannel, Channel},
|
||||
tokio_serde::formats::Bincode,
|
||||
};
|
||||
use tokio_serde::formats::Bincode;
|
||||
|
||||
/// Type of compression that should be enabled on the request. The transport is free to ignore this.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)]
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
use futures::future;
|
||||
use tarpc::context::Context;
|
||||
use tarpc::serde_transport as transport;
|
||||
use tarpc::server::{BaseChannel, Channel};
|
||||
use tarpc::{context::Context, tokio_serde::formats::Bincode};
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
use tokio_serde::formats::Bincode;
|
||||
use tokio_util::codec::length_delimited::LengthDelimitedCodec;
|
||||
|
||||
#[tarpc::service]
|
||||
@@ -14,16 +12,13 @@ pub trait PingService {
|
||||
#[derive(Clone)]
|
||||
struct Service;
|
||||
|
||||
#[tarpc::server]
|
||||
impl PingService for Service {
|
||||
type PingFut = future::Ready<()>;
|
||||
|
||||
fn ping(self, _: Context) -> Self::PingFut {
|
||||
future::ready(())
|
||||
}
|
||||
async fn ping(self, _: Context) {}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let bind_addr = "/tmp/tarpc_on_unix_example.sock";
|
||||
|
||||
let _ = std::fs::remove_file(bind_addr);
|
||||
@@ -46,5 +41,7 @@ async fn main() -> std::io::Result<()> {
|
||||
PingServiceClient::new(Default::default(), transport)
|
||||
.spawn()
|
||||
.ping(tarpc::context::current())
|
||||
.await
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use futures::future::{self, Ready};
|
||||
use std::io;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{self, Channel},
|
||||
@@ -35,7 +34,7 @@ impl World for HelloServer {
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
|
||||
let server = server::BaseChannel::with_defaults(server_transport);
|
||||
|
||||
@@ -9,7 +9,7 @@ use futures::{future, prelude::*};
|
||||
use std::env;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{BaseChannel, Incoming},
|
||||
server::{incoming::Incoming, BaseChannel},
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
@@ -8,14 +8,14 @@
|
||||
|
||||
mod in_flight_requests;
|
||||
|
||||
use crate::{context, trace, ClientMessage, Request, Response, Transport};
|
||||
use crate::{context, trace, ClientMessage, Request, Response, ServerError, Transport};
|
||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||
use in_flight_requests::InFlightRequests;
|
||||
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
error::Error,
|
||||
fmt, io, mem,
|
||||
fmt, mem,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
@@ -60,15 +60,16 @@ pub struct NewClient<C, D> {
|
||||
impl<C, D, E> NewClient<C, D>
|
||||
where
|
||||
D: Future<Output = Result<(), E>> + Send + 'static,
|
||||
E: std::fmt::Display,
|
||||
E: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
/// Helper method to spawn the dispatch on the default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub fn spawn(self) -> C {
|
||||
let dispatch = self
|
||||
.dispatch
|
||||
.unwrap_or_else(move |e| tracing::warn!("Connection broken: {}", e));
|
||||
let dispatch = self.dispatch.unwrap_or_else(move |e| {
|
||||
let e = anyhow::Error::new(e);
|
||||
tracing::warn!("Connection broken: {:?}", e);
|
||||
});
|
||||
tokio::spawn(dispatch);
|
||||
self.client
|
||||
}
|
||||
@@ -125,7 +126,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
mut ctx: context::Context,
|
||||
request_name: &str,
|
||||
request: Req,
|
||||
) -> io::Result<Resp> {
|
||||
) -> Result<Resp, RpcError> {
|
||||
let span = Span::current();
|
||||
ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
|
||||
tracing::warn!(
|
||||
@@ -156,7 +157,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
response_completion,
|
||||
})
|
||||
.await
|
||||
.map_err(|mpsc::error::SendError(_)| io::Error::from(io::ErrorKind::ConnectionReset))?;
|
||||
.map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
|
||||
response_guard.response().await
|
||||
}
|
||||
}
|
||||
@@ -164,23 +165,44 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
/// A server response that is completed by request dispatch when the corresponding response
|
||||
/// arrives off the wire.
|
||||
struct ResponseGuard<'a, Resp> {
|
||||
response: &'a mut oneshot::Receiver<Response<Resp>>,
|
||||
response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
cancellation: &'a RequestCancellation,
|
||||
request_id: u64,
|
||||
}
|
||||
|
||||
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
|
||||
/// rather cross-cutting errors that can always occur.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum RpcError {
|
||||
/// The client disconnected from the server.
|
||||
#[error("the client disconnected from the server")]
|
||||
Disconnected,
|
||||
/// The request exceeded its deadline.
|
||||
#[error("the request exceeded its deadline")]
|
||||
DeadlineExceeded,
|
||||
/// The server aborted request processing.
|
||||
#[error("the server aborted request processing")]
|
||||
Server(#[from] ServerError),
|
||||
}
|
||||
|
||||
impl From<DeadlineExceededError> for RpcError {
|
||||
fn from(_: DeadlineExceededError) -> Self {
|
||||
RpcError::DeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
impl<Resp> ResponseGuard<'_, Resp> {
|
||||
async fn response(mut self) -> io::Result<Resp> {
|
||||
async fn response(mut self) -> Result<Resp, RpcError> {
|
||||
let response = (&mut self.response).await;
|
||||
// Cancel drop logic once a response has been received.
|
||||
mem::forget(self);
|
||||
match response {
|
||||
Ok(resp) => Ok(resp.message?),
|
||||
Ok(resp) => Ok(resp?.message?),
|
||||
Err(oneshot::error::RecvError { .. }) => {
|
||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||
// there's nothing listening on the other side, so there's no point in
|
||||
// propagating cancellation.
|
||||
Err(io::Error::from(io::ErrorKind::ConnectionReset))
|
||||
Err(RpcError::Disconnected)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -257,11 +279,23 @@ pub enum ChannelError<E>
|
||||
where
|
||||
E: Error + Send + Sync + 'static,
|
||||
{
|
||||
/// An error occurred reading from, or writing to, the transport.
|
||||
#[error("an error occurred in the transport: {0}")]
|
||||
Transport(#[source] E),
|
||||
/// An error occurred while polling expired requests.
|
||||
#[error("an error occurred while polling expired requests: {0}")]
|
||||
/// Could not read from the transport.
|
||||
#[error("could not read from the transport")]
|
||||
Read(#[source] E),
|
||||
/// Could not ready the transport for writes.
|
||||
#[error("could not ready the transport for writes")]
|
||||
Ready(#[source] E),
|
||||
/// Could not write to the transport.
|
||||
#[error("could not write to the transport")]
|
||||
Write(#[source] E),
|
||||
/// Could not flush the transport.
|
||||
#[error("could not flush the transport")]
|
||||
Flush(#[source] E),
|
||||
/// Could not close the write end of the transport.
|
||||
#[error("could not close the write end of the transport")]
|
||||
Close(#[source] E),
|
||||
/// Could not poll expired requests.
|
||||
#[error("could not poll expired requests")]
|
||||
Timer(#[source] tokio::time::error::Error),
|
||||
}
|
||||
|
||||
@@ -277,6 +311,42 @@ where
|
||||
self.as_mut().project().transport
|
||||
}
|
||||
|
||||
fn poll_ready<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||
self.transport_pin_mut()
|
||||
.poll_ready(cx)
|
||||
.map_err(ChannelError::Ready)
|
||||
}
|
||||
|
||||
fn start_send(
|
||||
self: &mut Pin<&mut Self>,
|
||||
message: ClientMessage<Req>,
|
||||
) -> Result<(), ChannelError<C::Error>> {
|
||||
self.transport_pin_mut()
|
||||
.start_send(message)
|
||||
.map_err(ChannelError::Write)
|
||||
}
|
||||
|
||||
fn poll_flush<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||
self.transport_pin_mut()
|
||||
.poll_flush(cx)
|
||||
.map_err(ChannelError::Flush)
|
||||
}
|
||||
|
||||
fn poll_close<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), ChannelError<C::Error>>> {
|
||||
self.transport_pin_mut()
|
||||
.poll_close(cx)
|
||||
.map_err(ChannelError::Close)
|
||||
}
|
||||
|
||||
fn canceled_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut CanceledRequests {
|
||||
self.as_mut().project().canceled_requests
|
||||
}
|
||||
@@ -293,7 +363,7 @@ where
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
self.transport_pin_mut()
|
||||
.poll_next(cx)
|
||||
.map_err(ChannelError::Transport)
|
||||
.map_err(ChannelError::Read)
|
||||
.map_ok(|response| {
|
||||
self.complete(response);
|
||||
})
|
||||
@@ -308,21 +378,13 @@ where
|
||||
Closed,
|
||||
}
|
||||
|
||||
let pending_requests_status = match self
|
||||
.as_mut()
|
||||
.poll_write_request(cx)
|
||||
.map_err(ChannelError::Transport)?
|
||||
{
|
||||
let pending_requests_status = match self.as_mut().poll_write_request(cx)? {
|
||||
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
|
||||
Poll::Ready(None) => ReceiverStatus::Closed,
|
||||
Poll::Pending => ReceiverStatus::Pending,
|
||||
};
|
||||
|
||||
let canceled_requests_status = match self
|
||||
.as_mut()
|
||||
.poll_write_cancel(cx)
|
||||
.map_err(ChannelError::Transport)?
|
||||
{
|
||||
let canceled_requests_status = match self.as_mut().poll_write_cancel(cx)? {
|
||||
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
|
||||
Poll::Ready(None) => ReceiverStatus::Closed,
|
||||
Poll::Pending => ReceiverStatus::Pending,
|
||||
@@ -344,18 +406,12 @@ where
|
||||
|
||||
match (pending_requests_status, canceled_requests_status) {
|
||||
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
|
||||
ready!(self
|
||||
.transport_pin_mut()
|
||||
.poll_flush(cx)
|
||||
.map_err(ChannelError::Transport)?);
|
||||
ready!(self.poll_close(cx)?);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
(ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => {
|
||||
// No more messages to process, so flush any messages buffered in the transport.
|
||||
ready!(self
|
||||
.transport_pin_mut()
|
||||
.poll_flush(cx)
|
||||
.map_err(ChannelError::Transport)?);
|
||||
ready!(self.poll_flush(cx)?);
|
||||
|
||||
// Even if we fully-flush, we return Pending, because we have no more requests
|
||||
// or cancellations right now.
|
||||
@@ -371,7 +427,7 @@ where
|
||||
fn poll_next_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<DispatchRequest<Req, Resp>, C::Error>>> {
|
||||
) -> Poll<Option<Result<DispatchRequest<Req, Resp>, ChannelError<C::Error>>>> {
|
||||
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
|
||||
tracing::info!(
|
||||
"At in-flight request capacity ({}/{}).",
|
||||
@@ -409,7 +465,7 @@ where
|
||||
fn poll_next_cancellation(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(context::Context, Span, u64), C::Error>>> {
|
||||
) -> Poll<Option<Result<(context::Context, Span, u64), ChannelError<C::Error>>>> {
|
||||
ready!(self.ensure_writeable(cx)?);
|
||||
|
||||
loop {
|
||||
@@ -431,9 +487,9 @@ where
|
||||
fn ensure_writeable<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(), C::Error>>> {
|
||||
while self.transport_pin_mut().poll_ready(cx)?.is_pending() {
|
||||
ready!(self.transport_pin_mut().poll_flush(cx)?);
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
while self.poll_ready(cx)?.is_pending() {
|
||||
ready!(self.poll_flush(cx)?);
|
||||
}
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
@@ -441,7 +497,7 @@ where
|
||||
fn poll_write_request<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(), C::Error>>> {
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
let DispatchRequest {
|
||||
ctx,
|
||||
span,
|
||||
@@ -465,7 +521,7 @@ where
|
||||
trace_context: ctx.trace_context,
|
||||
},
|
||||
});
|
||||
self.transport_pin_mut().start_send(request)?;
|
||||
self.start_send(request)?;
|
||||
let deadline = ctx.deadline;
|
||||
tracing::info!(
|
||||
tarpc.deadline = %humantime::format_rfc3339(deadline),
|
||||
@@ -482,7 +538,7 @@ where
|
||||
fn poll_write_cancel<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<(), C::Error>>> {
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
|
||||
Some(triple) => triple,
|
||||
None => return Poll::Ready(None),
|
||||
@@ -493,7 +549,7 @@ where
|
||||
trace_context: context.trace_context,
|
||||
request_id,
|
||||
};
|
||||
self.transport_pin_mut().start_send(cancel)?;
|
||||
self.start_send(cancel)?;
|
||||
tracing::info!("CancelRequest");
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
@@ -549,7 +605,7 @@ struct DispatchRequest<Req, Resp> {
|
||||
pub span: Span,
|
||||
pub request_id: u64,
|
||||
pub request: Req,
|
||||
pub response_completion: oneshot::Sender<Response<Resp>>,
|
||||
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
}
|
||||
|
||||
/// Sends request cancellation signals.
|
||||
@@ -596,7 +652,10 @@ mod tests {
|
||||
RequestDispatch, ResponseGuard,
|
||||
};
|
||||
use crate::{
|
||||
client::{in_flight_requests::InFlightRequests, Config},
|
||||
client::{
|
||||
in_flight_requests::{DeadlineExceededError, InFlightRequests},
|
||||
Config,
|
||||
},
|
||||
context,
|
||||
transport::{self, channel::UnboundedChannel},
|
||||
ClientMessage, Response,
|
||||
@@ -630,7 +689,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
|
||||
assert_matches!(rx.try_recv(), Ok(Response { request_id: 0, message: Ok(resp) }) if resp == "Resp");
|
||||
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -651,10 +710,10 @@ mod tests {
|
||||
async fn dispatch_response_doesnt_cancel_after_complete() {
|
||||
let (cancellation, mut canceled_requests) = cancellations();
|
||||
let (tx, mut response) = oneshot::channel();
|
||||
tx.send(Response {
|
||||
tx.send(Ok(Response {
|
||||
request_id: 0,
|
||||
message: Ok("well done"),
|
||||
})
|
||||
}))
|
||||
.unwrap();
|
||||
// resp's drop() is run, but should not send a cancel message.
|
||||
ResponseGuard {
|
||||
@@ -799,8 +858,8 @@ mod tests {
|
||||
async fn send_request<'a>(
|
||||
channel: &'a mut Channel<String, String>,
|
||||
request: &str,
|
||||
response_completion: oneshot::Sender<Response<String>>,
|
||||
response: &'a mut oneshot::Receiver<Response<String>>,
|
||||
response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
|
||||
response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
|
||||
) -> ResponseGuard<'a, String> {
|
||||
let request_id =
|
||||
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
use crate::{
|
||||
context,
|
||||
util::{Compact, TimeUntil},
|
||||
Response, ServerError,
|
||||
Response,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
io,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::sync::oneshot;
|
||||
@@ -29,11 +28,17 @@ impl<Resp> Default for InFlightRequests<Resp> {
|
||||
}
|
||||
}
|
||||
|
||||
/// The request exceeded its deadline.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[error("the request exceeded its deadline")]
|
||||
pub struct DeadlineExceededError;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RequestData<Resp> {
|
||||
ctx: context::Context,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Response<Resp>>,
|
||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
/// The key to remove the timer for the request's deadline.
|
||||
deadline_key: delay_queue::Key,
|
||||
}
|
||||
@@ -60,7 +65,7 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
request_id: u64,
|
||||
ctx: context::Context,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Response<Resp>>,
|
||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
) -> Result<(), AlreadyExistsError> {
|
||||
match self.request_data.entry(request_id) {
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
@@ -85,7 +90,7 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
tracing::info!("ReceiveResponse");
|
||||
self.request_data.compact(0.1);
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
let _ = request_data.response_completion.send(response);
|
||||
let _ = request_data.response_completion.send(Ok(response));
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -124,19 +129,9 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
self.request_data.compact(0.1);
|
||||
let _ = request_data
|
||||
.response_completion
|
||||
.send(Self::deadline_exceeded_error(request_id));
|
||||
.send(Err(DeadlineExceededError));
|
||||
}
|
||||
request_id
|
||||
})
|
||||
}
|
||||
|
||||
fn deadline_exceeded_error(request_id: u64) -> Response<Resp> {
|
||||
Response {
|
||||
request_id,
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::TimedOut,
|
||||
detail: Some("Client dropped expired request.".to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ impl SpanExt for tracing::Span {
|
||||
.with_remote_span_context(opentelemetry::trace::SpanContext::new(
|
||||
opentelemetry::trace::TraceId::from(context.trace_context.trace_id),
|
||||
opentelemetry::trace::SpanId::from(context.trace_context.span_id),
|
||||
context.trace_context.sampling_decision as u8,
|
||||
opentelemetry::trace::TraceFlags::from(context.trace_context.sampling_decision),
|
||||
true,
|
||||
opentelemetry::trace::TraceState::default(),
|
||||
))
|
||||
|
||||
@@ -54,7 +54,7 @@
|
||||
//! Add to your `Cargo.toml` dependencies:
|
||||
//!
|
||||
//! ```toml
|
||||
//! tarpc = "0.26"
|
||||
//! tarpc = "0.27"
|
||||
//! ```
|
||||
//!
|
||||
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
@@ -67,8 +67,9 @@
|
||||
//! your `Cargo.toml`:
|
||||
//!
|
||||
//! ```toml
|
||||
//! futures = "1.0"
|
||||
//! tarpc = { version = "0.26", features = ["tokio1"] }
|
||||
//! anyhow = "1.0"
|
||||
//! futures = "0.3"
|
||||
//! tarpc = { version = "0.27", features = ["tokio1"] }
|
||||
//! tokio = { version = "1.0", features = ["macros"] }
|
||||
//! ```
|
||||
//!
|
||||
@@ -87,9 +88,8 @@
|
||||
//! };
|
||||
//! use tarpc::{
|
||||
//! client, context,
|
||||
//! server::{self, Incoming},
|
||||
//! server::{self, incoming::Incoming},
|
||||
//! };
|
||||
//! use std::io;
|
||||
//!
|
||||
//! // This is the service definition. It looks a lot like a trait definition.
|
||||
//! // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
@@ -111,9 +111,8 @@
|
||||
//! # };
|
||||
//! # use tarpc::{
|
||||
//! # client, context,
|
||||
//! # server::{self, Incoming},
|
||||
//! # server::{self, incoming::Incoming},
|
||||
//! # };
|
||||
//! # use std::io;
|
||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
//! # #[tarpc::service]
|
||||
@@ -153,7 +152,6 @@
|
||||
//! # client, context,
|
||||
//! # server::{self, Channel},
|
||||
//! # };
|
||||
//! # use std::io;
|
||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
//! # #[tarpc::service]
|
||||
@@ -177,7 +175,7 @@
|
||||
//! # fn main() {}
|
||||
//! # #[cfg(feature = "tokio1")]
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> io::Result<()> {
|
||||
//! async fn main() -> anyhow::Result<()> {
|
||||
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
//!
|
||||
//! let server = server::BaseChannel::with_defaults(server_transport);
|
||||
@@ -364,8 +362,9 @@ pub struct Response<T> {
|
||||
pub message: Result<T, ServerError>,
|
||||
}
|
||||
|
||||
/// An error response from a server to a client.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
/// An error indicating the server aborted the request early, e.g., due to request throttling.
|
||||
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[error("{kind:?}: {detail}")]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ServerError {
|
||||
@@ -380,13 +379,7 @@ pub struct ServerError {
|
||||
/// The type of error that occurred to fail the request.
|
||||
pub kind: io::ErrorKind,
|
||||
/// A message describing more detail about the error that occurred.
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
impl From<ServerError> for io::Error {
|
||||
fn from(e: ServerError) -> io::Error {
|
||||
io::Error::new(e.kind, e.detail.unwrap_or_default())
|
||||
}
|
||||
pub detail: String,
|
||||
}
|
||||
|
||||
impl<T> Request<T> {
|
||||
|
||||
@@ -42,12 +42,10 @@ where
|
||||
type Item = io::Result<Item>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
|
||||
self.project().inner.poll_next(cx).map_err(|e| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("while reading from transport: {}", e.into()),
|
||||
)
|
||||
})
|
||||
self.project()
|
||||
.inner
|
||||
.poll_next(cx)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,39 +61,31 @@ where
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project().inner.poll_ready(cx).map_err(|e| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("while readying write half of transport: {}", e.into()),
|
||||
)
|
||||
})
|
||||
self.project()
|
||||
.inner
|
||||
.poll_ready(cx)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
||||
self.project().inner.start_send(item).map_err(|e| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("while writing to transport: {}", e.into()),
|
||||
)
|
||||
})
|
||||
self.project()
|
||||
.inner
|
||||
.start_send(item)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project().inner.poll_flush(cx).map_err(|e| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("while flushing transport: {}", e.into()),
|
||||
)
|
||||
})
|
||||
self.project()
|
||||
.inner
|
||||
.poll_flush(cx)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project().inner.poll_close(cx).map_err(|e| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("while closing write half of transport: {}", e.into()),
|
||||
)
|
||||
})
|
||||
self.project()
|
||||
.inner
|
||||
.poll_close(cx)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::{
|
||||
context::{self, SpanExt},
|
||||
trace, ClientMessage, Request, Response, Transport,
|
||||
};
|
||||
use ::tokio::sync::mpsc;
|
||||
use futures::{
|
||||
future::{AbortRegistration, Abortable},
|
||||
prelude::*,
|
||||
@@ -19,23 +20,23 @@ use futures::{
|
||||
};
|
||||
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
convert::TryFrom, error::Error, fmt, hash::Hash, marker::PhantomData, pin::Pin,
|
||||
time::SystemTime,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin};
|
||||
use tracing::{info_span, instrument::Instrument, Span};
|
||||
|
||||
mod filter;
|
||||
mod in_flight_requests;
|
||||
#[cfg(test)]
|
||||
mod testing;
|
||||
mod throttle;
|
||||
|
||||
pub use self::{
|
||||
filter::ChannelFilter,
|
||||
throttle::{Throttler, ThrottlerStream},
|
||||
};
|
||||
/// Provides functionality to apply server limits.
|
||||
pub mod limits;
|
||||
|
||||
/// Provides helper methods for streams of Channels.
|
||||
pub mod incoming;
|
||||
|
||||
/// Provides convenience functionality for tokio-enabled applications.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub mod tokio;
|
||||
|
||||
/// Settings that control the behavior of [channels](Channel).
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -94,51 +95,13 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// An extension trait for [streams](Stream) of [`Channels`](Channel).
|
||||
pub trait Incoming<C>
|
||||
where
|
||||
Self: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
/// Enforces channel per-key limits.
|
||||
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> filter::ChannelFilter<Self, K, KF>
|
||||
where
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
KF: Fn(&C) -> K,
|
||||
{
|
||||
ChannelFilter::new(self, n, keymaker)
|
||||
}
|
||||
|
||||
/// Caps the number of concurrent requests per channel.
|
||||
fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream<Self> {
|
||||
ThrottlerStream::new(self, n)
|
||||
}
|
||||
|
||||
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
|
||||
/// concurrently by spawning on tokio's default executor, and each request will be also
|
||||
/// be spawned on tokio's default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
{
|
||||
TokioServerExecutor { inner: self, serve }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C> Incoming<C> for S
|
||||
where
|
||||
S: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
}
|
||||
|
||||
/// BaseChannel is a [Transport] that keeps track of in-flight requests. It converts a
|
||||
/// [`Transport`](Transport) of [`ClientMessages`](ClientMessage) into a stream of
|
||||
/// [requests](ClientMessage::Request).
|
||||
/// BaseChannel is the standard implementation of a [`Channel`].
|
||||
///
|
||||
/// Besides requests, the other type of client message is [cancellation
|
||||
/// BaseChannel manages a [`Transport`](Transport) of client [`messages`](ClientMessage) and
|
||||
/// implements a [`Stream`] of [requests](TrackedRequest). See the [`Channel`] documentation for
|
||||
/// how to use channels.
|
||||
///
|
||||
/// Besides requests, the other type of client message handled by `BaseChannel` is [cancellation
|
||||
/// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation
|
||||
/// messages. Instead, it internally handles them by cancelling corresponding requests (removing
|
||||
/// the corresponding in-flight requests and aborting their handlers).
|
||||
@@ -190,6 +153,47 @@ where
|
||||
fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<T>> {
|
||||
self.as_mut().project().transport
|
||||
}
|
||||
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
mut request: Request<Req>,
|
||||
) -> Result<TrackedRequest<Req>, AlreadyExistsError> {
|
||||
let span = info_span!(
|
||||
"RPC",
|
||||
rpc.trace_id = %request.context.trace_id(),
|
||||
otel.kind = "server",
|
||||
otel.name = tracing::field::Empty,
|
||||
);
|
||||
span.set_context(&request.context);
|
||||
request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
|
||||
tracing::trace!(
|
||||
"OpenTelemetry subscriber not installed; making unsampled \
|
||||
child context."
|
||||
);
|
||||
request.context.trace_context.new_child()
|
||||
});
|
||||
let entered = span.enter();
|
||||
tracing::info!("ReceiveRequest");
|
||||
let start = self.project().in_flight_requests.start_request(
|
||||
request.id,
|
||||
request.context.deadline,
|
||||
span.clone(),
|
||||
);
|
||||
match start {
|
||||
Ok(abort_registration) => {
|
||||
drop(entered);
|
||||
Ok(TrackedRequest {
|
||||
request,
|
||||
abort_registration,
|
||||
span,
|
||||
})
|
||||
}
|
||||
Err(AlreadyExistsError) => {
|
||||
tracing::trace!("DuplicateRequest");
|
||||
Err(AlreadyExistsError)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
|
||||
@@ -198,8 +202,20 @@ impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// The server end of an open connection with a client, streaming in requests from, and sinking
|
||||
/// responses to, the client.
|
||||
/// A request tracked by a [`Channel`].
|
||||
#[derive(Debug)]
|
||||
pub struct TrackedRequest<Req> {
|
||||
/// The request sent by the client.
|
||||
pub request: Request<Req>,
|
||||
/// A registration to abort a future when the [`Channel`] that produced this request stops
|
||||
/// tracking it.
|
||||
pub abort_registration: AbortRegistration,
|
||||
/// A span representing the server processing of this request.
|
||||
pub span: Span,
|
||||
}
|
||||
|
||||
/// The server end of an open connection with a client, receiving requests from, and sending
|
||||
/// responses to, the client. `Channel` is a [`Transport`] with request lifecycle management.
|
||||
///
|
||||
/// The ways to use a Channel, in order of simplest to most complex, is:
|
||||
/// 1. [`Channel::execute`] - Requires the `tokio1` feature. This method is best for those who
|
||||
@@ -210,18 +226,20 @@ impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
|
||||
/// [`execute`](InFlightRequest::execute) method. If using `execute`, request processing will
|
||||
/// automatically cease when either the request deadline is reached or when a corresponding
|
||||
/// cancellation message is received by the Channel.
|
||||
/// 3. [`Sink::start_send`] - A user is free to manually send responses to requests produced by a
|
||||
/// Channel using [`Sink::start_send`] in lieu of the previous methods. If not using one of the
|
||||
/// previous execute methods, then nothing will automatically cancel requests or set up the
|
||||
/// request context. However, the Channel will still clean up resources upon deadline expiration
|
||||
/// or cancellation. In the case that the Channel cleans up resources related to a request
|
||||
/// before the response is sent, the response can still be sent into the Channel later on.
|
||||
/// Because there is no guarantee that a cancellation message will ever be received for a
|
||||
/// request, or that requests come with reasonably short deadlines, services should strive to
|
||||
/// clean up Channel resources by sending a response for every request.
|
||||
/// 3. [`Stream::next`](futures::stream::StreamExt::next) /
|
||||
/// [`Sink::send`](futures::sink::SinkExt::send) - A user is free to manually read requests
|
||||
/// from, and send responses into, a Channel in lieu of the previous methods. Channels stream
|
||||
/// [`TrackedRequests`](TrackedRequest), which, in addition to the request itself, contains the
|
||||
/// server [`Span`] and request lifetime [`AbortRegistration`]. Wrapping response
|
||||
/// logic in an [`Abortable`] future using the abort registration will ensure that the response
|
||||
/// does not execute longer than the request deadline. The `Channel` itself will clean up
|
||||
/// request state once either the deadline expires, or a cancellation message is received, or a
|
||||
/// response is sent. Because there is no guarantee that a cancellation message will ever be
|
||||
/// received for a request, or that requests come with reasonably short deadlines, services
|
||||
/// should strive to clean up Channel resources by sending a response for every request.
|
||||
pub trait Channel
|
||||
where
|
||||
Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
|
||||
Self: Transport<Response<<Self as Channel>::Resp>, TrackedRequest<<Self as Channel>::Req>>,
|
||||
{
|
||||
/// Type of request item.
|
||||
type Req;
|
||||
@@ -241,24 +259,23 @@ where
|
||||
/// Returns the transport underlying the channel.
|
||||
fn transport(&self) -> &Self::Transport;
|
||||
|
||||
/// Caps the number of concurrent requests to `limit`.
|
||||
fn max_concurrent_requests(self, limit: usize) -> Throttler<Self>
|
||||
/// Caps the number of concurrent requests to `limit`. An error will be returned for requests
|
||||
/// over the concurrency limit.
|
||||
///
|
||||
/// Note that this is a very
|
||||
/// simplistic throttling heuristic. It is easy to set a number that is too low for the
|
||||
/// resources available to the server. For production use cases, a more advanced throttler is
|
||||
/// likely needed.
|
||||
fn max_concurrent_requests(
|
||||
self,
|
||||
limit: usize,
|
||||
) -> limits::requests_per_channel::MaxRequests<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Throttler::new(self, limit)
|
||||
limits::requests_per_channel::MaxRequests::new(self, limit)
|
||||
}
|
||||
|
||||
/// Tells the Channel that request with ID `request_id` is being handled.
|
||||
/// The request will be tracked until a response with the same ID is sent
|
||||
/// to the Channel or the deadline expires, whichever happens first.
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
request_id: u64,
|
||||
deadline: SystemTime,
|
||||
span: Span,
|
||||
) -> Result<AbortRegistration, AlreadyExistsError>;
|
||||
|
||||
/// Returns a stream of requests that automatically handle request cancellation and response
|
||||
/// routing.
|
||||
///
|
||||
@@ -279,11 +296,11 @@ where
|
||||
}
|
||||
|
||||
/// Runs the channel until completion by executing all requests using the given service
|
||||
/// function. Request handlers are run concurrently by [spawning](tokio::spawn) on tokio's
|
||||
/// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's
|
||||
/// default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
fn execute<S>(self, serve: S) -> TokioChannelExecutor<Requests<Self>, S>
|
||||
fn execute<S>(self, serve: S) -> self::tokio::TokioChannelExecutor<Requests<Self>, S>
|
||||
where
|
||||
Self: Sized,
|
||||
S: Serve<Self::Req, Resp = Self::Resp> + Send + 'static,
|
||||
@@ -306,16 +323,17 @@ where
|
||||
Transport(#[source] E),
|
||||
/// An error occurred while polling expired requests.
|
||||
#[error("an error occurred while polling expired requests: {0}")]
|
||||
Timer(#[source] tokio::time::error::Error),
|
||||
Timer(#[source] ::tokio::time::error::Error),
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
{
|
||||
type Item = Result<Request<Req>, ChannelError<T::Error>>;
|
||||
type Item = Result<TrackedRequest<Req>, ChannelError<T::Error>>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
#[derive(Debug)]
|
||||
enum ReceiverStatus {
|
||||
Ready,
|
||||
Pending,
|
||||
@@ -343,7 +361,16 @@ where
|
||||
{
|
||||
Poll::Ready(Some(message)) => match message {
|
||||
ClientMessage::Request(request) => {
|
||||
return Poll::Ready(Some(Ok(request)));
|
||||
match self.as_mut().start_request(request) {
|
||||
Ok(request) => return Poll::Ready(Some(Ok(request))),
|
||||
Err(AlreadyExistsError) => {
|
||||
// Instead of closing the channel if a duplicate request is sent,
|
||||
// just ignore it, since it's already being processed. Note that we
|
||||
// cannot return Poll::Pending here, since nothing has scheduled a
|
||||
// wakeup yet.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
ClientMessage::Cancel {
|
||||
trace_context,
|
||||
@@ -362,6 +389,11 @@ where
|
||||
Poll::Pending => Pending,
|
||||
};
|
||||
|
||||
tracing::trace!(
|
||||
"Expired requests: {:?}, Inbound: {:?}",
|
||||
expiration_status,
|
||||
request_status
|
||||
);
|
||||
match (expiration_status, request_status) {
|
||||
(Ready, _) | (_, Ready) => continue,
|
||||
(Closed, Closed) => return Poll::Ready(None),
|
||||
@@ -405,6 +437,7 @@ where
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
tracing::trace!("poll_flush");
|
||||
self.project()
|
||||
.transport
|
||||
.poll_flush(cx)
|
||||
@@ -444,17 +477,6 @@ where
|
||||
fn transport(&self) -> &Self::Transport {
|
||||
self.get_ref()
|
||||
}
|
||||
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
request_id: u64,
|
||||
deadline: SystemTime,
|
||||
span: Span,
|
||||
) -> Result<AbortRegistration, AlreadyExistsError> {
|
||||
self.project()
|
||||
.in_flight_requests
|
||||
.start_request(request_id, deadline, span)
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so
|
||||
@@ -492,54 +514,12 @@ where
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<InFlightRequest<C::Req, C::Resp>, C::Error>>> {
|
||||
loop {
|
||||
match ready!(self.channel_pin_mut().poll_next(cx)?) {
|
||||
Some(mut request) => {
|
||||
let span = info_span!(
|
||||
"RPC",
|
||||
rpc.trace_id = %request.context.trace_id(),
|
||||
otel.kind = "server",
|
||||
otel.name = tracing::field::Empty,
|
||||
);
|
||||
span.set_context(&request.context);
|
||||
request.context.trace_context =
|
||||
trace::Context::try_from(&span).unwrap_or_else(|_| {
|
||||
tracing::trace!(
|
||||
"OpenTelemetry subscriber not installed; making unsampled
|
||||
child context."
|
||||
);
|
||||
request.context.trace_context.new_child()
|
||||
});
|
||||
let entered = span.enter();
|
||||
tracing::info!("ReceiveRequest");
|
||||
let start = self.channel_pin_mut().start_request(
|
||||
request.id,
|
||||
request.context.deadline,
|
||||
span.clone(),
|
||||
);
|
||||
match start {
|
||||
Ok(abort_registration) => {
|
||||
let response_tx = self.responses_tx.clone();
|
||||
drop(entered);
|
||||
return Poll::Ready(Some(Ok(InFlightRequest {
|
||||
request,
|
||||
response_tx,
|
||||
abort_registration,
|
||||
span,
|
||||
})));
|
||||
}
|
||||
// Instead of closing the channel if a duplicate request is sent, just
|
||||
// ignore it, since it's already being processed. Note that we cannot
|
||||
// return Poll::Pending here, since nothing has scheduled a wakeup yet.
|
||||
Err(AlreadyExistsError) => {
|
||||
tracing::trace!("DuplicateRequest");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
self.channel_pin_mut()
|
||||
.poll_next(cx)
|
||||
.map_ok(|request| InFlightRequest {
|
||||
request,
|
||||
response_tx: self.responses_tx.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn pump_write(
|
||||
@@ -619,16 +599,14 @@ where
|
||||
/// A request produced by [Channel::requests].
|
||||
#[derive(Debug)]
|
||||
pub struct InFlightRequest<Req, Res> {
|
||||
request: Request<Req>,
|
||||
request: TrackedRequest<Req>,
|
||||
response_tx: mpsc::Sender<Response<Res>>,
|
||||
abort_registration: AbortRegistration,
|
||||
span: Span,
|
||||
}
|
||||
|
||||
impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
/// Returns a reference to the request.
|
||||
pub fn get(&self) -> &Request<Req> {
|
||||
&self.request
|
||||
&self.request.request
|
||||
}
|
||||
|
||||
/// Returns a [future](Future) that executes the request using the given [service
|
||||
@@ -647,15 +625,18 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
S: Serve<Req, Resp = Res>,
|
||||
{
|
||||
let Self {
|
||||
abort_registration,
|
||||
request:
|
||||
Request {
|
||||
context,
|
||||
message,
|
||||
id: request_id,
|
||||
},
|
||||
response_tx,
|
||||
span,
|
||||
request:
|
||||
TrackedRequest {
|
||||
abort_registration,
|
||||
span,
|
||||
request:
|
||||
Request {
|
||||
context,
|
||||
message,
|
||||
id: request_id,
|
||||
},
|
||||
},
|
||||
} = self;
|
||||
let method = serve.method(&message);
|
||||
span.record("otel.name", &method.unwrap_or(""));
|
||||
@@ -704,129 +685,22 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// Send + 'static execution helper methods.
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
impl<C> Requests<C>
|
||||
where
|
||||
C: Channel,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
{
|
||||
/// Executes all requests using the given service function. Requests are handled concurrently
|
||||
/// by [spawning](tokio::spawn) each handler on tokio's default executor.
|
||||
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
||||
{
|
||||
TokioChannelExecutor { inner: self, serve }
|
||||
}
|
||||
}
|
||||
|
||||
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
|
||||
/// for each new channel.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub struct TokioServerExecutor<T, S> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
serve: S,
|
||||
}
|
||||
|
||||
/// A future that drives the server by [spawning](tokio::spawn) each [response
|
||||
/// handler](InFlightRequest::execute) on tokio's default executor.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub struct TokioChannelExecutor<T, S> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
serve: S,
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
impl<T, S> TokioServerExecutor<T, S> {
|
||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||
self.as_mut().project().inner
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
impl<T, S> TokioChannelExecutor<T, S> {
|
||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||
self.as_mut().project().inner
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
|
||||
where
|
||||
St: Sized + Stream<Item = C>,
|
||||
C: Channel + Send + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||
Se::Fut: Send,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||
tokio::spawn(channel.execute(self.serve.clone()));
|
||||
}
|
||||
tracing::info!("Server shutting down.");
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
|
||||
where
|
||||
C: Channel + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||
S::Fut: Send,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||
match response_handler {
|
||||
Ok(resp) => {
|
||||
let server = self.serve.clone();
|
||||
tokio::spawn(async move {
|
||||
resp.execute(server).await;
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Requests stream errored out: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests};
|
||||
use crate::{
|
||||
trace,
|
||||
context, trace,
|
||||
transport::channel::{self, UnboundedChannel},
|
||||
ClientMessage, Request, Response,
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use futures::future::{pending, Aborted};
|
||||
use futures::{
|
||||
future::{pending, AbortRegistration, Abortable, Aborted},
|
||||
prelude::*,
|
||||
Future,
|
||||
};
|
||||
use futures_test::task::noop_context;
|
||||
use std::time::Duration;
|
||||
use std::{pin::Pin, task::Poll};
|
||||
|
||||
fn test_channel<Req, Resp>() -> (
|
||||
Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
|
||||
@@ -892,12 +766,18 @@ mod tests {
|
||||
|
||||
channel
|
||||
.as_mut()
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.start_request(Request {
|
||||
id: 0,
|
||||
context: context::current(),
|
||||
message: (),
|
||||
})
|
||||
.unwrap();
|
||||
assert_matches!(
|
||||
channel
|
||||
.as_mut()
|
||||
.start_request(0, SystemTime::now(), Span::current()),
|
||||
channel.as_mut().start_request(Request {
|
||||
id: 0,
|
||||
context: context::current(),
|
||||
message: ()
|
||||
}),
|
||||
Err(AlreadyExistsError)
|
||||
);
|
||||
}
|
||||
@@ -907,13 +787,21 @@ mod tests {
|
||||
let (mut channel, _tx) = test_channel::<(), ()>();
|
||||
|
||||
tokio::time::pause();
|
||||
let abort_registration0 = channel
|
||||
let req0 = channel
|
||||
.as_mut()
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.start_request(Request {
|
||||
id: 0,
|
||||
context: context::current(),
|
||||
message: (),
|
||||
})
|
||||
.unwrap();
|
||||
let abort_registration1 = channel
|
||||
let req1 = channel
|
||||
.as_mut()
|
||||
.start_request(1, SystemTime::now(), Span::current())
|
||||
.start_request(Request {
|
||||
id: 1,
|
||||
context: context::current(),
|
||||
message: (),
|
||||
})
|
||||
.unwrap();
|
||||
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
|
||||
|
||||
@@ -921,8 +809,8 @@ mod tests {
|
||||
channel.as_mut().poll_next(&mut noop_context()),
|
||||
Poll::Pending
|
||||
);
|
||||
assert_matches!(test_abortable(abort_registration0).await, Err(Aborted));
|
||||
assert_matches!(test_abortable(abort_registration1).await, Err(Aborted));
|
||||
assert_matches!(test_abortable(req0.abort_registration).await, Err(Aborted));
|
||||
assert_matches!(test_abortable(req1.abort_registration).await, Err(Aborted));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -930,13 +818,13 @@ mod tests {
|
||||
let (mut channel, mut tx) = test_channel::<(), ()>();
|
||||
|
||||
tokio::time::pause();
|
||||
let abort_registration = channel
|
||||
let req = channel
|
||||
.as_mut()
|
||||
.start_request(
|
||||
0,
|
||||
SystemTime::now() + Duration::from_millis(100),
|
||||
Span::current(),
|
||||
)
|
||||
.start_request(Request {
|
||||
id: 0,
|
||||
context: context::current(),
|
||||
message: (),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
tx.send(ClientMessage::Cancel {
|
||||
@@ -951,7 +839,7 @@ mod tests {
|
||||
Poll::Pending
|
||||
);
|
||||
|
||||
assert_matches!(test_abortable(abort_registration).await, Err(Aborted));
|
||||
assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -961,11 +849,11 @@ mod tests {
|
||||
tokio::time::pause();
|
||||
let _abort_registration = channel
|
||||
.as_mut()
|
||||
.start_request(
|
||||
0,
|
||||
SystemTime::now() + Duration::from_millis(100),
|
||||
Span::current(),
|
||||
)
|
||||
.start_request(Request {
|
||||
id: 0,
|
||||
context: context::current(),
|
||||
message: (),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
drop(tx);
|
||||
@@ -1001,9 +889,13 @@ mod tests {
|
||||
let (mut channel, mut tx) = test_channel::<(), ()>();
|
||||
|
||||
tokio::time::pause();
|
||||
let abort_registration = channel
|
||||
let req = channel
|
||||
.as_mut()
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.start_request(Request {
|
||||
id: 0,
|
||||
context: context::current(),
|
||||
message: (),
|
||||
})
|
||||
.unwrap();
|
||||
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
|
||||
|
||||
@@ -1013,7 +905,7 @@ mod tests {
|
||||
channel.as_mut().poll_next(&mut noop_context()),
|
||||
Poll::Ready(Some(Ok(_)))
|
||||
);
|
||||
assert_matches!(test_abortable(abort_registration).await, Err(Aborted));
|
||||
assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -1022,7 +914,11 @@ mod tests {
|
||||
|
||||
channel
|
||||
.as_mut()
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.start_request(Request {
|
||||
id: 0,
|
||||
context: context::current(),
|
||||
message: (),
|
||||
})
|
||||
.unwrap();
|
||||
assert_eq!(channel.in_flight_requests(), 1);
|
||||
channel
|
||||
@@ -1043,7 +939,11 @@ mod tests {
|
||||
requests
|
||||
.as_mut()
|
||||
.channel_pin_mut()
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.start_request(Request {
|
||||
id: 0,
|
||||
context: context::current(),
|
||||
message: (),
|
||||
})
|
||||
.unwrap();
|
||||
requests
|
||||
.as_mut()
|
||||
@@ -1069,7 +969,11 @@ mod tests {
|
||||
requests
|
||||
.as_mut()
|
||||
.channel_pin_mut()
|
||||
.start_request(1, SystemTime::now(), Span::current())
|
||||
.start_request(Request {
|
||||
id: 1,
|
||||
context: context::current(),
|
||||
message: (),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_matches!(
|
||||
@@ -1086,7 +990,11 @@ mod tests {
|
||||
requests
|
||||
.as_mut()
|
||||
.channel_pin_mut()
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.start_request(Request {
|
||||
id: 0,
|
||||
context: context::current(),
|
||||
message: (),
|
||||
})
|
||||
.unwrap();
|
||||
requests
|
||||
.as_mut()
|
||||
@@ -1101,7 +1009,11 @@ mod tests {
|
||||
requests
|
||||
.as_mut()
|
||||
.channel_pin_mut()
|
||||
.start_request(1, SystemTime::now(), Span::current())
|
||||
.start_request(Request {
|
||||
id: 1,
|
||||
context: context::current(),
|
||||
message: (),
|
||||
})
|
||||
.unwrap();
|
||||
requests
|
||||
.as_mut()
|
||||
|
||||
@@ -98,6 +98,11 @@ impl InFlightRequests {
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
) -> Poll<Option<Result<u64, tokio::time::error::Error>>> {
|
||||
if self.deadlines.is_empty() {
|
||||
// TODO(https://github.com/tokio-rs/tokio/issues/4161)
|
||||
// This is a workaround for DelayQueue not always treating this case correctly.
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
self.deadlines.poll_expired(cx).map_ok(|expired| {
|
||||
if let Some(RequestData {
|
||||
abort_handle, span, ..
|
||||
@@ -184,12 +189,31 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn remove_request_doesnt_abort() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
assert!(in_flight_requests.deadlines.is_empty());
|
||||
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now(), Span::current())
|
||||
.start_request(
|
||||
0,
|
||||
SystemTime::now() + std::time::Duration::from_secs(10),
|
||||
Span::current(),
|
||||
)
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
// Precondition: Pending expiration
|
||||
assert_matches!(
|
||||
in_flight_requests.poll_expired(&mut noop_context()),
|
||||
Poll::Pending
|
||||
);
|
||||
assert!(!in_flight_requests.deadlines.is_empty());
|
||||
|
||||
assert_matches!(in_flight_requests.remove_request(0), Some(_));
|
||||
// Postcondition: No pending expirations
|
||||
assert!(in_flight_requests.deadlines.is_empty());
|
||||
assert_matches!(
|
||||
in_flight_requests.poll_expired(&mut noop_context()),
|
||||
Poll::Ready(None)
|
||||
);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Pending
|
||||
|
||||
49
tarpc/src/server/incoming.rs
Normal file
49
tarpc/src/server/incoming.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use super::{
|
||||
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
|
||||
Channel,
|
||||
};
|
||||
use futures::prelude::*;
|
||||
use std::{fmt, hash::Hash};
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
use super::{tokio::TokioServerExecutor, Serve};
|
||||
|
||||
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
|
||||
pub trait Incoming<C>
|
||||
where
|
||||
Self: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
/// Enforces channel per-key limits.
|
||||
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> MaxChannelsPerKey<Self, K, KF>
|
||||
where
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
KF: Fn(&C) -> K,
|
||||
{
|
||||
MaxChannelsPerKey::new(self, n, keymaker)
|
||||
}
|
||||
|
||||
/// Caps the number of concurrent requests per channel.
|
||||
fn max_concurrent_requests_per_channel(self, n: usize) -> MaxRequestsPerChannel<Self> {
|
||||
MaxRequestsPerChannel::new(self, n)
|
||||
}
|
||||
|
||||
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
|
||||
/// concurrently by spawning on tokio's default executor, and each request will be also
|
||||
/// be spawned on tokio's default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
{
|
||||
TokioServerExecutor::new(self, serve)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C> Incoming<C> for S
|
||||
where
|
||||
S: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
}
|
||||
5
tarpc/src/server/limits.rs
Normal file
5
tarpc/src/server/limits.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
/// Provides functionality to limit the number of active channels.
|
||||
pub mod channels_per_key;
|
||||
|
||||
/// Provides a [channel](crate::server::Channel) that limits the number of in-flight requests.
|
||||
pub mod requests_per_channel;
|
||||
@@ -9,20 +9,23 @@ use crate::{
|
||||
util::Compact,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
|
||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||
use pin_project::pin_project;
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::{
|
||||
collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin,
|
||||
time::SystemTime,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, info, trace};
|
||||
|
||||
/// A single-threaded filter that drops channels based on per-key limits.
|
||||
/// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on
|
||||
/// per-key limits.
|
||||
///
|
||||
/// The decision to drop a Channel is made once at the time the Channel materializes. Once a
|
||||
/// Channel is yielded, it will not be prematurely dropped.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ChannelFilter<S, K, F>
|
||||
pub struct MaxChannelsPerKey<S, K, F>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
{
|
||||
@@ -35,7 +38,7 @@ where
|
||||
keymaker: F,
|
||||
}
|
||||
|
||||
/// A channel that is tracked by a ChannelFilter.
|
||||
/// A channel that is tracked by [`MaxChannelsPerKey`].
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct TrackedChannel<C, K> {
|
||||
@@ -116,15 +119,6 @@ where
|
||||
fn transport(&self) -> &Self::Transport {
|
||||
self.inner.transport()
|
||||
}
|
||||
|
||||
fn start_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
span: tracing::Span,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.inner_pin_mut().start_request(id, deadline, span)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, K> TrackedChannel<C, K> {
|
||||
@@ -139,7 +133,7 @@ impl<C, K> TrackedChannel<C, K> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, K, F> ChannelFilter<S, K, F>
|
||||
impl<S, K, F> MaxChannelsPerKey<S, K, F>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
S: Stream,
|
||||
@@ -148,7 +142,7 @@ where
|
||||
/// Sheds new channels to stay under configured limits.
|
||||
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
|
||||
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel();
|
||||
ChannelFilter {
|
||||
MaxChannelsPerKey {
|
||||
listener: listener.fuse(),
|
||||
channels_per_key,
|
||||
dropped_keys,
|
||||
@@ -159,7 +153,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, K, F> ChannelFilter<S, K, F>
|
||||
impl<S, K, F> MaxChannelsPerKey<S, K, F>
|
||||
where
|
||||
S: Stream,
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
@@ -251,7 +245,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, K, F> Stream for ChannelFilter<S, K, F>
|
||||
impl<S, K, F> Stream for MaxChannelsPerKey<S, K, F>
|
||||
where
|
||||
S: Stream,
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
@@ -354,7 +348,7 @@ fn channel_filter_increment_channels_for_key() {
|
||||
key: &'static str,
|
||||
}
|
||||
let (_, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
|
||||
assert_eq!(Arc::strong_count(&tracker1), 1);
|
||||
@@ -375,7 +369,7 @@ fn channel_filter_handle_new_channel() {
|
||||
key: &'static str,
|
||||
}
|
||||
let (_, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
let channel1 = filter
|
||||
.as_mut()
|
||||
@@ -407,7 +401,7 @@ fn channel_filter_poll_listener() {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
new_channels
|
||||
@@ -443,7 +437,7 @@ fn channel_filter_poll_closed_channels() {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
new_channels
|
||||
@@ -471,7 +465,7 @@ fn channel_filter_stream() {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
new_channels
|
||||
@@ -4,45 +4,49 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use super::{Channel, Config};
|
||||
use crate::{Response, ServerError};
|
||||
use futures::{future::AbortRegistration, prelude::*, ready, task::*};
|
||||
use crate::{
|
||||
server::{Channel, Config},
|
||||
Response, ServerError,
|
||||
};
|
||||
use futures::{prelude::*, ready, task::*};
|
||||
use pin_project::pin_project;
|
||||
use std::{io, pin::Pin, time::SystemTime};
|
||||
use tracing::Span;
|
||||
use std::{io, pin::Pin};
|
||||
|
||||
/// A [`Channel`] that limits the number of concurrent
|
||||
/// requests by throttling.
|
||||
/// A [`Channel`] that limits the number of concurrent requests by throttling.
|
||||
///
|
||||
/// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low
|
||||
/// for the resources available to the server. For production use cases, a more advanced throttler
|
||||
/// is likely needed.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct Throttler<C> {
|
||||
pub struct MaxRequests<C> {
|
||||
max_in_flight_requests: usize,
|
||||
#[pin]
|
||||
inner: C,
|
||||
}
|
||||
|
||||
impl<C> Throttler<C> {
|
||||
impl<C> MaxRequests<C> {
|
||||
/// Returns the inner channel.
|
||||
pub fn get_ref(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Throttler<C>
|
||||
impl<C> MaxRequests<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
/// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to
|
||||
/// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to
|
||||
/// `max_in_flight_requests`.
|
||||
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
|
||||
Throttler {
|
||||
MaxRequests {
|
||||
max_in_flight_requests,
|
||||
inner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Stream for Throttler<C>
|
||||
impl<C> Stream for MaxRequests<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
@@ -54,19 +58,18 @@ where
|
||||
ready!(self.as_mut().project().inner.poll_ready(cx)?);
|
||||
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
|
||||
Some(request) => {
|
||||
tracing::debug!(
|
||||
rpc.trace_id = %request.context.trace_id(),
|
||||
Some(r) => {
|
||||
let _entered = r.span.enter();
|
||||
tracing::info!(
|
||||
in_flight_requests = self.as_mut().in_flight_requests(),
|
||||
max_in_flight_requests = *self.as_mut().project().max_in_flight_requests,
|
||||
"At in-flight request limit",
|
||||
"ThrottleRequest",
|
||||
);
|
||||
|
||||
self.as_mut().start_send(Response {
|
||||
request_id: request.id,
|
||||
request_id: r.request.id,
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::WouldBlock,
|
||||
detail: Some("Server throttled the request.".into()),
|
||||
detail: "server throttled the request.".into(),
|
||||
}),
|
||||
})?;
|
||||
}
|
||||
@@ -77,7 +80,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
|
||||
impl<C> Sink<Response<<C as Channel>::Resp>> for MaxRequests<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
@@ -103,13 +106,13 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> AsRef<C> for Throttler<C> {
|
||||
impl<C> AsRef<C> for MaxRequests<C> {
|
||||
fn as_ref(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Channel for Throttler<C>
|
||||
impl<C> Channel for MaxRequests<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
@@ -128,27 +131,19 @@ where
|
||||
fn transport(&self) -> &Self::Transport {
|
||||
self.inner.transport()
|
||||
}
|
||||
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
span: Span,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.project().inner.start_request(id, deadline, span)
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream of throttling channels.
|
||||
/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on
|
||||
/// the number of in-flight requests.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ThrottlerStream<S> {
|
||||
pub struct MaxRequestsPerChannel<S> {
|
||||
#[pin]
|
||||
inner: S,
|
||||
max_in_flight_requests: usize,
|
||||
}
|
||||
|
||||
impl<S> ThrottlerStream<S>
|
||||
impl<S> MaxRequestsPerChannel<S>
|
||||
where
|
||||
S: Stream,
|
||||
<S as Stream>::Item: Channel,
|
||||
@@ -161,16 +156,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream for ThrottlerStream<S>
|
||||
impl<S> Stream for MaxRequestsPerChannel<S>
|
||||
where
|
||||
S: Stream,
|
||||
<S as Stream>::Item: Channel,
|
||||
{
|
||||
type Item = Throttler<<S as Stream>::Item>;
|
||||
type Item = MaxRequests<<S as Stream>::Item>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)) {
|
||||
Some(channel) => Poll::Ready(Some(Throttler::new(
|
||||
Some(channel) => Poll::Ready(Some(MaxRequests::new(
|
||||
channel,
|
||||
*self.project().max_in_flight_requests,
|
||||
))),
|
||||
@@ -183,19 +178,20 @@ where
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::{
|
||||
server::{
|
||||
in_flight_requests::AlreadyExistsError,
|
||||
testing::{self, FakeChannel, PollExt},
|
||||
},
|
||||
Request,
|
||||
use crate::server::{
|
||||
testing::{self, FakeChannel, PollExt},
|
||||
TrackedRequest,
|
||||
};
|
||||
use pin_utils::pin_mut;
|
||||
use std::{marker::PhantomData, time::Duration};
|
||||
use std::{
|
||||
marker::PhantomData,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use tracing::Span;
|
||||
|
||||
#[tokio::test]
|
||||
async fn throttler_in_flight_requests() {
|
||||
let throttler = Throttler {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
@@ -215,28 +211,9 @@ mod tests {
|
||||
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn throttler_start_request() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler
|
||||
.as_mut()
|
||||
.start_request(
|
||||
1,
|
||||
SystemTime::now() + Duration::from_secs(1),
|
||||
Span::current(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_done() {
|
||||
let throttler = Throttler {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
@@ -247,7 +224,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_some() -> io::Result<()> {
|
||||
let throttler = Throttler {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 1,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
@@ -259,7 +236,7 @@ mod tests {
|
||||
throttler
|
||||
.as_mut()
|
||||
.poll_next(&mut testing::cx())?
|
||||
.map(|r| r.map(|r| (r.id, r.message))),
|
||||
.map(|r| r.map(|r| (r.request.id, r.request.message))),
|
||||
Poll::Ready(Some((0, 1)))
|
||||
);
|
||||
Ok(())
|
||||
@@ -267,7 +244,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_throttled() {
|
||||
let throttler = Throttler {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
@@ -283,7 +260,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_throttled_sink_not_ready() {
|
||||
let throttler = Throttler {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: PendingSink::default::<isize, isize>(),
|
||||
};
|
||||
@@ -294,7 +271,8 @@ mod tests {
|
||||
ghost: PhantomData<fn(Out) -> In>,
|
||||
}
|
||||
impl PendingSink<(), ()> {
|
||||
pub fn default<Req, Resp>() -> PendingSink<io::Result<Request<Req>>, Response<Resp>> {
|
||||
pub fn default<Req, Resp>(
|
||||
) -> PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
PendingSink { ghost: PhantomData }
|
||||
}
|
||||
}
|
||||
@@ -319,7 +297,7 @@ mod tests {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>> {
|
||||
impl<Req, Resp> Channel for PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
type Transport = ();
|
||||
@@ -332,20 +310,12 @@ mod tests {
|
||||
fn transport(&self) -> &() {
|
||||
&()
|
||||
}
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
_id: u64,
|
||||
_deadline: SystemTime,
|
||||
_span: tracing::Span,
|
||||
) -> Result<AbortRegistration, AlreadyExistsError> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn throttler_start_send() {
|
||||
let throttler = Throttler {
|
||||
let throttler = MaxRequests {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
@@ -6,10 +6,10 @@
|
||||
|
||||
use crate::{
|
||||
context,
|
||||
server::{Channel, Config},
|
||||
server::{Channel, Config, TrackedRequest},
|
||||
Request, Response,
|
||||
};
|
||||
use futures::{future::AbortRegistration, task::*, Sink, Stream};
|
||||
use futures::{task::*, Sink, Stream};
|
||||
use pin_project::pin_project;
|
||||
use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
|
||||
use tracing::Span;
|
||||
@@ -62,7 +62,7 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> Channel for FakeChannel<io::Result<Request<Req>>, Response<Resp>>
|
||||
impl<Req, Resp> Channel for FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>>
|
||||
where
|
||||
Req: Unpin,
|
||||
{
|
||||
@@ -81,34 +81,28 @@ where
|
||||
fn transport(&self) -> &() {
|
||||
&()
|
||||
}
|
||||
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
span: Span,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.project()
|
||||
.in_flight_requests
|
||||
.start_request(id, deadline, span)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
|
||||
impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
pub fn push_req(&mut self, id: u64, message: Req) {
|
||||
self.stream.push_back(Ok(Request {
|
||||
context: context::Context {
|
||||
deadline: SystemTime::UNIX_EPOCH,
|
||||
trace_context: Default::default(),
|
||||
let (_, abort_registration) = futures::future::AbortHandle::new_pair();
|
||||
self.stream.push_back(Ok(TrackedRequest {
|
||||
request: Request {
|
||||
context: context::Context {
|
||||
deadline: SystemTime::UNIX_EPOCH,
|
||||
trace_context: Default::default(),
|
||||
},
|
||||
id,
|
||||
message,
|
||||
},
|
||||
id,
|
||||
message,
|
||||
abort_registration,
|
||||
span: Span::none(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeChannel<(), ()> {
|
||||
pub fn default<Req, Resp>() -> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
|
||||
pub fn default<Req, Resp>() -> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
FakeChannel {
|
||||
stream: Default::default(),
|
||||
sink: Default::default(),
|
||||
|
||||
111
tarpc/src/server/tokio.rs
Normal file
111
tarpc/src/server/tokio.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
use super::{Channel, Requests, Serve};
|
||||
use futures::{prelude::*, ready, task::*};
|
||||
use pin_project::pin_project;
|
||||
use std::pin::Pin;
|
||||
|
||||
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
|
||||
/// for each new channel. Returned by
|
||||
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct TokioServerExecutor<T, S> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
serve: S,
|
||||
}
|
||||
|
||||
impl<T, S> TokioServerExecutor<T, S> {
|
||||
pub(crate) fn new(inner: T, serve: S) -> Self {
|
||||
Self { inner, serve }
|
||||
}
|
||||
}
|
||||
|
||||
/// A future that drives the server by [spawning](tokio::spawn) each [response
|
||||
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
|
||||
/// [`Channel::execute`](crate::server::Channel::execute).
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct TokioChannelExecutor<T, S> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
serve: S,
|
||||
}
|
||||
|
||||
impl<T, S> TokioServerExecutor<T, S> {
|
||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||
self.as_mut().project().inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> TokioChannelExecutor<T, S> {
|
||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||
self.as_mut().project().inner
|
||||
}
|
||||
}
|
||||
|
||||
// Send + 'static execution helper methods.
|
||||
|
||||
impl<C> Requests<C>
|
||||
where
|
||||
C: Channel,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
{
|
||||
/// Executes all requests using the given service function. Requests are handled concurrently
|
||||
/// by [spawning](::tokio::spawn) each handler on tokio's default executor.
|
||||
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
||||
{
|
||||
TokioChannelExecutor { inner: self, serve }
|
||||
}
|
||||
}
|
||||
|
||||
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
|
||||
where
|
||||
St: Sized + Stream<Item = C>,
|
||||
C: Channel + Send + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||
Se::Fut: Send,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||
tokio::spawn(channel.execute(self.serve.clone()));
|
||||
}
|
||||
tracing::info!("Server shutting down.");
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
|
||||
where
|
||||
C: Channel + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||
S::Fut: Send,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||
match response_handler {
|
||||
Ok(resp) => {
|
||||
let server = self.serve.clone();
|
||||
tokio::spawn(async move {
|
||||
resp.execute(server).await;
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Requests stream errored out: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
@@ -71,9 +71,9 @@ pub struct SpanId(u64);
|
||||
#[repr(u8)]
|
||||
pub enum SamplingDecision {
|
||||
/// The associated span was sampled by its creating process. Child spans must also be sampled.
|
||||
Sampled = opentelemetry::trace::TRACE_FLAG_SAMPLED,
|
||||
Sampled,
|
||||
/// The associated span was not sampled by its creating process.
|
||||
Unsampled = opentelemetry::trace::TRACE_FLAG_NOT_SAMPLED,
|
||||
Unsampled,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
@@ -173,8 +173,8 @@ impl TryFrom<&tracing::Span> for Context {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&dyn opentelemetry::trace::Span> for Context {
|
||||
fn from(span: &dyn opentelemetry::trace::Span) -> Self {
|
||||
impl From<opentelemetry::trace::SpanRef<'_>> for Context {
|
||||
fn from(span: opentelemetry::trace::SpanRef<'_>) -> Self {
|
||||
let otel_ctx = span.span_context();
|
||||
Self {
|
||||
trace_id: TraceId::from(otel_ctx.trace_id()),
|
||||
@@ -184,6 +184,15 @@ impl From<&dyn opentelemetry::trace::Span> for Context {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SamplingDecision> for opentelemetry::trace::TraceFlags {
|
||||
fn from(decision: SamplingDecision) -> Self {
|
||||
match decision {
|
||||
SamplingDecision::Sampled => opentelemetry::trace::TraceFlags::SAMPLED,
|
||||
SamplingDecision::Unsampled => opentelemetry::trace::TraceFlags::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&opentelemetry::trace::SpanContext> for SamplingDecision {
|
||||
fn from(context: &opentelemetry::trace::SpanContext) -> Self {
|
||||
if context.is_sampled() {
|
||||
|
||||
@@ -151,7 +151,7 @@ impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
|
||||
mod tests {
|
||||
use crate::{
|
||||
client, context,
|
||||
server::{BaseChannel, Incoming},
|
||||
server::{incoming::Incoming, BaseChannel},
|
||||
transport::{
|
||||
self,
|
||||
channel::{Channel, UnboundedChannel},
|
||||
@@ -170,7 +170,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration() -> io::Result<()> {
|
||||
async fn integration() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use futures::prelude::*;
|
||||
use std::io;
|
||||
use tarpc::serde_transport;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{BaseChannel, Incoming},
|
||||
server::{incoming::Incoming, BaseChannel},
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
@@ -33,7 +32,7 @@ impl ColorProtocol for ColorServer {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_call() -> io::Result<()> {
|
||||
async fn test_call() -> anyhow::Result<()> {
|
||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
|
||||
let addr = transport.local_addr();
|
||||
tokio::spawn(
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::time::{Duration, SystemTime};
|
||||
use tarpc::{
|
||||
client::{self},
|
||||
context,
|
||||
server::{self, BaseChannel, Channel, Incoming},
|
||||
server::{self, incoming::Incoming, BaseChannel, Channel},
|
||||
transport::channel,
|
||||
};
|
||||
use tokio::join;
|
||||
|
||||
Reference in New Issue
Block a user