mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1da6bcec57 | ||
|
|
75a5591158 | ||
|
|
9462aad3bf | ||
|
|
0964fc51ff | ||
|
|
27aacab432 | ||
|
|
3feb465ad3 | ||
|
|
66cdc99ae0 | ||
|
|
66419db6fd | ||
|
|
72d5dbba89 | ||
|
|
e75193c191 | ||
|
|
ce4fd49161 | ||
|
|
3c978c5bf6 | ||
|
|
6f419e9a9a | ||
|
|
b3eb8d0b7a | ||
|
|
3b422eb179 | ||
|
|
4b513bad73 | ||
|
|
e71e17866d | ||
|
|
7e3fbec077 | ||
|
|
e4bc5e8e32 | ||
|
|
bc982c5584 | ||
|
|
d440e12c19 | ||
|
|
bc8128af69 | ||
|
|
1d87c14262 | ||
|
|
ca929c2178 | ||
|
|
569039734b | ||
|
|
3d43310e6a | ||
|
|
d21cbddb0d | ||
|
|
25aa857edf | ||
|
|
0bb2e2bbbe | ||
|
|
dc376343d6 | ||
|
|
2e7d1f8a88 | ||
|
|
6314591c65 | ||
|
|
7dd7494420 | ||
|
|
6c10e3649f | ||
|
|
4c6dee13d2 | ||
|
|
e45abe953a | ||
|
|
dec3e491b5 | ||
|
|
6ce341cf79 |
48
.github/workflows/main.yml
vendored
48
.github/workflows/main.yml
vendored
@@ -1,4 +1,10 @@
|
||||
on: [push, pull_request]
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
name: Continuous integration
|
||||
|
||||
@@ -7,27 +13,59 @@ jobs:
|
||||
name: Check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cancel previous
|
||||
uses: styfle/cancel-workflow-action@0.7.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
target: mipsel-unknown-linux-gnu
|
||||
override: true
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: check
|
||||
args: --all-features
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: check
|
||||
args: --all-features --target mipsel-unknown-linux-gnu
|
||||
|
||||
test:
|
||||
name: Test Suite
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cancel previous
|
||||
uses: styfle/cancel-workflow-action@0.7.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features serde1
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features tokio1
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features serde-transport
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features tcp
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
@@ -37,6 +75,10 @@ jobs:
|
||||
name: Rustfmt
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cancel previous
|
||||
uses: styfle/cancel-workflow-action@0.7.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
@@ -53,6 +95,10 @@ jobs:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cancel previous
|
||||
uses: styfle/cancel-workflow-action@0.7.0
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
|
||||
26
README.md
26
README.md
@@ -59,7 +59,7 @@ Some other features of tarpc:
|
||||
Add to your `Cargo.toml` dependencies:
|
||||
|
||||
```toml
|
||||
tarpc = "0.23.0"
|
||||
tarpc = "0.25"
|
||||
```
|
||||
|
||||
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
@@ -68,12 +68,13 @@ Simply implement the generated service trait, and you're off to the races!
|
||||
|
||||
## Example
|
||||
|
||||
For this example, in addition to tarpc, also add two other dependencies to
|
||||
This example uses [tokio](https://tokio.rs), so add the following dependencies to
|
||||
your `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
futures = "0.3"
|
||||
tokio = "0.3"
|
||||
futures = "1.0"
|
||||
tarpc = { version = "0.25", features = ["tokio1"] }
|
||||
tokio = { version = "1.0", features = ["macros"] }
|
||||
```
|
||||
|
||||
In the following example, we use an in-process channel for communication between
|
||||
@@ -90,7 +91,7 @@ use futures::{
|
||||
};
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{self, Handler},
|
||||
server::{self, Incoming},
|
||||
};
|
||||
use std::io;
|
||||
|
||||
@@ -125,7 +126,7 @@ impl World for HelloServer {
|
||||
```
|
||||
|
||||
Lastly let's write our `main` that will start the server. While this example uses an
|
||||
[in-process channel](rpc::transport::channel), tarpc also ships a generic [`serde_transport`]
|
||||
[in-process channel](transport::channel), tarpc also ships a generic [`serde_transport`]
|
||||
behind the `serde-transport` feature, with additional [TCP](serde_transport::tcp) functionality
|
||||
available behind the `tcp` feature.
|
||||
|
||||
@@ -134,16 +135,11 @@ available behind the `tcp` feature.
|
||||
async fn main() -> io::Result<()> {
|
||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
|
||||
let server = server::new(server::Config::default())
|
||||
// incoming() takes a stream of transports such as would be returned by
|
||||
// TcpListener::incoming (but a stream instead of an iterator).
|
||||
.incoming(stream::once(future::ready(server_transport)))
|
||||
.respond_with(HelloServer.serve());
|
||||
let server = server::BaseChannel::with_defaults(server_transport);
|
||||
tokio::spawn(server.execute(HelloServer.serve()));
|
||||
|
||||
tokio::spawn(server);
|
||||
|
||||
// WorldClient is generated by the macro. It has a constructor `new` that takes a config and
|
||||
// any Transport as input
|
||||
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
// that takes a config and any Transport as input.
|
||||
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
|
||||
|
||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
|
||||
81
RELEASES.md
81
RELEASES.md
@@ -1,3 +1,84 @@
|
||||
## 0.25.0 (2021-03-10)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
#### Major server module refactoring
|
||||
|
||||
1. Renames
|
||||
|
||||
Some of the items in this module were renamed to be less generic:
|
||||
|
||||
- Handler => Incoming
|
||||
- ClientHandler => Requests
|
||||
- ResponseHandler => InFlightRequest
|
||||
- Channel::{respond_with => requests}
|
||||
|
||||
In the case of Handler: handler of *what*? Now it's a bit clearer that this is a stream of Channels
|
||||
(aka *incoming* connections).
|
||||
|
||||
Similarly, ClientHandler was a stream of requests over a single connection. Hopefully Requests
|
||||
better reflects that.
|
||||
|
||||
ResponseHandler was renamed InFlightRequest because it no longer contains the serving function.
|
||||
Instead, it is just the request, plus the response channel and an abort hook. As a result of this,
|
||||
Channel::respond_with underwent a big change: it used to take the serving function and return a
|
||||
ClientHandler; now it has been renamed Channel::requests and does not take any args.
|
||||
|
||||
2. Execute methods
|
||||
|
||||
All methods thats actually result in responses being generated have been consolidated into methods
|
||||
named `execute`:
|
||||
|
||||
- InFlightRequest::execute returns a future that completes when a response has been generated and
|
||||
sent to the server Channel.
|
||||
- Requests::execute automatically spawns response handlers for all requests over a single channel.
|
||||
- Channel::execute is a convenience for `channel.requests().execute()`.
|
||||
- Incoming::execute automatically spawns response handlers for all requests over all channels.
|
||||
|
||||
3. Removal of Server.
|
||||
|
||||
server::Server was removed, as it provided no value over the Incoming/Channel abstractions.
|
||||
Additionally, server::new was removed, since it just returned a Server.
|
||||
|
||||
#### Client RPC methods now take &self
|
||||
|
||||
This required the breaking change of removing the Client trait. The intent of the Client trait was
|
||||
to facilitate the decorator pattern by allowing users to create their own Clients that added
|
||||
behavior on top of the base client. Unfortunately, this trait had become a maintenance burden,
|
||||
consistently causing issues with lifetimes and the lack of generic associated types. Specifically,
|
||||
it meant that Client impls could not use async fns, which is no longer tenable today, with channel
|
||||
libraries moving to async fns.
|
||||
|
||||
#### Servers no longer send deadline-exceed responses.
|
||||
|
||||
The deadline-exceeded response was largely redundant, because the client
|
||||
shouldn't normally be waiting for such a response, anyway -- the normal
|
||||
client will automatically remove the in-flight request when it reaches
|
||||
the deadline.
|
||||
|
||||
This also allows for internalizing the expiration+cleanup logic entirely
|
||||
within BaseChannel, without having it leak into the Channel trait and
|
||||
requiring action taken by the Requests struct.
|
||||
|
||||
#### Clients no longer send cancel messages when the request deadline is exceeded.
|
||||
|
||||
The server already knows when the request deadline was exceeded, so the client didn't need to inform
|
||||
it.
|
||||
|
||||
### Fixes
|
||||
|
||||
- When a channel is dropped, all in-flight requests for that channel are now aborted.
|
||||
|
||||
## 0.24.1 (2020-12-28)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
Upgrades tokio to 1.0.
|
||||
|
||||
## 0.24.0 (2020-12-28)
|
||||
|
||||
This release was yanked.
|
||||
|
||||
## 0.23.0 (2020-10-19)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc-example-service"
|
||||
version = "0.6.0"
|
||||
version = "0.9.0"
|
||||
authors = ["Tim Kuehn <tikue@google.com>"]
|
||||
edition = "2018"
|
||||
license = "MIT"
|
||||
@@ -17,10 +17,8 @@ clap = "2.33"
|
||||
env_logger = "0.8"
|
||||
futures = "0.3"
|
||||
serde = { version = "1.0" }
|
||||
tarpc = { version = "0.23", path = "../tarpc", features = ["full"] }
|
||||
tokio = { version = "0.3", features = ["full"] }
|
||||
tokio-serde = { version = "0.6", features = ["json"] }
|
||||
tokio-util = { version = "0.4", features = ["codec"] }
|
||||
tarpc = { version = "0.25", path = "../tarpc", features = ["full"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
[lib]
|
||||
name = "service"
|
||||
|
||||
@@ -6,8 +6,7 @@
|
||||
|
||||
use clap::{App, Arg};
|
||||
use std::{io, net::SocketAddr};
|
||||
use tarpc::{client, context};
|
||||
use tokio_serde::formats::Json;
|
||||
use tarpc::{client, context, tokio_serde::formats::Json};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
@@ -44,12 +43,11 @@ async fn main() -> io::Result<()> {
|
||||
let name = flags.value_of("name").unwrap().into();
|
||||
|
||||
let mut transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default);
|
||||
transport.config_mut().max_frame_length(4294967296);
|
||||
transport.config_mut().max_frame_length(usize::MAX);
|
||||
|
||||
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
|
||||
// config and any Transport as input.
|
||||
let mut client =
|
||||
service::WorldClient::new(client::Config::default(), transport.await?).spawn()?;
|
||||
let client = service::WorldClient::new(client::Config::default(), transport.await?).spawn()?;
|
||||
|
||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||
|
||||
@@ -13,9 +13,9 @@ use std::{
|
||||
};
|
||||
use tarpc::{
|
||||
context,
|
||||
server::{self, Channel, Handler},
|
||||
server::{self, Channel, Incoming},
|
||||
tokio_serde::formats::Json,
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
// This is the type that implements the generated World trait. It is the business logic
|
||||
// and is used to start the server.
|
||||
@@ -58,7 +58,7 @@ async fn main() -> io::Result<()> {
|
||||
// JSON transport is provided by the json_transport tarpc module. It makes it easy
|
||||
// to start up a serde-powered json serialization strategy over TCP.
|
||||
let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?;
|
||||
listener.config_mut().max_frame_length(4294967296);
|
||||
listener.config_mut().max_frame_length(usize::MAX);
|
||||
listener
|
||||
// Ignore accept errors.
|
||||
.filter_map(|r| future::ready(r.ok()))
|
||||
@@ -69,7 +69,7 @@ async fn main() -> io::Result<()> {
|
||||
// the generated World trait.
|
||||
.map(|channel| {
|
||||
let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap());
|
||||
channel.respond_with(server.serve()).execute()
|
||||
channel.requests().execute(server.serve())
|
||||
})
|
||||
// Max 10 channels.
|
||||
.buffer_unordered(10)
|
||||
|
||||
@@ -93,7 +93,7 @@ diff=""
|
||||
for file in $(git diff --name-only --cached);
|
||||
do
|
||||
if [ ${file: -3} == ".rs" ]; then
|
||||
diff="$diff$(cargo fmt -- --unstable-features --skip-children --check $file)"
|
||||
diff="$diff$(cargo fmt -- --check $file)"
|
||||
fi
|
||||
done
|
||||
if grep --quiet "^[-+]" <<< "$diff"; then
|
||||
|
||||
@@ -84,11 +84,6 @@ command -v rustup &>/dev/null
|
||||
if [ "$?" == 0 ]; then
|
||||
printf "${SUCCESS}\n"
|
||||
|
||||
check_toolchain nightly
|
||||
if [ ${TOOLCHAIN_RESULT} == 1 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
try_run "Building ... " cargo +stable build --color=always
|
||||
try_run "Testing ... " cargo +stable test --color=always
|
||||
try_run "Testing with all features enabled ... " cargo +stable test --all-features --color=always
|
||||
@@ -97,6 +92,12 @@ if [ "$?" == 0 ]; then
|
||||
try_run "Running example \"$EXAMPLE\" ... " cargo +stable run --example $EXAMPLE
|
||||
done
|
||||
|
||||
check_toolchain nightly
|
||||
if [ ${TOOLCHAIN_RESULT} != 1 ]; then
|
||||
try_run "Running clippy ... " cargo +nightly clippy --color=always -Z unstable-options -- --deny warnings
|
||||
fi
|
||||
|
||||
|
||||
fi
|
||||
|
||||
exit $PREPUSH_RESULT
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc-plugins"
|
||||
version = "0.8.0"
|
||||
version = "0.10.0"
|
||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||
edition = "2018"
|
||||
license = "MIT"
|
||||
|
||||
@@ -215,6 +215,25 @@ impl Parse for DeriveSerde {
|
||||
}
|
||||
}
|
||||
|
||||
/// A helper attribute to avoid a direct dependency on Serde.
|
||||
///
|
||||
/// Adds the following annotations to the annotated item:
|
||||
///
|
||||
/// ```rust
|
||||
/// #[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
||||
/// #[serde(crate = "tarpc::serde")]
|
||||
/// # struct Foo;
|
||||
/// ```
|
||||
#[proc_macro_attribute]
|
||||
pub fn derive_serde(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let mut gen: proc_macro2::TokenStream = quote! {
|
||||
#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
||||
#[serde(crate = "tarpc::serde")]
|
||||
};
|
||||
gen.extend(proc_macro2::TokenStream::from(item));
|
||||
proc_macro::TokenStream::from(gen)
|
||||
}
|
||||
|
||||
/// Generates:
|
||||
/// - service trait
|
||||
/// - serve fn
|
||||
@@ -240,7 +259,10 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
|
||||
let response_fut_name = &format!("{}ResponseFut", ident.unraw());
|
||||
let derive_serialize = if derive_serde.0 {
|
||||
Some(quote!(#[derive(serde::Serialize, serde::Deserialize)]))
|
||||
Some(
|
||||
quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
||||
#[serde(crate = "tarpc::serde")]},
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -466,10 +488,11 @@ impl<'a> ServiceGenerator<'a> {
|
||||
|
||||
quote! {
|
||||
#( #attrs )*
|
||||
#vis trait #service_ident: Clone {
|
||||
#vis trait #service_ident: Sized {
|
||||
#( #types_and_fns )*
|
||||
|
||||
/// Returns a serving function to use with [tarpc::server::Channel::respond_with].
|
||||
/// Returns a serving function to use with
|
||||
/// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute).
|
||||
fn serve(self) -> #server_ident<Self> {
|
||||
#server_ident { service: self }
|
||||
}
|
||||
@@ -483,7 +506,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
/// A serving function to use with [tarpc::server::Channel::respond_with].
|
||||
/// A serving function to use with [tarpc::server::InFlightRequest::execute].
|
||||
#[derive(Clone)]
|
||||
#vis struct #server_ident<S> {
|
||||
service: S,
|
||||
@@ -646,27 +669,9 @@ impl<'a> ServiceGenerator<'a> {
|
||||
quote! {
|
||||
#[allow(unused)]
|
||||
#[derive(Clone, Debug)]
|
||||
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
|
||||
#vis struct #client_ident<C = tarpc::client::Channel<#request_ident, #response_ident>>(C);
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_from_for_client(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
client_ident,
|
||||
request_ident,
|
||||
response_ident,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl<C> From<C> for #client_ident<C>
|
||||
where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
|
||||
{
|
||||
fn from(client: C) -> Self {
|
||||
#client_ident(client)
|
||||
}
|
||||
}
|
||||
/// The client stub that makes RPC calls to the server. All request methods return
|
||||
/// [Futures](std::future::Future).
|
||||
#vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -685,7 +690,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
#vis fn new<T>(config: tarpc::client::Config, transport: T)
|
||||
-> tarpc::client::NewClient<
|
||||
Self,
|
||||
tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>
|
||||
tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
|
||||
>
|
||||
where
|
||||
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>
|
||||
@@ -717,16 +722,14 @@ impl<'a> ServiceGenerator<'a> {
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl<C> #client_ident<C>
|
||||
where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
|
||||
{
|
||||
impl #client_ident {
|
||||
#(
|
||||
#[allow(unused)]
|
||||
#( #method_attrs )*
|
||||
#vis fn #method_idents(&mut self, ctx: tarpc::context::Context, #( #args ),*)
|
||||
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
|
||||
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
|
||||
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
|
||||
let resp = tarpc::Client::call(&mut self.0, ctx, request);
|
||||
let resp = self.0.call(ctx, request);
|
||||
async move {
|
||||
match resp.await? {
|
||||
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
|
||||
@@ -752,7 +755,6 @@ impl<'a> ToTokens for ServiceGenerator<'a> {
|
||||
self.impl_debug_for_response_future(),
|
||||
self.impl_future_for_response_future(),
|
||||
self.struct_client(),
|
||||
self.impl_from_for_client(),
|
||||
self.impl_client_new(),
|
||||
self.impl_client_rpc_methods(),
|
||||
])
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc"
|
||||
version = "0.23.0"
|
||||
version = "0.25.1"
|
||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||
edition = "2018"
|
||||
license = "MIT"
|
||||
@@ -16,9 +16,9 @@ description = "An RPC framework for Rust with a focus on ease of use."
|
||||
default = []
|
||||
|
||||
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
|
||||
tokio1 = []
|
||||
serde-transport = ["tokio-serde", "tokio-util/codec"]
|
||||
tcp = ["tokio/net", "tokio/stream"]
|
||||
tokio1 = ["tokio/rt-multi-thread"]
|
||||
serde-transport = ["serde1", "tokio1", "tokio-serde/json", "tokio-util/codec"]
|
||||
tcp = ["tokio/net"]
|
||||
|
||||
full = ["serde1", "tokio1", "serde-transport", "tcp"]
|
||||
|
||||
@@ -35,28 +35,33 @@ pin-project = "1.0"
|
||||
rand = "0.7"
|
||||
serde = { optional = true, version = "1.0", features = ["derive"] }
|
||||
static_assertions = "1.1.0"
|
||||
tarpc-plugins = { path = "../plugins", version = "0.8" }
|
||||
tokio = { version = "0.3" }
|
||||
tokio-util = { optional = true, version = "0.4" }
|
||||
tokio-serde = { optional = true, version = "0.6" }
|
||||
tarpc-plugins = { path = "../plugins", version = "0.10" }
|
||||
tokio = { version = "1", features = ["time"] }
|
||||
tokio-util = { version = "0.6.3", features = ["time"] }
|
||||
tokio-serde = { optional = true, version = "0.8" }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = "1.4"
|
||||
bincode = "1.3"
|
||||
bytes = { version = "0.5", features = ["serde"] }
|
||||
bytes = { version = "1", features = ["serde"] }
|
||||
env_logger = "0.8"
|
||||
flate2 = "1.0"
|
||||
futures-test = "0.3"
|
||||
log = "0.4"
|
||||
pin-utils = "0.1.0-alpha"
|
||||
serde_bytes = "0.11"
|
||||
tokio = { version = "0.3", features = ["full"] }
|
||||
tokio-serde = { version = "0.6", features = ["json", "bincode"] }
|
||||
tokio = { version = "1", features = ["full", "test-util"] }
|
||||
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
|
||||
trybuild = "1.0"
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
|
||||
[[example]]
|
||||
name = "compression"
|
||||
required-features = ["serde-transport", "tcp"]
|
||||
|
||||
[[example]]
|
||||
name = "server_calling_server"
|
||||
required-features = ["full"]
|
||||
@@ -68,3 +73,15 @@ required-features = ["full"]
|
||||
[[example]]
|
||||
name = "pubsub"
|
||||
required-features = ["full"]
|
||||
|
||||
[[example]]
|
||||
name = "custom_transport"
|
||||
required-features = ["serde1", "tokio1", "serde-transport"]
|
||||
|
||||
[[test]]
|
||||
name = "service_functional"
|
||||
required-features = ["serde-transport"]
|
||||
|
||||
[[test]]
|
||||
name = "dataservice"
|
||||
required-features = ["serde-transport", "tcp"]
|
||||
|
||||
@@ -113,14 +113,12 @@ async fn main() -> anyhow::Result<()> {
|
||||
tokio::spawn(async move {
|
||||
let transport = incoming.next().await.unwrap().unwrap();
|
||||
BaseChannel::with_defaults(add_compression(transport))
|
||||
.respond_with(HelloServer.serve())
|
||||
.execute()
|
||||
.execute(HelloServer.serve())
|
||||
.await;
|
||||
});
|
||||
|
||||
let transport = tcp::connect(addr, Bincode::default).await?;
|
||||
let mut client =
|
||||
WorldClient::new(client::Config::default(), add_compression(transport)).spawn()?;
|
||||
let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn()?;
|
||||
|
||||
println!(
|
||||
"{}",
|
||||
|
||||
50
tarpc/examples/custom_transport.rs
Normal file
50
tarpc/examples/custom_transport.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use futures::future;
|
||||
use tarpc::context::Context;
|
||||
use tarpc::serde_transport as transport;
|
||||
use tarpc::server::{BaseChannel, Channel};
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
use tokio_serde::formats::Bincode;
|
||||
use tokio_util::codec::length_delimited::LengthDelimitedCodec;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait PingService {
|
||||
async fn ping();
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Service;
|
||||
|
||||
impl PingService for Service {
|
||||
type PingFut = future::Ready<()>;
|
||||
|
||||
fn ping(self, _: Context) -> Self::PingFut {
|
||||
future::ready(())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
let bind_addr = "/tmp/tarpc_on_unix_example.sock";
|
||||
|
||||
let _ = std::fs::remove_file(bind_addr);
|
||||
|
||||
let listener = UnixListener::bind(bind_addr).unwrap();
|
||||
let codec_builder = LengthDelimitedCodec::builder();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (conn, _addr) = listener.accept().await.unwrap();
|
||||
let framed = codec_builder.new_framed(conn);
|
||||
let transport = transport::new(framed, Bincode::default());
|
||||
|
||||
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
});
|
||||
|
||||
let conn = UnixStream::connect(bind_addr).await?;
|
||||
let transport = transport::new(codec_builder.new_framed(conn), Bincode::default());
|
||||
PingServiceClient::new(Default::default(), transport)
|
||||
.spawn()?
|
||||
.ping(tarpc::context::current())
|
||||
.await
|
||||
}
|
||||
@@ -105,11 +105,11 @@ impl Subscriber {
|
||||
) -> anyhow::Result<SubscriberHandle> {
|
||||
let publisher = tcp::connect(publisher_addr, Json::default).await?;
|
||||
let local_addr = publisher.local_addr()?;
|
||||
let mut handler = server::BaseChannel::with_defaults(publisher)
|
||||
.respond_with(Subscriber { local_addr, topics }.serve());
|
||||
// The first request is for the topics being subscriibed to.
|
||||
let mut handler = server::BaseChannel::with_defaults(publisher).requests();
|
||||
let subscriber = Subscriber { local_addr, topics };
|
||||
// The first request is for the topics being subscribed to.
|
||||
match handler.next().await {
|
||||
Some(init_topics) => init_topics?.await,
|
||||
Some(init_topics) => init_topics?.execute(subscriber.clone().serve()).await,
|
||||
None => {
|
||||
return Err(anyhow!(
|
||||
"[{}] Server never initialized the subscriber.",
|
||||
@@ -117,7 +117,7 @@ impl Subscriber {
|
||||
))
|
||||
}
|
||||
};
|
||||
let (handler, abort_handle) = future::abortable(handler.execute());
|
||||
let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve()));
|
||||
tokio::spawn(async move {
|
||||
match handler.await {
|
||||
Ok(()) | Err(future::Aborted) => info!("[{}] subscriber shutdown.", local_addr),
|
||||
@@ -162,8 +162,7 @@ impl Publisher {
|
||||
info!("[{}] publisher connected.", publisher.peer_addr().unwrap());
|
||||
|
||||
server::BaseChannel::with_defaults(publisher)
|
||||
.respond_with(self.serve())
|
||||
.execute()
|
||||
.execute(self.serve())
|
||||
.await
|
||||
});
|
||||
|
||||
@@ -204,7 +203,7 @@ impl Publisher {
|
||||
async fn initialize_subscription(
|
||||
&mut self,
|
||||
subscriber_addr: SocketAddr,
|
||||
mut subscriber: subscriber::SubscriberClient,
|
||||
subscriber: subscriber::SubscriberClient,
|
||||
) {
|
||||
// Populate the topics
|
||||
if let Ok(topics) = subscriber.topics(context::current()).await {
|
||||
@@ -306,7 +305,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut publisher = publisher::PublisherClient::new(
|
||||
let publisher = publisher::PublisherClient::new(
|
||||
client::Config::default(),
|
||||
tcp::connect(addrs.publisher, Json::default).await?,
|
||||
)
|
||||
|
||||
@@ -4,16 +4,12 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use futures::{
|
||||
future::{self, Ready},
|
||||
prelude::*,
|
||||
};
|
||||
use futures::future::{self, Ready};
|
||||
use std::io;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{BaseChannel, Channel},
|
||||
server::{self, Channel},
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
/// This is the service definition. It looks a lot like a trait definition.
|
||||
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
@@ -40,40 +36,21 @@ impl World for HelloServer {
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
// tarpc_json_transport is provided by the associated crate json_transport. It makes it
|
||||
// easy to start up a serde-powered JSON serialization strategy over TCP.
|
||||
let mut transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default).await?;
|
||||
let addr = transport.local_addr();
|
||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
|
||||
let server = async move {
|
||||
// For this example, we're just going to wait for one connection.
|
||||
let client = transport.next().await.unwrap().unwrap();
|
||||
let server = server::BaseChannel::with_defaults(server_transport);
|
||||
tokio::spawn(server.execute(HelloServer.serve()));
|
||||
|
||||
// `Channel` is a trait representing a server-side connection. It is a trait to allow
|
||||
// for some channels to be instrumented: for example, to track the number of open connections.
|
||||
// BaseChannel is the most basic channel, simply wrapping a transport with no added
|
||||
// functionality.
|
||||
BaseChannel::with_defaults(client)
|
||||
// serve_world is generated by the tarpc::service attribute. It takes as input any type
|
||||
// implementing the generated World trait.
|
||||
.respond_with(HelloServer.serve())
|
||||
.execute()
|
||||
.await;
|
||||
};
|
||||
tokio::spawn(server);
|
||||
|
||||
let transport = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
|
||||
// WorldClient is generated by the tarpc::service attribute. It has a constructor `new` that
|
||||
// takes a config and any Transport as input.
|
||||
let mut client = WorldClient::new(client::Config::default(), transport).spawn()?;
|
||||
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
// that takes a config and any Transport as input.
|
||||
let client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
|
||||
|
||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||
// specifies a deadline and trace information which can be helpful in debugging requests.
|
||||
let hello = client.hello(context::current(), "Stim".to_string()).await?;
|
||||
|
||||
eprintln!("{}", hello);
|
||||
println!("{}", hello);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use futures::{future, prelude::*};
|
||||
use std::io;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{Handler, Server},
|
||||
server::{BaseChannel, Incoming},
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
@@ -46,7 +46,7 @@ struct DoubleServer {
|
||||
|
||||
#[tarpc::server]
|
||||
impl DoubleService for DoubleServer {
|
||||
async fn double(mut self, _: context::Context, x: i32) -> Result<i32, String> {
|
||||
async fn double(self, _: context::Context, x: i32) -> Result<i32, String> {
|
||||
self.add_client
|
||||
.add(context::current(), x, x)
|
||||
.await
|
||||
@@ -62,10 +62,10 @@ async fn main() -> io::Result<()> {
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let addr = add_listener.get_ref().local_addr();
|
||||
let add_server = Server::default()
|
||||
.incoming(add_listener)
|
||||
let add_server = add_listener
|
||||
.map(BaseChannel::with_defaults)
|
||||
.take(1)
|
||||
.respond_with(AddServer.serve());
|
||||
.execute(AddServer.serve());
|
||||
tokio::spawn(add_server);
|
||||
|
||||
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
@@ -75,14 +75,14 @@ async fn main() -> io::Result<()> {
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let addr = double_listener.get_ref().local_addr();
|
||||
let double_server = tarpc::Server::default()
|
||||
.incoming(double_listener)
|
||||
let double_server = double_listener
|
||||
.map(BaseChannel::with_defaults)
|
||||
.take(1)
|
||||
.respond_with(DoubleServer { add_client }.serve());
|
||||
.execute(DoubleServer { add_client }.serve());
|
||||
tokio::spawn(double_server);
|
||||
|
||||
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let mut double_client =
|
||||
let double_client =
|
||||
double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?;
|
||||
|
||||
for i in 1..=5 {
|
||||
|
||||
@@ -4,32 +4,93 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Provides a client that connects to a server and sends multiplexed requests.
|
||||
|
||||
mod in_flight_requests;
|
||||
|
||||
use crate::{
|
||||
context,
|
||||
trace::SpanId,
|
||||
util::{Compact, TimeUntil},
|
||||
ClientMessage, PollContext, PollIo, Request, Response, Transport,
|
||||
context, trace::SpanId, ClientMessage, PollContext, PollIo, Request, Response, Transport,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
prelude::*,
|
||||
ready,
|
||||
stream::Fuse,
|
||||
task::*,
|
||||
};
|
||||
use log::{debug, info, trace};
|
||||
use pin_project::{pin_project, pinned_drop};
|
||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||
use in_flight_requests::InFlightRequests;
|
||||
use log::{info, trace};
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
io,
|
||||
convert::TryFrom,
|
||||
fmt, io,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
use super::{Config, NewClient};
|
||||
/// Settings that control the behavior of the client.
|
||||
#[derive(Clone, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub struct Config {
|
||||
/// The number of requests that can be in flight at once.
|
||||
/// `max_in_flight_requests` controls the size of the map used by the client
|
||||
/// for storing pending requests.
|
||||
pub max_in_flight_requests: usize,
|
||||
/// The number of requests that can be buffered client-side before being sent.
|
||||
/// `pending_requests_buffer` controls the size of the channel clients use
|
||||
/// to communicate with the request dispatch task.
|
||||
pub pending_request_buffer: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Config {
|
||||
max_in_flight_requests: 1_000,
|
||||
pending_request_buffer: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A channel and dispatch pair. The dispatch drives the sending and receiving of requests
|
||||
/// and must be polled continuously or spawned.
|
||||
pub struct NewClient<C, D> {
|
||||
/// The new client.
|
||||
pub client: C,
|
||||
/// The client's dispatch.
|
||||
pub dispatch: D,
|
||||
}
|
||||
|
||||
impl<C, D, E> NewClient<C, D>
|
||||
where
|
||||
D: Future<Output = Result<(), E>> + Send + 'static,
|
||||
E: std::fmt::Display,
|
||||
{
|
||||
/// Helper method to spawn the dispatch on the default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub fn spawn(self) -> io::Result<C> {
|
||||
use log::warn;
|
||||
|
||||
let dispatch = self
|
||||
.dispatch
|
||||
.unwrap_or_else(move |e| warn!("Connection broken: {}", e));
|
||||
tokio::spawn(dispatch);
|
||||
Ok(self.client)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, D> fmt::Debug for NewClient<C, D> {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "NewClient")
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[allow(clippy::no_effect)]
|
||||
const CHECK_USIZE: () = {
|
||||
if std::mem::size_of::<usize>() > std::mem::size_of::<u64>() {
|
||||
// TODO: replace this with panic!() as soon as RFC 2345 gets stabilized
|
||||
["usize is too big to fit in u64"][42];
|
||||
}
|
||||
};
|
||||
|
||||
/// Handles communication from the client to request dispatch.
|
||||
#[derive(Debug)]
|
||||
@@ -38,7 +99,7 @@ pub struct Channel<Req, Resp> {
|
||||
/// Channel to send a cancel message to the dispatcher.
|
||||
cancellation: RequestCancellation,
|
||||
/// The ID to use for the next request to stage.
|
||||
next_request_id: Arc<AtomicU64>,
|
||||
next_request_id: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl<Req, Resp> Clone for Channel<Req, Resp> {
|
||||
@@ -51,106 +112,64 @@ impl<Req, Resp> Clone for Channel<Req, Resp> {
|
||||
}
|
||||
}
|
||||
|
||||
/// A future returned by [`Channel::send`] that resolves to a server response.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
struct Send<'a, Req, Resp> {
|
||||
#[pin]
|
||||
fut: MapOkDispatchResponse<SendMapErrConnectionReset<'a, Req, Resp>, Resp>,
|
||||
}
|
||||
|
||||
type SendMapErrConnectionReset<'a, Req, Resp> = MapErrConnectionReset<
|
||||
futures::sink::Send<'a, mpsc::Sender<DispatchRequest<Req, Resp>>, DispatchRequest<Req, Resp>>,
|
||||
>;
|
||||
|
||||
impl<'a, Req, Resp> Future for Send<'a, Req, Resp> {
|
||||
type Output = io::Result<DispatchResponse<Resp>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.as_mut().project().fut.poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// A future returned by [`Channel::call`] that resolves to a server response.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct Call<'a, Req, Resp> {
|
||||
#[pin]
|
||||
fut: tokio::time::Timeout<AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>>,
|
||||
}
|
||||
|
||||
impl<'a, Req, Resp> Future for Call<'a, Req, Resp> {
|
||||
type Output = io::Result<Resp>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let resp = ready!(self.as_mut().project().fut.poll(cx));
|
||||
Poll::Ready(match resp {
|
||||
Ok(resp) => resp,
|
||||
Err(tokio::time::error::Elapsed { .. }) => Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"Client dropped expired request.".to_string(),
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> Channel<Req, Resp> {
|
||||
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
||||
/// resolves when the request is sent (not when the response is received).
|
||||
fn send(&mut self, mut ctx: context::Context, request: Req) -> Send<Req, Resp> {
|
||||
fn send(
|
||||
&self,
|
||||
mut ctx: context::Context,
|
||||
request: Req,
|
||||
) -> impl Future<Output = io::Result<DispatchResponse<Resp>>> + '_ {
|
||||
// Convert the context to the call context.
|
||||
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
|
||||
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
|
||||
|
||||
let (response_completion, response) = oneshot::channel();
|
||||
let cancellation = self.cancellation.clone();
|
||||
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
||||
Send {
|
||||
fut: MapOkDispatchResponse::new(
|
||||
MapErrConnectionReset::new(self.to_dispatch.send(DispatchRequest {
|
||||
let request_id =
|
||||
u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||
|
||||
// DispatchResponse impls Drop to cancel in-flight requests. It should be created before
|
||||
// sending out the request; otherwise, the response future could be dropped after the
|
||||
// request is sent out but before DispatchResponse is created, rendering the cancellation
|
||||
// logic inactive.
|
||||
let response = DispatchResponse {
|
||||
response,
|
||||
request_id,
|
||||
cancellation: Some(cancellation),
|
||||
ctx,
|
||||
};
|
||||
async move {
|
||||
self.to_dispatch
|
||||
.send(DispatchRequest {
|
||||
ctx,
|
||||
request_id,
|
||||
request,
|
||||
response_completion,
|
||||
})),
|
||||
DispatchResponse {
|
||||
response,
|
||||
complete: false,
|
||||
request_id,
|
||||
cancellation,
|
||||
ctx,
|
||||
},
|
||||
),
|
||||
})
|
||||
.await
|
||||
.map_err(|mpsc::error::SendError(_)| {
|
||||
io::Error::from(io::ErrorKind::ConnectionReset)
|
||||
})?;
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
||||
/// resolves to the response.
|
||||
pub fn call(&mut self, ctx: context::Context, request: Req) -> Call<Req, Resp> {
|
||||
let timeout = ctx.deadline.time_until();
|
||||
trace!(
|
||||
"[{}] Queuing request with timeout {:?}.",
|
||||
ctx.trace_id(),
|
||||
timeout,
|
||||
);
|
||||
|
||||
Call {
|
||||
fut: tokio::time::timeout(timeout, AndThenIdent::new(self.send(ctx, request))),
|
||||
}
|
||||
pub async fn call(&self, ctx: context::Context, request: Req) -> io::Result<Resp> {
|
||||
let dispatch_response = self.send(ctx, request).await?;
|
||||
dispatch_response.await
|
||||
}
|
||||
}
|
||||
|
||||
/// A server response that is completed by request dispatch when the corresponding response
|
||||
/// arrives off the wire.
|
||||
#[pin_project(PinnedDrop)]
|
||||
#[derive(Debug)]
|
||||
struct DispatchResponse<Resp> {
|
||||
response: oneshot::Receiver<Response<Resp>>,
|
||||
ctx: context::Context,
|
||||
complete: bool,
|
||||
cancellation: RequestCancellation,
|
||||
cancellation: Option<RequestCancellation>,
|
||||
request_id: u64,
|
||||
}
|
||||
|
||||
@@ -159,10 +178,10 @@ impl<Resp> Future for DispatchResponse<Resp> {
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
|
||||
let resp = ready!(self.response.poll_unpin(cx));
|
||||
self.complete = true;
|
||||
self.cancellation.take();
|
||||
Poll::Ready(match resp {
|
||||
Ok(resp) => Ok(resp.message?),
|
||||
Err(oneshot::Canceled) => {
|
||||
Err(oneshot::error::RecvError { .. }) => {
|
||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||
// there's nothing listening on the other side, so there's no point in
|
||||
// propagating cancellation.
|
||||
@@ -173,10 +192,9 @@ impl<Resp> Future for DispatchResponse<Resp> {
|
||||
}
|
||||
|
||||
// Cancels the request when dropped, if not already complete.
|
||||
#[pinned_drop]
|
||||
impl<Resp> PinnedDrop for DispatchResponse<Resp> {
|
||||
fn drop(mut self: Pin<&mut Self>) {
|
||||
if !self.complete {
|
||||
impl<Resp> Drop for DispatchResponse<Resp> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(cancellation) = &mut self.cancellation {
|
||||
// The receiver needs to be closed to handle the edge case that the request has not
|
||||
// yet been received by the dispatch task. It is possible for the cancel message to
|
||||
// arrive before the request itself, in which case the request could get stuck in the
|
||||
@@ -188,8 +206,7 @@ impl<Resp> PinnedDrop for DispatchResponse<Resp> {
|
||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||
// receiver as closed.
|
||||
self.response.close();
|
||||
let request_id = self.request_id;
|
||||
self.cancellation.cancel(request_id);
|
||||
cancellation.cancel(self.request_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -205,20 +222,20 @@ where
|
||||
{
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
|
||||
let (cancellation, canceled_requests) = cancellations();
|
||||
let canceled_requests = canceled_requests.fuse();
|
||||
let canceled_requests = canceled_requests;
|
||||
|
||||
NewClient {
|
||||
client: Channel {
|
||||
to_dispatch,
|
||||
cancellation,
|
||||
next_request_id: Arc::new(AtomicU64::new(0)),
|
||||
next_request_id: Arc::new(AtomicUsize::new(0)),
|
||||
},
|
||||
dispatch: RequestDispatch {
|
||||
config,
|
||||
canceled_requests,
|
||||
transport: transport.fuse(),
|
||||
in_flight_requests: FnvHashMap::default(),
|
||||
pending_requests: pending_requests.fuse(),
|
||||
in_flight_requests: InFlightRequests::default(),
|
||||
pending_requests,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -232,13 +249,11 @@ pub struct RequestDispatch<Req, Resp, C> {
|
||||
#[pin]
|
||||
transport: Fuse<C>,
|
||||
/// Requests waiting to be written to the wire.
|
||||
#[pin]
|
||||
pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>,
|
||||
pending_requests: mpsc::Receiver<DispatchRequest<Req, Resp>>,
|
||||
/// Requests that were dropped.
|
||||
#[pin]
|
||||
canceled_requests: Fuse<CanceledRequests>,
|
||||
canceled_requests: CanceledRequests,
|
||||
/// Requests already written to the wire that haven't yet received responses.
|
||||
in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>,
|
||||
in_flight_requests: InFlightRequests<Resp>,
|
||||
/// Configures limits to prevent unlimited resource usage.
|
||||
config: Config,
|
||||
}
|
||||
@@ -247,50 +262,70 @@ impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
|
||||
where
|
||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||
{
|
||||
fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<Resp> {
|
||||
self.as_mut().project().in_flight_requests
|
||||
}
|
||||
|
||||
fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<C>> {
|
||||
self.as_mut().project().transport
|
||||
}
|
||||
|
||||
fn canceled_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut CanceledRequests {
|
||||
self.as_mut().project().canceled_requests
|
||||
}
|
||||
|
||||
fn pending_requests_mut<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
) -> &'a mut mpsc::Receiver<DispatchRequest<Req, Resp>> {
|
||||
self.as_mut().project().pending_requests
|
||||
}
|
||||
|
||||
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
Poll::Ready(
|
||||
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
|
||||
Some(response) => {
|
||||
self.complete(response);
|
||||
Some(Ok(()))
|
||||
}
|
||||
None => None,
|
||||
},
|
||||
)
|
||||
Poll::Ready(match ready!(self.transport_pin_mut().poll_next(cx)?) {
|
||||
Some(response) => {
|
||||
self.complete(response);
|
||||
Some(Ok(()))
|
||||
}
|
||||
None => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
enum ReceiverStatus {
|
||||
NotReady,
|
||||
Pending,
|
||||
Closed,
|
||||
}
|
||||
|
||||
let pending_requests_status = match self.as_mut().poll_next_request(cx)? {
|
||||
Poll::Ready(Some(dispatch_request)) => {
|
||||
self.as_mut().write_request(dispatch_request)?;
|
||||
return Poll::Ready(Some(Ok(())));
|
||||
}
|
||||
let pending_requests_status = match self.as_mut().poll_write_request(cx)? {
|
||||
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
|
||||
Poll::Ready(None) => ReceiverStatus::Closed,
|
||||
Poll::Pending => ReceiverStatus::NotReady,
|
||||
Poll::Pending => ReceiverStatus::Pending,
|
||||
};
|
||||
|
||||
let canceled_requests_status = match self.as_mut().poll_next_cancellation(cx)? {
|
||||
Poll::Ready(Some((context, request_id))) => {
|
||||
self.as_mut().write_cancel(context, request_id)?;
|
||||
return Poll::Ready(Some(Ok(())));
|
||||
}
|
||||
let canceled_requests_status = match self.as_mut().poll_write_cancel(cx)? {
|
||||
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
|
||||
Poll::Ready(None) => ReceiverStatus::Closed,
|
||||
Poll::Pending => ReceiverStatus::NotReady,
|
||||
Poll::Pending => ReceiverStatus::Pending,
|
||||
};
|
||||
|
||||
// Receiving Poll::Ready(None) when polling expired requests never indicates "Closed",
|
||||
// because there can temporarily be zero in-flight rquests. Therefore, there is no need to
|
||||
// track the status like is done with pending and cancelled requests.
|
||||
if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx)? {
|
||||
// Expired requests are considered complete; there is no compelling reason to send a
|
||||
// cancellation message to the server, since it will have already exhausted its
|
||||
// allotted processing time.
|
||||
return Poll::Ready(Some(Ok(())));
|
||||
}
|
||||
|
||||
match (pending_requests_status, canceled_requests_status) {
|
||||
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
|
||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
||||
ready!(self.transport_pin_mut().poll_flush(cx)?);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
(ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
|
||||
(ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => {
|
||||
// No more messages to process, so flush any messages buffered in the transport.
|
||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
||||
ready!(self.transport_pin_mut().poll_flush(cx)?);
|
||||
|
||||
// Even if we fully-flush, we return Pending, because we have no more requests
|
||||
// or cancellations right now.
|
||||
@@ -300,14 +335,17 @@ where
|
||||
}
|
||||
|
||||
/// Yields the next pending request, if one is ready to be sent.
|
||||
///
|
||||
/// Note that a request will only be yielded if the transport is *ready* to be written to (i.e.
|
||||
/// start_send would succeed).
|
||||
fn poll_next_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<DispatchRequest<Req, Resp>> {
|
||||
if self.as_mut().project().in_flight_requests.len() >= self.config.max_in_flight_requests {
|
||||
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
|
||||
info!(
|
||||
"At in-flight request capacity ({}/{}).",
|
||||
self.as_mut().project().in_flight_requests.len(),
|
||||
self.in_flight_requests().len(),
|
||||
self.config.max_in_flight_requests
|
||||
);
|
||||
|
||||
@@ -316,15 +354,12 @@ where
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? {
|
||||
// We can't yield a request-to-be-sent before the transport is capable of buffering it.
|
||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
||||
}
|
||||
ready!(self.ensure_writeable(cx)?);
|
||||
|
||||
loop {
|
||||
match ready!(self.as_mut().project().pending_requests.poll_next_unpin(cx)) {
|
||||
match ready!(self.pending_requests_mut().poll_recv(cx)) {
|
||||
Some(request) => {
|
||||
if request.response_completion.is_canceled() {
|
||||
if request.response_completion.is_closed() {
|
||||
trace!(
|
||||
"[{}] Request canceled before being sent.",
|
||||
request.ctx.trace_id()
|
||||
@@ -340,31 +375,20 @@ where
|
||||
}
|
||||
|
||||
/// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
|
||||
///
|
||||
/// Note that a request to cancel will only be yielded if the transport is *ready* to be
|
||||
/// written to (i.e. start_send would succeed).
|
||||
fn poll_next_cancellation(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<(context::Context, u64)> {
|
||||
while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? {
|
||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
||||
}
|
||||
ready!(self.ensure_writeable(cx)?);
|
||||
|
||||
loop {
|
||||
let cancellation = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.canceled_requests
|
||||
.poll_next_unpin(cx);
|
||||
match ready!(cancellation) {
|
||||
match ready!(self.canceled_requests_mut().poll_next_unpin(cx)) {
|
||||
Some(request_id) => {
|
||||
if let Some(in_flight_data) = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&request_id)
|
||||
{
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
debug!("[{}] Removed request.", in_flight_data.ctx.trace_id());
|
||||
return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id))));
|
||||
if let Some(ctx) = self.in_flight_requests().cancel_request(request_id) {
|
||||
return Poll::Ready(Some(Ok((ctx, request_id))));
|
||||
}
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
@@ -372,10 +396,24 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn write_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
dispatch_request: DispatchRequest<Req, Resp>,
|
||||
) -> io::Result<()> {
|
||||
/// Returns Ready if writing a message to the transport (i.e. via write_request or
|
||||
/// write_cancel) would not fail due to a full buffer. If the transport is not ready to be
|
||||
/// written to, flushes it until it is ready.
|
||||
fn ensure_writeable<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
while self.transport_pin_mut().poll_ready(cx)?.is_pending() {
|
||||
ready!(self.transport_pin_mut().poll_flush(cx)?);
|
||||
}
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
fn poll_write_request<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
let dispatch_request = match ready!(self.as_mut().poll_next_request(cx)?) {
|
||||
Some(dispatch_request) => dispatch_request,
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
// poll_next_request only returns Ready if there is room to buffer another request.
|
||||
// Therefore, we can call write_request without fear of erroring due to a full
|
||||
// buffer.
|
||||
let request_id = dispatch_request.request_id;
|
||||
let request = ClientMessage::Request(Request {
|
||||
id: request_id,
|
||||
@@ -385,54 +423,36 @@ where
|
||||
trace_context: dispatch_request.ctx.trace_context,
|
||||
},
|
||||
});
|
||||
self.as_mut().project().transport.start_send(request)?;
|
||||
self.as_mut().project().in_flight_requests.insert(
|
||||
request_id,
|
||||
InFlightData {
|
||||
ctx: dispatch_request.ctx,
|
||||
response_completion: dispatch_request.response_completion,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
self.transport_pin_mut().start_send(request)?;
|
||||
self.in_flight_requests()
|
||||
.insert_request(
|
||||
request_id,
|
||||
dispatch_request.ctx,
|
||||
dispatch_request.response_completion,
|
||||
)
|
||||
.expect("Request IDs should be unique");
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
fn write_cancel(
|
||||
mut self: Pin<&mut Self>,
|
||||
context: context::Context,
|
||||
request_id: u64,
|
||||
) -> io::Result<()> {
|
||||
fn poll_write_cancel<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
let (context, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
|
||||
Some((context, request_id)) => (context, request_id),
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
|
||||
let trace_id = *context.trace_id();
|
||||
let cancel = ClientMessage::Cancel {
|
||||
trace_context: context.trace_context,
|
||||
request_id,
|
||||
};
|
||||
self.as_mut().project().transport.start_send(cancel)?;
|
||||
self.transport_pin_mut().start_send(cancel)?;
|
||||
trace!("[{}] Cancel message sent.", trace_id);
|
||||
Ok(())
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
/// Sends a server response to the client task that initiated the associated request.
|
||||
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
|
||||
if let Some(in_flight_data) = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&response.request_id)
|
||||
{
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
|
||||
trace!("[{}] Received response.", in_flight_data.ctx.trace_id());
|
||||
let _ = in_flight_data.response_completion.send(response);
|
||||
return true;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"No in-flight request found for request_id = {}.",
|
||||
response.request_id
|
||||
);
|
||||
|
||||
// If the response completion was absent, then the request was already canceled.
|
||||
false
|
||||
self.in_flight_requests().complete_request(response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -457,13 +477,13 @@ where
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
(read, Poll::Ready(None)) => {
|
||||
if self.as_mut().project().in_flight_requests.is_empty() {
|
||||
if self.in_flight_requests().is_empty() {
|
||||
info!("Shutdown: write half closed, and no requests in flight.");
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
info!(
|
||||
"Shutdown: write half closed, and {} requests in flight.",
|
||||
self.as_mut().project().in_flight_requests.len()
|
||||
self.in_flight_requests().len()
|
||||
);
|
||||
match read {
|
||||
Poll::Ready(Some(())) => continue,
|
||||
@@ -481,16 +501,10 @@ where
|
||||
/// the lifecycle of the request.
|
||||
#[derive(Debug)]
|
||||
struct DispatchRequest<Req, Resp> {
|
||||
ctx: context::Context,
|
||||
request_id: u64,
|
||||
request: Req,
|
||||
response_completion: oneshot::Sender<Response<Resp>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InFlightData<Resp> {
|
||||
ctx: context::Context,
|
||||
response_completion: oneshot::Sender<Response<Resp>>,
|
||||
pub ctx: context::Context,
|
||||
pub request_id: u64,
|
||||
pub request: Req,
|
||||
pub response_completion: oneshot::Sender<Response<Resp>>,
|
||||
}
|
||||
|
||||
/// Sends request cancellation signals.
|
||||
@@ -507,14 +521,20 @@ fn cancellations() -> (RequestCancellation, CanceledRequests) {
|
||||
// bounded by the number of in-flight requests. Additionally, each request has a clone
|
||||
// of the sender, so the bounded channel would have the same behavior,
|
||||
// since it guarantees a slot.
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
(RequestCancellation(tx), CanceledRequests(rx))
|
||||
}
|
||||
|
||||
impl RequestCancellation {
|
||||
/// Cancels the request with ID `request_id`.
|
||||
fn cancel(&mut self, request_id: u64) {
|
||||
let _ = self.0.unbounded_send(request_id);
|
||||
let _ = self.0.send(request_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl CanceledRequests {
|
||||
fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
||||
self.0.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -522,184 +542,7 @@ impl Stream for CanceledRequests {
|
||||
type Item = u64;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
||||
self.0.poll_next_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
struct MapErrConnectionReset<Fut> {
|
||||
#[pin]
|
||||
future: Fut,
|
||||
finished: Option<()>,
|
||||
}
|
||||
|
||||
impl<Fut> MapErrConnectionReset<Fut> {
|
||||
fn new(future: Fut) -> MapErrConnectionReset<Fut> {
|
||||
MapErrConnectionReset {
|
||||
future,
|
||||
finished: Some(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fut> Future for MapErrConnectionReset<Fut>
|
||||
where
|
||||
Fut: TryFuture,
|
||||
{
|
||||
type Output = io::Result<Fut::Ok>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.as_mut().project().future.try_poll(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(result) => {
|
||||
self.project().finished.take().expect(
|
||||
"MapErrConnectionReset must not be polled after it returned `Poll::Ready`",
|
||||
);
|
||||
Poll::Ready(result.map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
struct MapOkDispatchResponse<Fut, Resp> {
|
||||
#[pin]
|
||||
future: Fut,
|
||||
response: Option<DispatchResponse<Resp>>,
|
||||
}
|
||||
|
||||
impl<Fut, Resp> MapOkDispatchResponse<Fut, Resp> {
|
||||
fn new(future: Fut, response: DispatchResponse<Resp>) -> MapOkDispatchResponse<Fut, Resp> {
|
||||
MapOkDispatchResponse {
|
||||
future,
|
||||
response: Some(response),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fut, Resp> Future for MapOkDispatchResponse<Fut, Resp>
|
||||
where
|
||||
Fut: TryFuture,
|
||||
{
|
||||
type Output = Result<DispatchResponse<Resp>, Fut::Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.as_mut().project().future.try_poll(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(result) => {
|
||||
let response = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.response
|
||||
.take()
|
||||
.expect("MapOk must not be polled after it returned `Poll::Ready`");
|
||||
Poll::Ready(result.map(|_| response))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
struct AndThenIdent<Fut1, Fut2> {
|
||||
#[pin]
|
||||
try_chain: TryChain<Fut1, Fut2>,
|
||||
}
|
||||
|
||||
impl<Fut1, Fut2> AndThenIdent<Fut1, Fut2>
|
||||
where
|
||||
Fut1: TryFuture<Ok = Fut2>,
|
||||
Fut2: TryFuture,
|
||||
{
|
||||
/// Creates a new `Then`.
|
||||
fn new(future: Fut1) -> AndThenIdent<Fut1, Fut2> {
|
||||
AndThenIdent {
|
||||
try_chain: TryChain::new(future),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fut1, Fut2> Future for AndThenIdent<Fut1, Fut2>
|
||||
where
|
||||
Fut1: TryFuture<Ok = Fut2>,
|
||||
Fut2: TryFuture<Error = Fut1::Error>,
|
||||
{
|
||||
type Output = Result<Fut2::Ok, Fut2::Error>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.project().try_chain.poll(cx, |result| match result {
|
||||
Ok(ok) => TryChainAction::Future(ok),
|
||||
Err(err) => TryChainAction::Output(Err(err)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project(project = TryChainProj)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
#[derive(Debug)]
|
||||
enum TryChain<Fut1, Fut2> {
|
||||
First(#[pin] Fut1),
|
||||
Second(#[pin] Fut2),
|
||||
Empty,
|
||||
}
|
||||
|
||||
enum TryChainAction<Fut2>
|
||||
where
|
||||
Fut2: TryFuture,
|
||||
{
|
||||
Future(Fut2),
|
||||
Output(Result<Fut2::Ok, Fut2::Error>),
|
||||
}
|
||||
|
||||
impl<Fut1, Fut2> TryChain<Fut1, Fut2>
|
||||
where
|
||||
Fut1: TryFuture<Ok = Fut2>,
|
||||
Fut2: TryFuture,
|
||||
{
|
||||
fn new(fut1: Fut1) -> TryChain<Fut1, Fut2> {
|
||||
TryChain::First(fut1)
|
||||
}
|
||||
|
||||
fn poll<F>(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
f: F,
|
||||
) -> Poll<Result<Fut2::Ok, Fut2::Error>>
|
||||
where
|
||||
F: FnOnce(Result<Fut1::Ok, Fut1::Error>) -> TryChainAction<Fut2>,
|
||||
{
|
||||
let mut f = Some(f);
|
||||
|
||||
loop {
|
||||
let output = match self.as_mut().project() {
|
||||
TryChainProj::First(fut1) => {
|
||||
// Poll the first future
|
||||
match fut1.try_poll(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(output) => output,
|
||||
}
|
||||
}
|
||||
TryChainProj::Second(fut2) => {
|
||||
// Poll the second future
|
||||
return fut2.try_poll(cx);
|
||||
}
|
||||
TryChainProj::Empty => {
|
||||
panic!("future must not be polled after it returned `Poll::Ready`");
|
||||
}
|
||||
};
|
||||
|
||||
self.set(TryChain::Empty); // Drop fut1
|
||||
let f = f.take().unwrap();
|
||||
match f(output) {
|
||||
TryChainAction::Future(fut2) => self.set(TryChain::Second(fut2)),
|
||||
TryChainAction::Output(output) => return Poll::Ready(output),
|
||||
}
|
||||
}
|
||||
self.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -710,18 +553,14 @@ mod tests {
|
||||
RequestDispatch,
|
||||
};
|
||||
use crate::{
|
||||
client::Config,
|
||||
client::{in_flight_requests::InFlightRequests, Config},
|
||||
context,
|
||||
transport::{self, channel::UnboundedChannel},
|
||||
ClientMessage, Response,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
prelude::*,
|
||||
task::*,
|
||||
};
|
||||
use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc};
|
||||
use futures::{prelude::*, task::*};
|
||||
use std::{pin::Pin, sync::atomic::AtomicUsize, sync::Arc};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_response_cancels_on_drop() {
|
||||
@@ -729,13 +568,37 @@ mod tests {
|
||||
let (_, response) = oneshot::channel();
|
||||
drop(DispatchResponse::<u32> {
|
||||
response,
|
||||
cancellation,
|
||||
complete: false,
|
||||
cancellation: Some(cancellation),
|
||||
request_id: 3,
|
||||
ctx: context::current(),
|
||||
});
|
||||
// resp's drop() is run, which should send a cancel message.
|
||||
assert_eq!(canceled_requests.0.try_next().unwrap(), Some(3));
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(Some(3)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_response_doesnt_cancel_after_complete() {
|
||||
let (cancellation, mut canceled_requests) = cancellations();
|
||||
let (tx, response) = oneshot::channel();
|
||||
tx.send(Response {
|
||||
request_id: 0,
|
||||
message: Ok("well done"),
|
||||
})
|
||||
.unwrap();
|
||||
{
|
||||
DispatchResponse {
|
||||
response,
|
||||
cancellation: Some(cancellation),
|
||||
request_id: 3,
|
||||
ctx: context::current(),
|
||||
}
|
||||
.await
|
||||
.unwrap();
|
||||
// resp's drop() is run, but should not send a cancel message.
|
||||
}
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(None));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -800,7 +663,7 @@ mod tests {
|
||||
let req = send_request(&mut channel, "hi").await;
|
||||
|
||||
assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
|
||||
assert!(!dispatch.as_mut().project().in_flight_requests.is_empty());
|
||||
assert!(!dispatch.in_flight_requests().is_empty());
|
||||
|
||||
// Test that a request future dropped after it's processed by dispatch will cause the request
|
||||
// to be removed from the in-flight request map.
|
||||
@@ -810,7 +673,7 @@ mod tests {
|
||||
} else {
|
||||
panic!("Expected request to be cancelled")
|
||||
};
|
||||
assert!(dispatch.project().in_flight_requests.is_empty());
|
||||
assert!(dispatch.in_flight_requests().is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -836,14 +699,14 @@ mod tests {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
||||
let (cancel_tx, canceled_requests) = mpsc::unbounded();
|
||||
let (cancel_tx, canceled_requests) = mpsc::unbounded_channel();
|
||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||
|
||||
let dispatch = RequestDispatch::<String, String, _> {
|
||||
transport: client_channel.fuse(),
|
||||
pending_requests: pending_requests.fuse(),
|
||||
canceled_requests: CanceledRequests(canceled_requests).fuse(),
|
||||
in_flight_requests: FnvHashMap::default(),
|
||||
pending_requests: pending_requests,
|
||||
canceled_requests: CanceledRequests(canceled_requests),
|
||||
in_flight_requests: InFlightRequests::default(),
|
||||
config: Config::default(),
|
||||
};
|
||||
|
||||
@@ -851,7 +714,7 @@ mod tests {
|
||||
let channel = Channel {
|
||||
to_dispatch,
|
||||
cancellation,
|
||||
next_request_id: Arc::new(AtomicU64::new(0)),
|
||||
next_request_id: Arc::new(AtomicUsize::new(0)),
|
||||
};
|
||||
|
||||
(dispatch, channel, server_channel)
|
||||
@@ -890,7 +753,7 @@ mod tests {
|
||||
match self {
|
||||
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
|
||||
Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
@@ -899,7 +762,7 @@ mod tests {
|
||||
match self {
|
||||
Poll::Ready(Some(Ok(t))) => Some(t),
|
||||
Poll::Ready(None) => None,
|
||||
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
|
||||
Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()),
|
||||
Poll::Pending => panic!("Pending"),
|
||||
}
|
||||
}
|
||||
162
tarpc/src/client/in_flight_requests.rs
Normal file
162
tarpc/src/client/in_flight_requests.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
use crate::{
|
||||
context,
|
||||
util::{Compact, TimeUntil},
|
||||
PollIo, Response, ServerError,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::ready;
|
||||
use log::{debug, trace};
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
io,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_util::time::delay_queue::{self, DelayQueue};
|
||||
|
||||
/// Requests already written to the wire that haven't yet received responses.
|
||||
#[derive(Debug)]
|
||||
pub struct InFlightRequests<Resp> {
|
||||
request_data: FnvHashMap<u64, RequestData<Resp>>,
|
||||
deadlines: DelayQueue<u64>,
|
||||
}
|
||||
|
||||
impl<Resp> Default for InFlightRequests<Resp> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
request_data: Default::default(),
|
||||
deadlines: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RequestData<Resp> {
|
||||
ctx: context::Context,
|
||||
response_completion: oneshot::Sender<Response<Resp>>,
|
||||
/// The key to remove the timer for the request's deadline.
|
||||
deadline_key: delay_queue::Key,
|
||||
}
|
||||
|
||||
/// An error returned when an attempt is made to insert a request with an ID that is already in
|
||||
/// use.
|
||||
#[derive(Debug)]
|
||||
pub struct AlreadyExistsError;
|
||||
|
||||
impl<Resp> InFlightRequests<Resp> {
|
||||
/// Returns the number of in-flight requests.
|
||||
pub fn len(&self) -> usize {
|
||||
self.request_data.len()
|
||||
}
|
||||
|
||||
/// Returns true iff there are no requests in flight.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.request_data.is_empty()
|
||||
}
|
||||
|
||||
/// Starts a request, unless a request with the same ID is already in flight.
|
||||
pub fn insert_request(
|
||||
&mut self,
|
||||
request_id: u64,
|
||||
ctx: context::Context,
|
||||
response_completion: oneshot::Sender<Response<Resp>>,
|
||||
) -> Result<(), AlreadyExistsError> {
|
||||
match self.request_data.entry(request_id) {
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
let timeout = ctx.deadline.time_until();
|
||||
trace!(
|
||||
"[{}] Queuing request with timeout {:?}.",
|
||||
ctx.trace_id(),
|
||||
timeout,
|
||||
);
|
||||
|
||||
let deadline_key = self.deadlines.insert(request_id, timeout);
|
||||
vacant.insert(RequestData {
|
||||
ctx,
|
||||
response_completion,
|
||||
deadline_key,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
hash_map::Entry::Occupied(_) => Err(AlreadyExistsError),
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a request without aborting. Returns true iff the request was found.
|
||||
pub fn complete_request(&mut self, response: Response<Resp>) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&response.request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
|
||||
trace!("[{}] Received response.", request_data.ctx.trace_id());
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
request_data.complete(response);
|
||||
return true;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"No in-flight request found for request_id = {}.",
|
||||
response.request_id
|
||||
);
|
||||
|
||||
// If the response completion was absent, then the request was already canceled.
|
||||
false
|
||||
}
|
||||
|
||||
/// Cancels a request without completing (typically used when a request handle was dropped
|
||||
/// before the request completed).
|
||||
pub fn cancel_request(&mut self, request_id: u64) -> Option<context::Context> {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
trace!("[{}] Cancelling request.", request_data.ctx.trace_id());
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
Some(request_data.ctx)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Yields a request that has expired, completing it with a TimedOut error.
|
||||
/// The caller should send cancellation messages for any yielded request ID.
|
||||
pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo<u64> {
|
||||
Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) {
|
||||
Some(Ok(expired)) => {
|
||||
let request_id = expired.into_inner();
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
request_data.complete(Self::deadline_exceeded_error(request_id));
|
||||
}
|
||||
Some(Ok(request_id))
|
||||
}
|
||||
Some(Err(e)) => Some(Err(io::Error::new(io::ErrorKind::Other, e))),
|
||||
None => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn deadline_exceeded_error(request_id: u64) -> Response<Resp> {
|
||||
Response {
|
||||
request_id,
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::TimedOut,
|
||||
detail: Some("Client dropped expired request.".to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// When InFlightRequests is dropped, any outstanding requests are completed with a
|
||||
/// deadline-exceeded error.
|
||||
impl<Resp> Drop for InFlightRequests<Resp> {
|
||||
fn drop(&mut self) {
|
||||
let deadlines = &mut self.deadlines;
|
||||
for (_, request_data) in self.request_data.drain() {
|
||||
let expired = deadlines.remove(&request_data.deadline_key);
|
||||
request_data.complete(Self::deadline_exceeded_error(expired.into_inner()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Resp> RequestData<Resp> {
|
||||
fn complete(self, response: Response<Resp>) {
|
||||
let _ = self.response_completion.send(response);
|
||||
}
|
||||
}
|
||||
167
tarpc/src/lib.rs
167
tarpc/src/lib.rs
@@ -3,7 +3,6 @@
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! *Disclaimer*: This is not an official Google product.
|
||||
//!
|
||||
//! tarpc is an RPC framework for rust with a focus on ease of use. Defining a
|
||||
@@ -47,7 +46,7 @@
|
||||
//! Add to your `Cargo.toml` dependencies:
|
||||
//!
|
||||
//! ```toml
|
||||
//! tarpc = "0.23.0"
|
||||
//! tarpc = "0.25"
|
||||
//! ```
|
||||
//!
|
||||
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
@@ -56,12 +55,13 @@
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! For this example, in addition to tarpc, also add two other dependencies to
|
||||
//! This example uses [tokio](https://tokio.rs), so add the following dependencies to
|
||||
//! your `Cargo.toml`:
|
||||
//!
|
||||
//! ```toml
|
||||
//! futures = "0.3"
|
||||
//! tokio = "0.3"
|
||||
//! futures = "1.0"
|
||||
//! tarpc = { version = "0.25", features = ["tokio1"] }
|
||||
//! tokio = { version = "1.0", features = ["macros"] }
|
||||
//! ```
|
||||
//!
|
||||
//! In the following example, we use an in-process channel for communication between
|
||||
@@ -79,7 +79,7 @@
|
||||
//! };
|
||||
//! use tarpc::{
|
||||
//! client, context,
|
||||
//! server::{self, Handler},
|
||||
//! server::{self, Incoming},
|
||||
//! };
|
||||
//! use std::io;
|
||||
//!
|
||||
@@ -103,7 +103,7 @@
|
||||
//! # };
|
||||
//! # use tarpc::{
|
||||
//! # client, context,
|
||||
//! # server::{self, Handler},
|
||||
//! # server::{self, Incoming},
|
||||
//! # };
|
||||
//! # use std::io;
|
||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||
@@ -131,7 +131,7 @@
|
||||
//! ```
|
||||
//!
|
||||
//! Lastly let's write our `main` that will start the server. While this example uses an
|
||||
//! [in-process channel](rpc::transport::channel), tarpc also ships a generic [`serde_transport`]
|
||||
//! [in-process channel](transport::channel), tarpc also ships a generic [`serde_transport`]
|
||||
//! behind the `serde-transport` feature, with additional [TCP](serde_transport::tcp) functionality
|
||||
//! available behind the `tcp` feature.
|
||||
//!
|
||||
@@ -143,7 +143,7 @@
|
||||
//! # };
|
||||
//! # use tarpc::{
|
||||
//! # client, context,
|
||||
//! # server::{self, Handler},
|
||||
//! # server::{self, Channel},
|
||||
//! # };
|
||||
//! # use std::io;
|
||||
//! # // This is the service definition. It looks a lot like a trait definition.
|
||||
@@ -165,20 +165,18 @@
|
||||
//! # future::ready(format!("Hello, {}!", name))
|
||||
//! # }
|
||||
//! # }
|
||||
//! # #[cfg(not(feature = "tokio1"))]
|
||||
//! # fn main() {}
|
||||
//! # #[cfg(feature = "tokio1")]
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> io::Result<()> {
|
||||
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
//!
|
||||
//! let server = server::new(server::Config::default())
|
||||
//! // incoming() takes a stream of transports such as would be returned by
|
||||
//! // TcpListener::incoming (but a stream instead of an iterator).
|
||||
//! .incoming(stream::once(future::ready(server_transport)))
|
||||
//! .respond_with(HelloServer.serve());
|
||||
//! let server = server::BaseChannel::with_defaults(server_transport);
|
||||
//! tokio::spawn(server.execute(HelloServer.serve()));
|
||||
//!
|
||||
//! tokio::spawn(server);
|
||||
//!
|
||||
//! // WorldClient is generated by the macro. It has a constructor `new` that takes a config and
|
||||
//! // any Transport as input
|
||||
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
//! // that takes a config and any Transport as input.
|
||||
//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?;
|
||||
//!
|
||||
//! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
@@ -200,8 +198,11 @@
|
||||
#![allow(clippy::type_complexity)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
pub mod rpc;
|
||||
pub use rpc::*;
|
||||
#[cfg(feature = "serde1")]
|
||||
pub use serde;
|
||||
|
||||
#[cfg(feature = "serde-transport")]
|
||||
pub use tokio_serde;
|
||||
|
||||
#[cfg(feature = "serde-transport")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "serde-transport")))]
|
||||
@@ -209,6 +210,9 @@ pub mod serde_transport;
|
||||
|
||||
pub mod trace;
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
pub use tarpc_plugins::derive_serde;
|
||||
|
||||
/// The main macro that creates RPC services.
|
||||
///
|
||||
/// Rpc methods are specified, mirroring trait syntax:
|
||||
@@ -288,3 +292,126 @@ pub use tarpc_plugins::service;
|
||||
/// Note that this won't touch functions unless they have been annotated with
|
||||
/// `async`, meaning that this should not break existing code.
|
||||
pub use tarpc_plugins::server;
|
||||
|
||||
pub mod client;
|
||||
pub mod context;
|
||||
pub mod server;
|
||||
pub mod transport;
|
||||
pub(crate) mod util;
|
||||
|
||||
pub use crate::transport::sealed::Transport;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use futures::task::*;
|
||||
use std::{fmt::Display, io, time::SystemTime};
|
||||
|
||||
/// A message from a client to a server.
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[non_exhaustive]
|
||||
pub enum ClientMessage<T> {
|
||||
/// A request initiated by a user. The server responds to a request by invoking a
|
||||
/// service-provided request handler. The handler completes with a [`response`](Response), which
|
||||
/// the server sends back to the client.
|
||||
Request(Request<T>),
|
||||
/// A command to cancel an in-flight request, automatically sent by the client when a response
|
||||
/// future is dropped.
|
||||
///
|
||||
/// When received, the server will immediately cancel the main task (top-level future) of the
|
||||
/// request handler for the associated request. Any tasks spawned by the request handler will
|
||||
/// not be canceled, because the framework layer does not
|
||||
/// know about them.
|
||||
Cancel {
|
||||
/// The trace context associates the message with a specific chain of causally-related actions,
|
||||
/// possibly orchestrated across many distributed systems.
|
||||
#[cfg_attr(feature = "serde1", serde(default))]
|
||||
trace_context: trace::Context,
|
||||
/// The ID of the request to cancel.
|
||||
request_id: u64,
|
||||
},
|
||||
}
|
||||
|
||||
/// A request from a client to a server.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Request<T> {
|
||||
/// Trace context, deadline, and other cross-cutting concerns.
|
||||
pub context: context::Context,
|
||||
/// Uniquely identifies the request across all requests sent over a single channel.
|
||||
pub id: u64,
|
||||
/// The request body.
|
||||
pub message: T,
|
||||
}
|
||||
|
||||
/// A response from a server to a client.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Response<T> {
|
||||
/// The ID of the request being responded to.
|
||||
pub request_id: u64,
|
||||
/// The response body, or an error if the request failed.
|
||||
pub message: Result<T, ServerError>,
|
||||
}
|
||||
|
||||
/// An error response from a server to a client.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ServerError {
|
||||
#[cfg_attr(
|
||||
feature = "serde1",
|
||||
serde(serialize_with = "util::serde::serialize_io_error_kind_as_u32")
|
||||
)]
|
||||
#[cfg_attr(
|
||||
feature = "serde1",
|
||||
serde(deserialize_with = "util::serde::deserialize_io_error_kind_from_u32")
|
||||
)]
|
||||
/// The type of error that occurred to fail the request.
|
||||
pub kind: io::ErrorKind,
|
||||
/// A message describing more detail about the error that occurred.
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
impl From<ServerError> for io::Error {
|
||||
fn from(e: ServerError) -> io::Error {
|
||||
io::Error::new(e.kind, e.detail.unwrap_or_default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Request<T> {
|
||||
/// Returns the deadline for this request.
|
||||
pub fn deadline(&self) -> &SystemTime {
|
||||
&self.context.deadline
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) type PollIo<T> = Poll<Option<io::Result<T>>>;
|
||||
pub(crate) trait PollContext<T> {
|
||||
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static;
|
||||
|
||||
fn with_context<C, F>(self, f: F) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
F: FnOnce() -> C;
|
||||
}
|
||||
|
||||
impl<T> PollContext<T> for PollIo<T> {
|
||||
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
{
|
||||
self.map(|o| o.map(|r| r.context(context)))
|
||||
}
|
||||
|
||||
fn with_context<C, F>(self, f: F) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
F: FnOnce() -> C,
|
||||
{
|
||||
self.map(|o| o.map(|r| r.with_context(f)))
|
||||
}
|
||||
}
|
||||
|
||||
148
tarpc/src/rpc.rs
148
tarpc/src/rpc.rs
@@ -1,148 +0,0 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
#![deny(missing_docs, missing_debug_implementations)]
|
||||
|
||||
//! An RPC framework providing client and server.
|
||||
//!
|
||||
//! Features:
|
||||
//! * RPC deadlines, both client- and server-side.
|
||||
//! * Cascading cancellation (works with multiple hops).
|
||||
//! * Configurable limits
|
||||
//! * In-flight requests, both client and server-side.
|
||||
//! * Server-side limit is per-connection.
|
||||
//! * When the server reaches the in-flight request maximum, it returns a throttled error
|
||||
//! to the client.
|
||||
//! * When the client reaches the in-flight request max, messages are buffered up to a
|
||||
//! configurable maximum, beyond which the requests are back-pressured.
|
||||
//! * Server connections.
|
||||
//! * Total and per-IP limits.
|
||||
//! * When an incoming connection is accepted, if already at maximum, the connection is
|
||||
//! dropped.
|
||||
//! * Transport agnostic.
|
||||
|
||||
pub mod client;
|
||||
pub mod context;
|
||||
pub mod server;
|
||||
pub mod transport;
|
||||
pub(crate) mod util;
|
||||
|
||||
pub use crate::{client::Client, server::Server, trace, transport::sealed::Transport};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use futures::task::*;
|
||||
use std::{fmt::Display, io, time::SystemTime};
|
||||
|
||||
/// A message from a client to a server.
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[non_exhaustive]
|
||||
pub enum ClientMessage<T> {
|
||||
/// A request initiated by a user. The server responds to a request by invoking a
|
||||
/// service-provided request handler. The handler completes with a [`response`](Response), which
|
||||
/// the server sends back to the client.
|
||||
Request(Request<T>),
|
||||
/// A command to cancel an in-flight request, automatically sent by the client when a response
|
||||
/// future is dropped.
|
||||
///
|
||||
/// When received, the server will immediately cancel the main task (top-level future) of the
|
||||
/// request handler for the associated request. Any tasks spawned by the request handler will
|
||||
/// not be canceled, because the framework layer does not
|
||||
/// know about them.
|
||||
Cancel {
|
||||
/// The trace context associates the message with a specific chain of causally-related actions,
|
||||
/// possibly orchestrated across many distributed systems.
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
trace_context: trace::Context,
|
||||
/// The ID of the request to cancel.
|
||||
request_id: u64,
|
||||
},
|
||||
}
|
||||
|
||||
/// A request from a client to a server.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Request<T> {
|
||||
/// Trace context, deadline, and other cross-cutting concerns.
|
||||
pub context: context::Context,
|
||||
/// Uniquely identifies the request across all requests sent over a single channel.
|
||||
pub id: u64,
|
||||
/// The request body.
|
||||
pub message: T,
|
||||
}
|
||||
|
||||
/// A response from a server to a client.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Response<T> {
|
||||
/// The ID of the request being responded to.
|
||||
pub request_id: u64,
|
||||
/// The response body, or an error if the request failed.
|
||||
pub message: Result<T, ServerError>,
|
||||
}
|
||||
|
||||
/// An error response from a server to a client.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ServerError {
|
||||
#[cfg_attr(
|
||||
feature = "serde1",
|
||||
serde(serialize_with = "util::serde::serialize_io_error_kind_as_u32")
|
||||
)]
|
||||
#[cfg_attr(
|
||||
feature = "serde1",
|
||||
serde(deserialize_with = "util::serde::deserialize_io_error_kind_from_u32")
|
||||
)]
|
||||
/// The type of error that occurred to fail the request.
|
||||
pub kind: io::ErrorKind,
|
||||
/// A message describing more detail about the error that occurred.
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
impl From<ServerError> for io::Error {
|
||||
fn from(e: ServerError) -> io::Error {
|
||||
io::Error::new(e.kind, e.detail.unwrap_or_default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Request<T> {
|
||||
/// Returns the deadline for this request.
|
||||
pub fn deadline(&self) -> &SystemTime {
|
||||
&self.context.deadline
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) type PollIo<T> = Poll<Option<io::Result<T>>>;
|
||||
pub(crate) trait PollContext<T> {
|
||||
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static;
|
||||
|
||||
fn with_context<C, F>(self, f: F) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
F: FnOnce() -> C;
|
||||
}
|
||||
|
||||
impl<T> PollContext<T> for PollIo<T> {
|
||||
fn context<C>(self, context: C) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
{
|
||||
self.map(|o| o.map(|r| r.context(context)))
|
||||
}
|
||||
|
||||
fn with_context<C, F>(self, f: F) -> Poll<Option<anyhow::Result<T>>>
|
||||
where
|
||||
C: Display + Send + Sync + 'static,
|
||||
F: FnOnce() -> C,
|
||||
{
|
||||
self.map(|o| o.map(|r| r.with_context(f)))
|
||||
}
|
||||
}
|
||||
@@ -1,155 +0,0 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Provides a client that connects to a server and sends multiplexed requests.
|
||||
|
||||
use crate::context;
|
||||
use futures::prelude::*;
|
||||
use std::io;
|
||||
|
||||
/// Provides a [`Client`] backed by a transport.
|
||||
pub mod channel;
|
||||
pub use channel::{new, Channel};
|
||||
|
||||
/// Sends multiplexed requests to, and receives responses from, a server.
|
||||
pub trait Client<'a, Req> {
|
||||
/// The response type.
|
||||
type Response;
|
||||
|
||||
/// The future response.
|
||||
type Future: Future<Output = io::Result<Self::Response>> + 'a;
|
||||
|
||||
/// Initiates a request, sending it to the dispatch task.
|
||||
///
|
||||
/// Returns a [`Future`] that resolves to this client and the future response
|
||||
/// once the request is successfully enqueued.
|
||||
///
|
||||
/// [`Future`]: futures::Future
|
||||
fn call(&'a mut self, ctx: context::Context, request: Req) -> Self::Future;
|
||||
|
||||
/// Returns a Client that applies a post-processing function to the returned response.
|
||||
fn map_response<F, R>(self, f: F) -> MapResponse<Self, F>
|
||||
where
|
||||
F: FnMut(Self::Response) -> R,
|
||||
Self: Sized,
|
||||
{
|
||||
MapResponse { inner: self, f }
|
||||
}
|
||||
|
||||
/// Returns a Client that applies a pre-processing function to the request.
|
||||
fn with_request<F, Req2>(self, f: F) -> WithRequest<Self, F>
|
||||
where
|
||||
F: FnMut(Req2) -> Req,
|
||||
Self: Sized,
|
||||
{
|
||||
WithRequest { inner: self, f }
|
||||
}
|
||||
}
|
||||
|
||||
/// A Client that applies a function to the returned response.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MapResponse<C, F> {
|
||||
inner: C,
|
||||
f: F,
|
||||
}
|
||||
|
||||
impl<'a, C, F, Req, Resp, Resp2> Client<'a, Req> for MapResponse<C, F>
|
||||
where
|
||||
C: Client<'a, Req, Response = Resp>,
|
||||
F: FnMut(Resp) -> Resp2 + 'a,
|
||||
{
|
||||
type Response = Resp2;
|
||||
type Future = futures::future::MapOk<<C as Client<'a, Req>>::Future, &'a mut F>;
|
||||
|
||||
fn call(&'a mut self, ctx: context::Context, request: Req) -> Self::Future {
|
||||
self.inner.call(ctx, request).map_ok(&mut self.f)
|
||||
}
|
||||
}
|
||||
|
||||
/// A Client that applies a pre-processing function to the request.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct WithRequest<C, F> {
|
||||
inner: C,
|
||||
f: F,
|
||||
}
|
||||
|
||||
impl<'a, C, F, Req, Req2, Resp> Client<'a, Req2> for WithRequest<C, F>
|
||||
where
|
||||
C: Client<'a, Req, Response = Resp>,
|
||||
F: FnMut(Req2) -> Req,
|
||||
{
|
||||
type Response = Resp;
|
||||
type Future = <C as Client<'a, Req>>::Future;
|
||||
|
||||
fn call(&'a mut self, ctx: context::Context, request: Req2) -> Self::Future {
|
||||
self.inner.call(ctx, (self.f)(request))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, Req, Resp> Client<'a, Req> for Channel<Req, Resp>
|
||||
where
|
||||
Req: 'a,
|
||||
Resp: 'a,
|
||||
{
|
||||
type Response = Resp;
|
||||
type Future = channel::Call<'a, Req, Resp>;
|
||||
|
||||
fn call(&'a mut self, ctx: context::Context, request: Req) -> channel::Call<'a, Req, Resp> {
|
||||
self.call(ctx, request)
|
||||
}
|
||||
}
|
||||
|
||||
/// Settings that control the behavior of the client.
|
||||
#[derive(Clone, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub struct Config {
|
||||
/// The number of requests that can be in flight at once.
|
||||
/// `max_in_flight_requests` controls the size of the map used by the client
|
||||
/// for storing pending requests.
|
||||
pub max_in_flight_requests: usize,
|
||||
/// The number of requests that can be buffered client-side before being sent.
|
||||
/// `pending_requests_buffer` controls the size of the channel clients use
|
||||
/// to communicate with the request dispatch task.
|
||||
pub pending_request_buffer: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Config {
|
||||
max_in_flight_requests: 1_000,
|
||||
pending_request_buffer: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A channel and dispatch pair. The dispatch drives the sending and receiving of requests
|
||||
/// and must be polled continuously or spawned.
|
||||
#[derive(Debug)]
|
||||
pub struct NewClient<C, D> {
|
||||
/// The new client.
|
||||
pub client: C,
|
||||
/// The client's dispatch.
|
||||
pub dispatch: D,
|
||||
}
|
||||
|
||||
impl<C, D, E> NewClient<C, D>
|
||||
where
|
||||
D: Future<Output = Result<(), E>> + Send + 'static,
|
||||
E: std::fmt::Display,
|
||||
{
|
||||
/// Helper method to spawn the dispatch on the default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub fn spawn(self) -> io::Result<C> {
|
||||
use log::error;
|
||||
|
||||
let dispatch = self
|
||||
.dispatch
|
||||
.unwrap_or_else(move |e| error!("Connection broken: {}", e));
|
||||
tokio::spawn(dispatch);
|
||||
Ok(self.client)
|
||||
}
|
||||
}
|
||||
@@ -1,703 +0,0 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Provides a server that concurrently handles many connections sending multiplexed requests.
|
||||
|
||||
use crate::{
|
||||
context, trace, util::Compact, util::TimeUntil, ClientMessage, PollIo, Request, Response,
|
||||
ServerError, Transport,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{
|
||||
channel::mpsc,
|
||||
future::{AbortHandle, AbortRegistration, Abortable},
|
||||
prelude::*,
|
||||
ready,
|
||||
stream::Fuse,
|
||||
task::*,
|
||||
};
|
||||
use humantime::format_rfc3339;
|
||||
use log::{debug, trace};
|
||||
use pin_project::pin_project;
|
||||
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
|
||||
use tokio::time::Timeout;
|
||||
|
||||
mod filter;
|
||||
#[cfg(test)]
|
||||
mod testing;
|
||||
mod throttle;
|
||||
|
||||
pub use self::{
|
||||
filter::ChannelFilter,
|
||||
throttle::{Throttler, ThrottlerStream},
|
||||
};
|
||||
|
||||
/// Manages clients, serving multiplexed requests over each connection.
|
||||
#[derive(Debug)]
|
||||
pub struct Server<Req, Resp> {
|
||||
config: Config,
|
||||
ghost: PhantomData<(Req, Resp)>,
|
||||
}
|
||||
|
||||
impl<Req, Resp> Default for Server<Req, Resp> {
|
||||
fn default() -> Self {
|
||||
new(Config::default())
|
||||
}
|
||||
}
|
||||
|
||||
/// Settings that control the behavior of the server.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Config {
|
||||
/// The number of responses per client that can be buffered server-side before being sent.
|
||||
/// `pending_response_buffer` controls the buffer size of the channel that a server's
|
||||
/// response tasks use to send responses to the client handler task.
|
||||
pub pending_response_buffer: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Config {
|
||||
pending_response_buffer: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Returns a channel backed by `transport` and configured with `self`.
|
||||
pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
{
|
||||
BaseChannel::new(self, transport)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new server with configuration specified `config`.
|
||||
pub fn new<Req, Resp>(config: Config) -> Server<Req, Resp> {
|
||||
Server {
|
||||
config,
|
||||
ghost: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> Server<Req, Resp> {
|
||||
/// Returns the config for this server.
|
||||
pub fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Returns a stream of server channels.
|
||||
pub fn incoming<S, T>(self, listener: S) -> impl Stream<Item = BaseChannel<Req, Resp, T>>
|
||||
where
|
||||
S: Stream<Item = T>,
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
{
|
||||
listener.map(move |t| BaseChannel::new(self.config.clone(), t))
|
||||
}
|
||||
}
|
||||
|
||||
/// Basically a Fn(Req) -> impl Future<Output = Resp>;
|
||||
pub trait Serve<Req>: Sized + Clone {
|
||||
/// Type of response.
|
||||
type Resp;
|
||||
|
||||
/// Type of response future.
|
||||
type Fut: Future<Output = Self::Resp>;
|
||||
|
||||
/// Responds to a single request.
|
||||
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut;
|
||||
}
|
||||
|
||||
impl<Req, Resp, Fut, F> Serve<Req> for F
|
||||
where
|
||||
F: FnOnce(context::Context, Req) -> Fut + Clone,
|
||||
Fut: Future<Output = Resp>,
|
||||
{
|
||||
type Resp = Resp;
|
||||
type Fut = Fut;
|
||||
|
||||
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut {
|
||||
self(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
/// A utility trait enabling a stream to fluently chain a request handler.
|
||||
pub trait Handler<C>
|
||||
where
|
||||
Self: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
/// Enforces channel per-key limits.
|
||||
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> filter::ChannelFilter<Self, K, KF>
|
||||
where
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
KF: Fn(&C) -> K,
|
||||
{
|
||||
ChannelFilter::new(self, n, keymaker)
|
||||
}
|
||||
|
||||
/// Caps the number of concurrent requests per channel.
|
||||
fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream<Self> {
|
||||
ThrottlerStream::new(self, n)
|
||||
}
|
||||
|
||||
/// Responds to all requests with [`server::serve`](Serve).
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
fn respond_with<S>(self, server: S) -> Running<Self, S>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
{
|
||||
Running {
|
||||
incoming: self,
|
||||
server,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C> Handler<C> for S
|
||||
where
|
||||
S: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
}
|
||||
|
||||
/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct BaseChannel<Req, Resp, T> {
|
||||
config: Config,
|
||||
/// Writes responses to the wire and reads requests off the wire.
|
||||
#[pin]
|
||||
transport: Fuse<T>,
|
||||
/// Number of requests currently being responded to.
|
||||
in_flight_requests: FnvHashMap<u64, AbortHandle>,
|
||||
/// Types the request and response.
|
||||
ghost: PhantomData<(Req, Resp)>,
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
{
|
||||
/// Creates a new channel backed by `transport` and configured with `config`.
|
||||
pub fn new(config: Config, transport: T) -> Self {
|
||||
BaseChannel {
|
||||
config,
|
||||
transport: transport.fuse(),
|
||||
in_flight_requests: FnvHashMap::default(),
|
||||
ghost: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new channel backed by `transport` and configured with the defaults.
|
||||
pub fn with_defaults(transport: T) -> Self {
|
||||
Self::new(Config::default(), transport)
|
||||
}
|
||||
|
||||
/// Returns the inner transport over which messages are sent and received.
|
||||
pub fn get_ref(&self) -> &T {
|
||||
self.transport.get_ref()
|
||||
}
|
||||
|
||||
/// Returns the inner transport over which messages are sent and received.
|
||||
pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> {
|
||||
self.project().transport.get_pin_mut()
|
||||
}
|
||||
|
||||
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
|
||||
// It's possible the request was already completed, so it's fine
|
||||
// if this is None.
|
||||
if let Some(cancel_handle) = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&request_id)
|
||||
{
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
|
||||
cancel_handle.abort();
|
||||
let remaining = self.as_mut().project().in_flight_requests.len();
|
||||
trace!(
|
||||
"[{}] Request canceled. In-flight requests = {}",
|
||||
trace_context.trace_id,
|
||||
remaining,
|
||||
);
|
||||
} else {
|
||||
trace!(
|
||||
"[{}] Received cancellation, but response handler \
|
||||
is already complete.",
|
||||
trace_context.trace_id,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The server end of an open connection with a client, streaming in requests from, and sinking
|
||||
/// responses to, the client.
|
||||
///
|
||||
/// Channels are free to somewhat rely on the assumption that all in-flight requests are eventually
|
||||
/// either [cancelled](BaseChannel::cancel_request) or [responded to](Sink::start_send). Safety cannot
|
||||
/// rely on this assumption, but it is best for `Channel` users to always account for all outstanding
|
||||
/// requests.
|
||||
pub trait Channel
|
||||
where
|
||||
Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
|
||||
{
|
||||
/// Type of request item.
|
||||
type Req;
|
||||
|
||||
/// Type of response sink item.
|
||||
type Resp;
|
||||
|
||||
/// Configuration of the channel.
|
||||
fn config(&self) -> &Config;
|
||||
|
||||
/// Returns the number of in-flight requests over this channel.
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize;
|
||||
|
||||
/// Caps the number of concurrent requests.
|
||||
fn max_concurrent_requests(self, n: usize) -> Throttler<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Throttler::new(self, n)
|
||||
}
|
||||
|
||||
/// Tells the Channel that request with ID `request_id` is being handled.
|
||||
/// The request will be tracked until a response with the same ID is sent
|
||||
/// to the Channel.
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration;
|
||||
|
||||
/// Respond to requests coming over the channel with `f`. Returns a future that drives the
|
||||
/// responses and resolves when the connection is closed.
|
||||
fn respond_with<S>(self, server: S) -> ClientHandler<Self, S>
|
||||
where
|
||||
S: Serve<Self::Req, Resp = Self::Resp>,
|
||||
Self: Sized,
|
||||
{
|
||||
let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
|
||||
let responses = responses.fuse();
|
||||
|
||||
ClientHandler {
|
||||
channel: self,
|
||||
server,
|
||||
pending_responses: responses,
|
||||
responses_tx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
{
|
||||
type Item = io::Result<Request<Req>>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
|
||||
Some(message) => match message {
|
||||
ClientMessage::Request(request) => {
|
||||
return Poll::Ready(Some(Ok(request)));
|
||||
}
|
||||
ClientMessage::Cancel {
|
||||
trace_context,
|
||||
request_id,
|
||||
} => {
|
||||
self.as_mut().cancel_request(&trace_context, request_id);
|
||||
}
|
||||
},
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
{
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().transport.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
|
||||
if self
|
||||
.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&response.request_id)
|
||||
.is_some()
|
||||
{
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
}
|
||||
|
||||
self.project().transport.start_send(response)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().transport.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project().transport.poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
|
||||
fn as_ref(&self) -> &T {
|
||||
self.transport.get_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
fn in_flight_requests(mut self: Pin<&mut Self>) -> usize {
|
||||
self.as_mut().project().in_flight_requests.len()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
let (abort_handle, abort_registration) = AbortHandle::new_pair();
|
||||
assert!(self
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.insert(request_id, abort_handle)
|
||||
.is_none());
|
||||
abort_registration
|
||||
}
|
||||
}
|
||||
|
||||
/// A running handler serving all requests coming over a channel.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ClientHandler<C, S>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
#[pin]
|
||||
channel: C,
|
||||
/// Responses waiting to be written to the wire.
|
||||
#[pin]
|
||||
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>,
|
||||
/// Handed out to request handlers to fan in responses.
|
||||
#[pin]
|
||||
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
|
||||
/// Server
|
||||
server: S,
|
||||
}
|
||||
|
||||
impl<C, S> ClientHandler<C, S>
|
||||
where
|
||||
C: Channel,
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
{
|
||||
/// Returns the inner channel over which messages are sent and received.
|
||||
pub fn get_pin_channel(self: Pin<&mut Self>) -> Pin<&mut C> {
|
||||
self.project().channel
|
||||
}
|
||||
|
||||
fn pump_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<RequestHandler<S::Fut, C::Resp>> {
|
||||
match ready!(self.as_mut().project().channel.poll_next(cx)?) {
|
||||
Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn pump_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
read_half_closed: bool,
|
||||
) -> PollIo<()> {
|
||||
match self.as_mut().poll_next_response(cx)? {
|
||||
Poll::Ready(Some((ctx, response))) => {
|
||||
trace!(
|
||||
"[{}] Staging response. In-flight requests = {}.",
|
||||
ctx.trace_id(),
|
||||
self.as_mut().project().channel.in_flight_requests(),
|
||||
);
|
||||
self.as_mut().project().channel.start_send(response)?;
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
// Shutdown can't be done before we finish pumping out remaining responses.
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Pending => {
|
||||
// No more requests to process, so flush any requests buffered in the transport.
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
|
||||
// Being here means there are no staged requests and all written responses are
|
||||
// fully flushed. So, if the read half is closed and there are no in-flight
|
||||
// requests, then we can close the write half.
|
||||
if read_half_closed && self.as_mut().project().channel.in_flight_requests() == 0 {
|
||||
Poll::Ready(None)
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_next_response(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<(context::Context, Response<C::Resp>)> {
|
||||
// Ensure there's room to write a response.
|
||||
while let Poll::Pending = self.as_mut().project().channel.poll_ready(cx)? {
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
}
|
||||
|
||||
match ready!(self.as_mut().project().pending_responses.poll_next(cx)) {
|
||||
Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
|
||||
None => {
|
||||
// This branch likely won't happen, since the ClientHandler is holding a Sender.
|
||||
Poll::Ready(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
request: Request<C::Req>,
|
||||
) -> RequestHandler<S::Fut, C::Resp> {
|
||||
let request_id = request.id;
|
||||
let deadline = request.context.deadline;
|
||||
let timeout = deadline.time_until();
|
||||
trace!(
|
||||
"[{}] Received request with deadline {} (timeout {:?}).",
|
||||
request.context.trace_id(),
|
||||
format_rfc3339(deadline),
|
||||
timeout,
|
||||
);
|
||||
let ctx = request.context;
|
||||
let request = request.message;
|
||||
|
||||
let response = self.as_mut().project().server.clone().serve(ctx, request);
|
||||
let response = Resp {
|
||||
state: RespState::PollResp,
|
||||
request_id,
|
||||
ctx,
|
||||
deadline,
|
||||
f: tokio::time::timeout(timeout, response),
|
||||
response: None,
|
||||
response_tx: self.as_mut().project().responses_tx.clone(),
|
||||
};
|
||||
let abort_registration = self.as_mut().project().channel.start_request(request_id);
|
||||
RequestHandler {
|
||||
resp: Abortable::new(response, abort_registration),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A future fulfilling a single client request.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct RequestHandler<F, R> {
|
||||
#[pin]
|
||||
resp: Abortable<Resp<F, R>>,
|
||||
}
|
||||
|
||||
impl<F, R> Future for RequestHandler<F, R>
|
||||
where
|
||||
F: Future<Output = R>,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
let _ = ready!(self.project().resp.poll(cx));
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
struct Resp<F, R> {
|
||||
state: RespState,
|
||||
request_id: u64,
|
||||
ctx: context::Context,
|
||||
deadline: SystemTime,
|
||||
#[pin]
|
||||
f: Timeout<F>,
|
||||
response: Option<Response<R>>,
|
||||
#[pin]
|
||||
response_tx: mpsc::Sender<(context::Context, Response<R>)>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
enum RespState {
|
||||
PollResp,
|
||||
PollReady,
|
||||
PollFlush,
|
||||
}
|
||||
|
||||
impl<F, R> Future for Resp<F, R>
|
||||
where
|
||||
F: Future<Output = R>,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
loop {
|
||||
match self.as_mut().project().state {
|
||||
RespState::PollResp => {
|
||||
let result = ready!(self.as_mut().project().f.poll(cx));
|
||||
*self.as_mut().project().response = Some(Response {
|
||||
request_id: self.request_id,
|
||||
message: match result {
|
||||
Ok(message) => Ok(message),
|
||||
Err(tokio::time::error::Elapsed { .. }) => {
|
||||
debug!(
|
||||
"[{}] Response did not complete before deadline of {}s.",
|
||||
self.ctx.trace_id(),
|
||||
format_rfc3339(self.deadline)
|
||||
);
|
||||
// No point in responding, since the client will have dropped the
|
||||
// request.
|
||||
Err(ServerError {
|
||||
kind: io::ErrorKind::TimedOut,
|
||||
detail: Some(format!(
|
||||
"Response did not complete before deadline of {}s.",
|
||||
format_rfc3339(self.deadline)
|
||||
)),
|
||||
})
|
||||
}
|
||||
},
|
||||
});
|
||||
*self.as_mut().project().state = RespState::PollReady;
|
||||
}
|
||||
RespState::PollReady => {
|
||||
let ready = ready!(self.as_mut().project().response_tx.poll_ready(cx));
|
||||
if ready.is_err() {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
let resp = (self.ctx, self.as_mut().project().response.take().unwrap());
|
||||
if self
|
||||
.as_mut()
|
||||
.project()
|
||||
.response_tx
|
||||
.start_send(resp)
|
||||
.is_err()
|
||||
{
|
||||
return Poll::Ready(());
|
||||
}
|
||||
*self.as_mut().project().state = RespState::PollFlush;
|
||||
}
|
||||
RespState::PollFlush => {
|
||||
let ready = ready!(self.as_mut().project().response_tx.poll_flush(cx));
|
||||
if ready.is_err() {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
return Poll::Ready(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, S> Stream for ClientHandler<C, S>
|
||||
where
|
||||
C: Channel,
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
{
|
||||
type Item = io::Result<RequestHandler<S::Fut, C::Resp>>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
let read = self.as_mut().pump_read(cx)?;
|
||||
let read_closed = matches!(read, Poll::Ready(None));
|
||||
match (read, self.as_mut().pump_write(cx, read_closed)?) {
|
||||
(Poll::Ready(None), Poll::Ready(None)) => {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
(Poll::Ready(Some(request_handler)), _) => {
|
||||
return Poll::Ready(Some(Ok(request_handler)));
|
||||
}
|
||||
(_, Poll::Ready(Some(()))) => {}
|
||||
_ => {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send + 'static execution helper methods.
|
||||
|
||||
impl<C, S> ClientHandler<C, S>
|
||||
where
|
||||
C: Channel + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
||||
S::Fut: Send + 'static,
|
||||
{
|
||||
/// Runs the client handler until completion by [spawning](tokio::spawn) each
|
||||
/// request handler onto the default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub fn execute(self) -> impl Future<Output = ()> {
|
||||
self.try_for_each(|request_handler| async {
|
||||
tokio::spawn(request_handler);
|
||||
Ok(())
|
||||
})
|
||||
.map_ok(|()| log::info!("ClientHandler finished."))
|
||||
.unwrap_or_else(|e| log::info!("ClientHandler errored out: {}", e))
|
||||
}
|
||||
}
|
||||
|
||||
/// A future that drives the server by [spawning](tokio::spawn) channels and request handlers on the default
|
||||
/// executor.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub struct Running<St, Se> {
|
||||
#[pin]
|
||||
incoming: St,
|
||||
server: Se,
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
impl<St, C, Se> Future for Running<St, Se>
|
||||
where
|
||||
St: Sized + Stream<Item = C>,
|
||||
C: Channel + Send + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||
Se::Fut: Send + 'static,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
while let Some(channel) = ready!(self.as_mut().project().incoming.poll_next(cx)) {
|
||||
tokio::spawn(
|
||||
channel
|
||||
.respond_with(self.as_mut().project().server.clone())
|
||||
.execute(),
|
||||
);
|
||||
}
|
||||
log::info!("Server shutting down.");
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
@@ -14,10 +14,7 @@ use serde::{Deserialize, Serialize};
|
||||
use std::{error::Error, io, pin::Pin};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_serde::{Framed as SerdeFramed, *};
|
||||
use tokio_util::codec::{
|
||||
length_delimited::{self, LengthDelimitedCodec},
|
||||
Framed,
|
||||
};
|
||||
use tokio_util::codec::{length_delimited::LengthDelimitedCodec, Framed};
|
||||
|
||||
/// A transport that serializes to, and deserializes from, a byte stream.
|
||||
#[pin_project]
|
||||
@@ -130,6 +127,7 @@ pub mod tcp {
|
||||
futures::ready,
|
||||
std::{marker::PhantomData, net::SocketAddr},
|
||||
tokio::net::{TcpListener, TcpStream, ToSocketAddrs},
|
||||
tokio_util::codec::length_delimited,
|
||||
};
|
||||
|
||||
mod private {
|
||||
@@ -296,89 +294,66 @@ mod tests {
|
||||
Context::from_waker(&noop_waker_ref())
|
||||
}
|
||||
|
||||
struct TestIo(Cursor<Vec<u8>>);
|
||||
|
||||
impl AsyncRead for TestIo {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TestIo {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn close() {
|
||||
let (tx, _rx) = crate::transport::channel::bounded::<(), ()>(0);
|
||||
pin_mut!(tx);
|
||||
assert_matches!(tx.as_mut().poll_close(&mut ctx()), Poll::Ready(Ok(())));
|
||||
assert_matches!(tx.as_mut().start_send(()), Err(_));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream() {
|
||||
struct TestIo(Cursor<&'static [u8]>);
|
||||
|
||||
impl AsyncRead for TestIo {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
AsyncRead::poll_read(Pin::new(self.0.get_mut()), cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TestIo {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
let data = b"\x00\x00\x00\x18\"Test one, check check.\"";
|
||||
let data: &[u8] = b"\x00\x00\x00\x18\"Test one, check check.\"";
|
||||
let transport = Transport::from((
|
||||
TestIo(Cursor::new(data)),
|
||||
TestIo(Cursor::new(Vec::from(data))),
|
||||
SymmetricalJson::<String>::default(),
|
||||
));
|
||||
pin_mut!(transport);
|
||||
|
||||
assert_matches!(
|
||||
transport.poll_next(&mut ctx()),
|
||||
transport.as_mut().poll_next(&mut ctx()),
|
||||
Poll::Ready(Some(Ok(ref s))) if s == "Test one, check check.");
|
||||
assert_matches!(transport.as_mut().poll_next(&mut ctx()), Poll::Ready(None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sink() {
|
||||
struct TestIo<'a>(&'a mut Vec<u8>);
|
||||
|
||||
impl<'a> AsyncRead for TestIo<'a> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> AsyncWrite for TestIo<'a> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
AsyncWrite::poll_write(Pin::new(&mut *self.0), cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
AsyncWrite::poll_flush(Pin::new(&mut *self.0), cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
AsyncWrite::poll_shutdown(Pin::new(&mut *self.0), cx)
|
||||
}
|
||||
}
|
||||
|
||||
let mut writer = vec![];
|
||||
let transport =
|
||||
Transport::from((TestIo(&mut writer), SymmetricalJson::<String>::default()));
|
||||
pin_mut!(transport);
|
||||
let writer = Cursor::new(vec![]);
|
||||
let mut transport = Box::pin(Transport::from((
|
||||
TestIo(writer),
|
||||
SymmetricalJson::<String>::default(),
|
||||
)));
|
||||
|
||||
assert_matches!(
|
||||
transport.as_mut().poll_ready(&mut ctx()),
|
||||
@@ -390,7 +365,32 @@ mod tests {
|
||||
.start_send("Test one, check check.".into()),
|
||||
Ok(())
|
||||
);
|
||||
assert_matches!(transport.poll_flush(&mut ctx()), Poll::Ready(Ok(())));
|
||||
assert_eq!(writer, b"\x00\x00\x00\x18\"Test one, check check.\"");
|
||||
assert_matches!(
|
||||
transport.as_mut().poll_flush(&mut ctx()),
|
||||
Poll::Ready(Ok(()))
|
||||
);
|
||||
assert_eq!(
|
||||
transport.get_ref().0.get_ref(),
|
||||
b"\x00\x00\x00\x18\"Test one, check check.\""
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(tcp)]
|
||||
#[tokio::test]
|
||||
async fn tcp() -> io::Result<()> {
|
||||
use super::tcp;
|
||||
|
||||
let mut listener = tcp::listen("0.0.0.0:0", SymmetricalJson::<String>::default).await?;
|
||||
let addr = listener.local_addr();
|
||||
tokio::spawn(async move {
|
||||
let mut transport = listener.next().await.unwrap().unwrap();
|
||||
let message = transport.next().await.unwrap().unwrap();
|
||||
transport.send(message).await.unwrap();
|
||||
});
|
||||
let mut transport = tcp::connect(addr, SymmetricalJson::<String>::default).await?;
|
||||
transport.send(String::from("test")).await?;
|
||||
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
|
||||
assert_matches!(transport.next().await, None);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
1049
tarpc/src/server.rs
Normal file
1049
tarpc/src/server.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,13 +9,15 @@ use crate::{
|
||||
util::Compact,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{channel::mpsc, future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
|
||||
use futures::{future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
|
||||
use log::{debug, info, trace};
|
||||
use pin_project::pin_project;
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::{
|
||||
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
|
||||
collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin,
|
||||
time::SystemTime,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// A single-threaded filter that drops channels based on per-key limits.
|
||||
#[pin_project]
|
||||
@@ -27,9 +29,7 @@ where
|
||||
#[pin]
|
||||
listener: Fuse<S>,
|
||||
channels_per_key: u32,
|
||||
#[pin]
|
||||
dropped_keys: mpsc::UnboundedReceiver<K>,
|
||||
#[pin]
|
||||
dropped_keys_tx: mpsc::UnboundedSender<K>,
|
||||
key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
|
||||
keymaker: F,
|
||||
@@ -53,7 +53,7 @@ struct Tracker<K> {
|
||||
impl<K> Drop for Tracker<K> {
|
||||
fn drop(&mut self) {
|
||||
// Don't care if the listener is dropped.
|
||||
let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap());
|
||||
let _ = self.dropped_keys.send(self.key.take().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,8 +63,8 @@ where
|
||||
{
|
||||
type Item = <C as Stream>::Item;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
self.channel().poll_next(cx)
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
self.inner_pin_mut().poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,20 +74,20 @@ where
|
||||
{
|
||||
type Error = C::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.channel().poll_ready(cx)
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner_pin_mut().poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
|
||||
self.channel().start_send(item)
|
||||
fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
|
||||
self.inner_pin_mut().start_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.channel().poll_flush(cx)
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner_pin_mut().poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.channel().poll_close(cx)
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner_pin_mut().poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,12 +108,16 @@ where
|
||||
self.inner.config()
|
||||
}
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
self.project().inner.in_flight_requests()
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
self.inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
self.project().inner.start_request(request_id)
|
||||
fn start_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.inner_pin_mut().start_request(id, deadline)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,8 +128,8 @@ impl<C, K> TrackedChannel<C, K> {
|
||||
}
|
||||
|
||||
/// Returns the pinned inner channel.
|
||||
fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> {
|
||||
self.project().inner
|
||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
|
||||
self.as_mut().project().inner
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,7 +141,7 @@ where
|
||||
{
|
||||
/// Sheds new channels to stay under configured limits.
|
||||
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
|
||||
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded();
|
||||
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel();
|
||||
ChannelFilter {
|
||||
listener: listener.fuse(),
|
||||
channels_per_key,
|
||||
@@ -155,6 +159,10 @@ where
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
F: Fn(&S::Item) -> K,
|
||||
{
|
||||
fn listener_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<S>> {
|
||||
self.as_mut().project().listener
|
||||
}
|
||||
|
||||
fn handle_new_channel(
|
||||
mut self: Pin<&mut Self>,
|
||||
stream: S::Item,
|
||||
@@ -166,7 +174,7 @@ where
|
||||
"[{}] Opening channel ({}/{}) channels for key.",
|
||||
key,
|
||||
Arc::strong_count(&tracker),
|
||||
self.as_mut().project().channels_per_key
|
||||
self.channels_per_key
|
||||
);
|
||||
|
||||
Ok(TrackedChannel {
|
||||
@@ -175,15 +183,14 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
|
||||
let channels_per_key = self.channels_per_key;
|
||||
let dropped_keys = self.dropped_keys_tx.clone();
|
||||
let key_counts = &mut self.as_mut().project().key_counts;
|
||||
match key_counts.entry(key.clone()) {
|
||||
fn increment_channels_for_key(self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
|
||||
let self_ = self.project();
|
||||
let dropped_keys = self_.dropped_keys_tx;
|
||||
match self_.key_counts.entry(key.clone()) {
|
||||
Entry::Vacant(vacant) => {
|
||||
let tracker = Arc::new(Tracker {
|
||||
key: Some(key),
|
||||
dropped_keys,
|
||||
dropped_keys: dropped_keys.clone(),
|
||||
});
|
||||
|
||||
vacant.insert(Arc::downgrade(&tracker));
|
||||
@@ -191,17 +198,17 @@ where
|
||||
}
|
||||
Entry::Occupied(mut o) => {
|
||||
let count = o.get().strong_count();
|
||||
if count >= channels_per_key.try_into().unwrap() {
|
||||
if count >= TryFrom::try_from(*self_.channels_per_key).unwrap() {
|
||||
info!(
|
||||
"[{}] Opened max channels from key ({}/{}).",
|
||||
key, count, channels_per_key
|
||||
key, count, self_.channels_per_key
|
||||
);
|
||||
Err(key)
|
||||
} else {
|
||||
Ok(o.get().upgrade().unwrap_or_else(|| {
|
||||
let tracker = Arc::new(Tracker {
|
||||
key: Some(key),
|
||||
dropped_keys,
|
||||
dropped_keys: dropped_keys.clone(),
|
||||
});
|
||||
|
||||
*o.get_mut() = Arc::downgrade(&tracker);
|
||||
@@ -216,18 +223,19 @@ where
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
|
||||
match ready!(self.as_mut().project().listener.poll_next_unpin(cx)) {
|
||||
match ready!(self.listener_pin_mut().poll_next_unpin(cx)) {
|
||||
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
match ready!(self.as_mut().project().dropped_keys.poll_next_unpin(cx)) {
|
||||
fn poll_closed_channels(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
let self_ = self.project();
|
||||
match ready!(self_.dropped_keys.poll_recv(cx)) {
|
||||
Some(key) => {
|
||||
debug!("All channels dropped for key [{}]", key);
|
||||
self.as_mut().project().key_counts.remove(&key);
|
||||
self.as_mut().project().key_counts.compact(0.1);
|
||||
self_.key_counts.remove(&key);
|
||||
self_.key_counts.compact(0.1);
|
||||
Poll::Ready(())
|
||||
}
|
||||
None => unreachable!("Holding a copy of closed_channels and didn't close it."),
|
||||
@@ -268,7 +276,6 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn ctx() -> Context<'static> {
|
||||
use futures::task::*;
|
||||
@@ -280,12 +287,12 @@ fn ctx() -> Context<'static> {
|
||||
fn tracker_drop() {
|
||||
use assert_matches::assert_matches;
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded();
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
Tracker {
|
||||
key: Some(1),
|
||||
dropped_keys: tx,
|
||||
};
|
||||
assert_matches!(rx.try_next(), Ok(Some(1)));
|
||||
assert_matches!(rx.poll_recv(&mut ctx()), Poll::Ready(Some(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -293,8 +300,8 @@ fn tracked_channel_stream() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
|
||||
let (chan_tx, chan) = mpsc::unbounded();
|
||||
let (dropped_keys, _) = mpsc::unbounded();
|
||||
let (chan_tx, chan) = futures::channel::mpsc::unbounded();
|
||||
let (dropped_keys, _) = mpsc::unbounded_channel();
|
||||
let channel = TrackedChannel {
|
||||
inner: chan,
|
||||
tracker: Arc::new(Tracker {
|
||||
@@ -313,8 +320,8 @@ fn tracked_channel_sink() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
|
||||
let (chan, mut chan_rx) = mpsc::unbounded();
|
||||
let (dropped_keys, _) = mpsc::unbounded();
|
||||
let (chan, mut chan_rx) = futures::channel::mpsc::unbounded();
|
||||
let (dropped_keys, _) = mpsc::unbounded_channel();
|
||||
let channel = TrackedChannel {
|
||||
inner: chan,
|
||||
tracker: Arc::new(Tracker {
|
||||
@@ -338,7 +345,7 @@ fn channel_filter_increment_channels_for_key() {
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (_, listener) = mpsc::unbounded();
|
||||
let (_, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
|
||||
@@ -359,7 +366,7 @@ fn channel_filter_handle_new_channel() {
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (_, listener) = mpsc::unbounded();
|
||||
let (_, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
let channel1 = filter
|
||||
@@ -391,7 +398,7 @@ fn channel_filter_poll_listener() {
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = mpsc::unbounded();
|
||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
@@ -427,7 +434,7 @@ fn channel_filter_poll_closed_channels() {
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = mpsc::unbounded();
|
||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
@@ -455,7 +462,7 @@ fn channel_filter_stream() {
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = mpsc::unbounded();
|
||||
let (new_channels, listener) = futures::channel::mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
192
tarpc/src/server/in_flight_requests.rs
Normal file
192
tarpc/src/server/in_flight_requests.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
use crate::{
|
||||
util::{Compact, TimeUntil},
|
||||
PollIo,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{
|
||||
future::{AbortHandle, AbortRegistration},
|
||||
ready,
|
||||
};
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
io,
|
||||
task::{Context, Poll},
|
||||
time::SystemTime,
|
||||
};
|
||||
use tokio_util::time::delay_queue::{self, DelayQueue};
|
||||
|
||||
/// A data structure that tracks in-flight requests. It aborts requests,
|
||||
/// either on demand or when a request deadline expires.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct InFlightRequests {
|
||||
request_data: FnvHashMap<u64, RequestData>,
|
||||
deadlines: DelayQueue<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Data needed to clean up a single in-flight request.
|
||||
struct RequestData {
|
||||
/// Aborts the response handler for the associated request.
|
||||
abort_handle: AbortHandle,
|
||||
/// The key to remove the timer for the request's deadline.
|
||||
deadline_key: delay_queue::Key,
|
||||
}
|
||||
|
||||
/// An error returned when a request attempted to start with the same ID as a request already
|
||||
/// in flight.
|
||||
#[derive(Debug)]
|
||||
pub struct AlreadyExistsError;
|
||||
|
||||
impl InFlightRequests {
|
||||
/// Returns the number of in-flight requests.
|
||||
pub fn len(&self) -> usize {
|
||||
self.request_data.len()
|
||||
}
|
||||
|
||||
/// Starts a request, unless a request with the same ID is already in flight.
|
||||
pub fn start_request(
|
||||
&mut self,
|
||||
request_id: u64,
|
||||
deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, AlreadyExistsError> {
|
||||
match self.request_data.entry(request_id) {
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
let timeout = deadline.time_until();
|
||||
let (abort_handle, abort_registration) = AbortHandle::new_pair();
|
||||
let deadline_key = self.deadlines.insert(request_id, timeout);
|
||||
vacant.insert(RequestData {
|
||||
abort_handle,
|
||||
deadline_key,
|
||||
});
|
||||
Ok(abort_registration)
|
||||
}
|
||||
hash_map::Entry::Occupied(_) => Err(AlreadyExistsError),
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancels an in-flight request. Returns true iff the request was found.
|
||||
pub fn cancel_request(&mut self, request_id: u64) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
|
||||
request_data.abort_handle.abort();
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a request without aborting. Returns true iff the request was found.
|
||||
/// This method should be used when a response is being sent.
|
||||
pub fn remove_request(&mut self, request_id: u64) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Yields a request that has expired, aborting any ongoing processing of that request.
|
||||
pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo<u64> {
|
||||
Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) {
|
||||
Some(Ok(expired)) => {
|
||||
if let Some(request_data) = self.request_data.remove(expired.get_ref()) {
|
||||
self.request_data.compact(0.1);
|
||||
request_data.abort_handle.abort();
|
||||
}
|
||||
Some(Ok(expired.into_inner()))
|
||||
}
|
||||
Some(Err(e)) => Some(Err(io::Error::new(io::ErrorKind::Other, e))),
|
||||
None => None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// When InFlightRequests is dropped, any outstanding requests are aborted.
|
||||
impl Drop for InFlightRequests {
|
||||
fn drop(&mut self) {
|
||||
self.request_data
|
||||
.values()
|
||||
.for_each(|request_data| request_data.abort_handle.abort())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use {
|
||||
assert_matches::assert_matches,
|
||||
futures::{
|
||||
future::{pending, Abortable},
|
||||
FutureExt,
|
||||
},
|
||||
futures_test::task::noop_context,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_request_increases_len() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.unwrap();
|
||||
assert_eq!(in_flight_requests.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn polling_expired_aborts() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
tokio::time::pause();
|
||||
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
|
||||
|
||||
assert_matches!(
|
||||
in_flight_requests.poll_expired(&mut noop_context()),
|
||||
Poll::Ready(Some(Ok(_)))
|
||||
);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Ready(Err(_))
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancel_request_aborts() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
assert_eq!(in_flight_requests.cancel_request(0), true);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Ready(Err(_))
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_request_doesnt_abort() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
assert_eq!(in_flight_requests.remove_request(0), true);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Pending
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
}
|
||||
@@ -4,14 +4,12 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use crate::server::{Channel, Config};
|
||||
use crate::{context, Request, Response};
|
||||
use fnv::FnvHashSet;
|
||||
use futures::{
|
||||
future::{AbortHandle, AbortRegistration},
|
||||
task::*,
|
||||
Sink, Stream,
|
||||
use crate::{
|
||||
context,
|
||||
server::{Channel, Config},
|
||||
Request, Response,
|
||||
};
|
||||
use futures::{future::AbortRegistration, task::*, Sink, Stream};
|
||||
use pin_project::pin_project;
|
||||
use std::collections::VecDeque;
|
||||
use std::io;
|
||||
@@ -25,7 +23,7 @@ pub(crate) struct FakeChannel<In, Out> {
|
||||
#[pin]
|
||||
pub sink: VecDeque<Out>,
|
||||
pub config: Config,
|
||||
pub in_flight_requests: FnvHashSet<u64>,
|
||||
pub in_flight_requests: super::in_flight_requests::InFlightRequests,
|
||||
}
|
||||
|
||||
impl<In, Out> Stream for FakeChannel<In, Out>
|
||||
@@ -50,7 +48,7 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
|
||||
self.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&response.request_id);
|
||||
.remove_request(response.request_id);
|
||||
self.project()
|
||||
.sink
|
||||
.start_send(response)
|
||||
@@ -77,13 +75,18 @@ where
|
||||
&self.config
|
||||
}
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
self.in_flight_requests.len()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration {
|
||||
self.project().in_flight_requests.insert(id);
|
||||
AbortHandle::new_pair().1
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.project()
|
||||
.in_flight_requests
|
||||
.start_request(id, deadline)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::{Response, ServerError};
|
||||
use futures::{future::AbortRegistration, prelude::*, ready, task::*};
|
||||
use log::debug;
|
||||
use pin_project::pin_project;
|
||||
use std::{io, pin::Pin};
|
||||
use std::{io, pin::Pin, time::SystemTime};
|
||||
|
||||
/// A [`Channel`] that limits the number of concurrent
|
||||
/// requests by throttling.
|
||||
@@ -36,8 +36,8 @@ where
|
||||
/// `max_in_flight_requests`.
|
||||
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
|
||||
Throttler {
|
||||
inner,
|
||||
max_in_flight_requests,
|
||||
inner,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -113,16 +113,20 @@ where
|
||||
type Req = <C as Channel>::Req;
|
||||
type Resp = <C as Channel>::Resp;
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
self.project().inner.in_flight_requests()
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
self.inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
self.inner.config()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
self.project().inner.start_request(request_id)
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.project().inner.start_request(id, deadline)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,10 +177,10 @@ use crate::Request;
|
||||
#[cfg(test)]
|
||||
use pin_utils::pin_mut;
|
||||
#[cfg(test)]
|
||||
use std::marker::PhantomData;
|
||||
use std::{marker::PhantomData, time::Duration};
|
||||
|
||||
#[test]
|
||||
fn throttler_in_flight_requests() {
|
||||
#[tokio::test]
|
||||
async fn throttler_in_flight_requests() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
@@ -184,20 +188,27 @@ fn throttler_in_flight_requests() {
|
||||
|
||||
pin_mut!(throttler);
|
||||
for i in 0..5 {
|
||||
throttler.inner.in_flight_requests.insert(i);
|
||||
throttler
|
||||
.inner
|
||||
.in_flight_requests
|
||||
.start_request(i, SystemTime::now() + Duration::from_secs(1))
|
||||
.unwrap();
|
||||
}
|
||||
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_start_request() {
|
||||
#[tokio::test]
|
||||
async fn throttler_start_request() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.as_mut().start_request(1);
|
||||
throttler
|
||||
.as_mut()
|
||||
.start_request(1, SystemTime::now() + Duration::from_secs(1))
|
||||
.unwrap();
|
||||
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
|
||||
}
|
||||
|
||||
@@ -292,24 +303,32 @@ fn throttler_poll_next_throttled_sink_not_ready() {
|
||||
fn config(&self) -> &Config {
|
||||
unimplemented!()
|
||||
}
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
0
|
||||
}
|
||||
fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration {
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
_id: u64,
|
||||
_deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_start_send() {
|
||||
#[tokio::test]
|
||||
async fn throttler_start_send() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.inner.in_flight_requests.insert(0);
|
||||
throttler
|
||||
.inner
|
||||
.in_flight_requests
|
||||
.start_request(0, SystemTime::now() + Duration::from_secs(1))
|
||||
.unwrap();
|
||||
throttler
|
||||
.as_mut()
|
||||
.start_send(Response {
|
||||
@@ -317,7 +336,7 @@ fn throttler_start_send() {
|
||||
message: Ok(1),
|
||||
})
|
||||
.unwrap();
|
||||
assert!(throttler.inner.in_flight_requests.is_empty());
|
||||
assert_eq!(throttler.inner.in_flight_requests.len(), 0);
|
||||
assert_eq!(
|
||||
throttler.inner.sink.get(0),
|
||||
Some(&Response {
|
||||
@@ -27,7 +27,7 @@ use std::{
|
||||
/// Consists of a span identifying an event, an optional parent span identifying a causal event
|
||||
/// that triggered the current span, and a trace with which all related spans are associated.
|
||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Context {
|
||||
/// An identifier of the trace associated with the current context. A trace ID is typically
|
||||
/// created at a root span and passed along through all causal events.
|
||||
@@ -47,12 +47,12 @@ pub struct Context {
|
||||
/// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the
|
||||
/// same trace ID.
|
||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct TraceId(u128);
|
||||
|
||||
/// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace.
|
||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct SpanId(u64);
|
||||
|
||||
impl Context {
|
||||
|
||||
@@ -7,10 +7,11 @@
|
||||
//! Transports backed by in-memory channels.
|
||||
|
||||
use crate::PollIo;
|
||||
use futures::{channel::mpsc, task::*, Sink, Stream};
|
||||
use futures::{task::*, Sink, Stream};
|
||||
use pin_project::pin_project;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
|
||||
/// [`Sink`].
|
||||
@@ -18,8 +19,8 @@ pub fn unbounded<SinkItem, Item>() -> (
|
||||
UnboundedChannel<SinkItem, Item>,
|
||||
UnboundedChannel<Item, SinkItem>,
|
||||
) {
|
||||
let (tx1, rx2) = mpsc::unbounded();
|
||||
let (tx2, rx1) = mpsc::unbounded();
|
||||
let (tx1, rx2) = mpsc::unbounded_channel();
|
||||
let (tx2, rx1) = mpsc::unbounded_channel();
|
||||
(
|
||||
UnboundedChannel { tx: tx1, rx: rx1 },
|
||||
UnboundedChannel { tx: tx2, rx: rx2 },
|
||||
@@ -28,60 +29,125 @@ pub fn unbounded<SinkItem, Item>() -> (
|
||||
|
||||
/// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender)
|
||||
/// and [`UnboundedReceiver`](mpsc::UnboundedReceiver).
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct UnboundedChannel<Item, SinkItem> {
|
||||
#[pin]
|
||||
rx: mpsc::UnboundedReceiver<Item>,
|
||||
#[pin]
|
||||
tx: mpsc::UnboundedSender<SinkItem>,
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
|
||||
type Item = Result<Item, io::Error>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> {
|
||||
self.rx.poll_recv(cx).map(|option| option.map(Ok))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(if self.tx.is_closed() {
|
||||
Err(io::Error::from(io::ErrorKind::NotConnected))
|
||||
} else {
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
||||
self.tx
|
||||
.send(item)
|
||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
// UnboundedSender requires no flushing.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
// UnboundedSender can't initiate closure.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns two channel peers with buffer equal to `capacity`. Each [`Stream`] yields items sent
|
||||
/// through the other's [`Sink`].
|
||||
pub fn bounded<SinkItem, Item>(
|
||||
capacity: usize,
|
||||
) -> (Channel<SinkItem, Item>, Channel<Item, SinkItem>) {
|
||||
let (tx1, rx2) = futures::channel::mpsc::channel(capacity);
|
||||
let (tx2, rx1) = futures::channel::mpsc::channel(capacity);
|
||||
(Channel { tx: tx1, rx: rx1 }, Channel { tx: tx2, rx: rx2 })
|
||||
}
|
||||
|
||||
/// A bi-directional channel backed by a [`Sender`](futures::channel::mpsc::Sender)
|
||||
/// and [`Receiver`](futures::channel::mpsc::Receiver).
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct Channel<Item, SinkItem> {
|
||||
#[pin]
|
||||
rx: futures::channel::mpsc::Receiver<Item>,
|
||||
#[pin]
|
||||
tx: futures::channel::mpsc::Sender<SinkItem>,
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Stream for Channel<Item, SinkItem> {
|
||||
type Item = Result<Item, io::Error>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> {
|
||||
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
|
||||
impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_ready(cx)
|
||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
||||
.map_err(convert_send_err_to_io)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
||||
self.project()
|
||||
.tx
|
||||
.start_send(item)
|
||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
||||
.map_err(convert_send_err_to_io)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_flush(cx)
|
||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
||||
.map_err(convert_send_err_to_io)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_close(cx)
|
||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
||||
.map_err(convert_send_err_to_io)
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_send_err_to_io(e: futures::channel::mpsc::SendError) -> io::Error {
|
||||
if e.is_disconnected() {
|
||||
io::Error::from(io::ErrorKind::NotConnected)
|
||||
} else if e.is_full() {
|
||||
io::Error::from(io::ErrorKind::WouldBlock)
|
||||
} else {
|
||||
io::Error::new(io::ErrorKind::Other, e)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(feature = "tokio1")]
|
||||
mod tests {
|
||||
use crate::{
|
||||
client, context,
|
||||
server::{Handler, Server},
|
||||
server::{BaseChannel, Incoming},
|
||||
transport,
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
@@ -89,16 +155,15 @@ mod tests {
|
||||
use log::trace;
|
||||
use std::io;
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[tokio::test]
|
||||
async fn integration() -> io::Result<()> {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||
tokio::spawn(
|
||||
Server::default()
|
||||
.incoming(stream::once(future::ready(server_channel)))
|
||||
.respond_with(|_ctx, request: String| {
|
||||
stream::once(future::ready(server_channel))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(|_ctx, request: String| {
|
||||
future::ready(request.parse::<u64>().map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
@@ -108,7 +173,7 @@ mod tests {
|
||||
}),
|
||||
);
|
||||
|
||||
let mut client = client::new(client::Config::default(), client_channel).spawn()?;
|
||||
let client = client::new(client::Config::default(), client_channel).spawn()?;
|
||||
|
||||
let response1 = client.call(context::current(), "123".into()).await?;
|
||||
let response2 = client.call(context::current(), "abc".into()).await?;
|
||||
@@ -10,8 +10,8 @@ use std::{
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
|
||||
#[cfg(feature = "serde1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "serde1")))]
|
||||
pub mod serde;
|
||||
|
||||
/// Extension trait for [SystemTimes](SystemTime) in the future, i.e. deadlines.
|
||||
@@ -1,4 +1,4 @@
|
||||
#[tarpc::service]
|
||||
#[tarpc::service(derive_serde = false)]
|
||||
trait World {
|
||||
async fn hello(name: String) -> String;
|
||||
}
|
||||
|
||||
@@ -9,11 +9,3 @@ error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not a
|
||||
|
|
||||
10 | fn hello(name: String) -> String {
|
||||
| ^^
|
||||
|
||||
error[E0433]: failed to resolve: use of undeclared type or module `serde`
|
||||
--> $DIR/tarpc_server_missing_async.rs:1:1
|
||||
|
|
||||
1 | #[tarpc::service]
|
||||
| ^^^^^^^^^^^^^^^^^ use of undeclared type or module `serde`
|
||||
|
|
||||
= note: this error originates in an attribute macro (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||
|
||||
56
tarpc/tests/dataservice.rs
Normal file
56
tarpc/tests/dataservice.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use futures::prelude::*;
|
||||
use std::io;
|
||||
use tarpc::serde_transport;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{BaseChannel, Incoming},
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
#[tarpc::derive_serde]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum TestData {
|
||||
Black,
|
||||
White,
|
||||
}
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait ColorProtocol {
|
||||
async fn get_opposite_color(color: TestData) -> TestData;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ColorServer;
|
||||
|
||||
#[tarpc::server]
|
||||
impl ColorProtocol for ColorServer {
|
||||
async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData {
|
||||
match color {
|
||||
TestData::White => TestData::Black,
|
||||
TestData::Black => TestData::White,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_call() -> io::Result<()> {
|
||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
|
||||
let addr = transport.local_addr();
|
||||
tokio::spawn(
|
||||
transport
|
||||
.take(1)
|
||||
.filter_map(|r| async { r.ok() })
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(ColorServer.serve()),
|
||||
);
|
||||
|
||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let client = ColorProtocolClient::new(client::Config::default(), transport).spawn()?;
|
||||
|
||||
let color = client
|
||||
.get_opposite_color(context::current(), TestData::White)
|
||||
.await?;
|
||||
assert_eq!(color, TestData::Black);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -3,15 +3,17 @@ use futures::{
|
||||
future::{join_all, ready, Ready},
|
||||
prelude::*,
|
||||
};
|
||||
use std::io;
|
||||
use std::{
|
||||
io,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use tarpc::{
|
||||
client::{self},
|
||||
context, serde_transport,
|
||||
server::{self, BaseChannel, Channel, Handler},
|
||||
context,
|
||||
server::{self, BaseChannel, Channel, Incoming},
|
||||
transport::channel,
|
||||
};
|
||||
use tokio::join;
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
#[tarpc_plugins::service]
|
||||
trait Service {
|
||||
@@ -44,11 +46,11 @@ async fn sequential() -> io::Result<()> {
|
||||
|
||||
tokio::spawn(
|
||||
BaseChannel::new(server::Config::default(), rx)
|
||||
.respond_with(Server.serve())
|
||||
.execute(),
|
||||
.requests()
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let mut client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
|
||||
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
|
||||
assert_matches!(
|
||||
@@ -58,21 +60,77 @@ async fn sequential() -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
#[tokio::test]
|
||||
async fn serde() -> io::Result<()> {
|
||||
async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
|
||||
#[tarpc_plugins::service]
|
||||
trait Loop {
|
||||
async fn r#loop();
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct LoopServer;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AllHandlersComplete;
|
||||
|
||||
#[tarpc::server]
|
||||
impl Loop for LoopServer {
|
||||
async fn r#loop(self, _: context::Context) {
|
||||
loop {
|
||||
futures::pending!();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let transport = serde_transport::tcp::listen("localhost:56789", Json::default).await?;
|
||||
let (tx, rx) = channel::unbounded();
|
||||
|
||||
// Set up a client that initiates a long-lived request.
|
||||
// The request will complete in error when the server drops the connection.
|
||||
tokio::spawn(async move {
|
||||
let client = LoopClient::new(client::Config::default(), tx)
|
||||
.spawn()
|
||||
.unwrap();
|
||||
|
||||
let mut ctx = context::current();
|
||||
ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60);
|
||||
let _ = client.r#loop(ctx).await;
|
||||
});
|
||||
|
||||
let mut requests = BaseChannel::with_defaults(rx).requests();
|
||||
// Reading a request should trigger the request being registered with BaseChannel.
|
||||
let first_request = requests.next().await.unwrap()?;
|
||||
// Dropping the channel should trigger cleanup of outstanding requests.
|
||||
drop(requests);
|
||||
// In-flight requests should be aborted by channel cleanup.
|
||||
// The first and only request sent by the client is `loop`, which is an infinite loop
|
||||
// on the server side, so if cleanup was not triggered, this line should hang indefinitely.
|
||||
first_request.execute(LoopServer.serve()).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
||||
#[tokio::test]
|
||||
async fn serde() -> io::Result<()> {
|
||||
use tarpc::serde_transport;
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56789", Json::default).await?;
|
||||
let addr = transport.local_addr();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(transport.take(1).filter_map(|r| async { r.ok() }))
|
||||
.respond_with(Server.serve()),
|
||||
transport
|
||||
.take(1)
|
||||
.filter_map(|r| async { r.ok() })
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let mut client = ServiceClient::new(client::Config::default(), transport).spawn()?;
|
||||
let client = ServiceClient::new(client::Config::default(), transport).spawn()?;
|
||||
|
||||
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
|
||||
assert_matches!(
|
||||
@@ -89,21 +147,16 @@ async fn concurrent() -> io::Result<()> {
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(stream::once(ready(rx)))
|
||||
.respond_with(Server.serve()),
|
||||
stream::once(ready(rx))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
|
||||
let mut c = client.clone();
|
||||
let req1 = c.add(context::current(), 1, 2);
|
||||
|
||||
let mut c = client.clone();
|
||||
let req2 = c.add(context::current(), 3, 4);
|
||||
|
||||
let mut c = client.clone();
|
||||
let req3 = c.hey(context::current(), "Tim".to_string());
|
||||
let req1 = client.add(context::current(), 1, 2);
|
||||
let req2 = client.add(context::current(), 3, 4);
|
||||
let req3 = client.hey(context::current(), "Tim".to_string());
|
||||
|
||||
assert_matches!(req1.await, Ok(3));
|
||||
assert_matches!(req2.await, Ok(7));
|
||||
@@ -118,21 +171,16 @@ async fn concurrent_join() -> io::Result<()> {
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(stream::once(ready(rx)))
|
||||
.respond_with(Server.serve()),
|
||||
stream::once(ready(rx))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
|
||||
let mut c = client.clone();
|
||||
let req1 = c.add(context::current(), 1, 2);
|
||||
|
||||
let mut c = client.clone();
|
||||
let req2 = c.add(context::current(), 3, 4);
|
||||
|
||||
let mut c = client.clone();
|
||||
let req3 = c.hey(context::current(), "Tim".to_string());
|
||||
let req1 = client.add(context::current(), 1, 2);
|
||||
let req2 = client.add(context::current(), 3, 4);
|
||||
let req3 = client.hey(context::current(), "Tim".to_string());
|
||||
|
||||
let (resp1, resp2, resp3) = join!(req1, req2, req3);
|
||||
assert_matches!(resp1, Ok(3));
|
||||
@@ -148,18 +196,15 @@ async fn concurrent_join_all() -> io::Result<()> {
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
tarpc::Server::default()
|
||||
.incoming(stream::once(ready(rx)))
|
||||
.respond_with(Server.serve()),
|
||||
stream::once(ready(rx))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn()?;
|
||||
|
||||
let mut c1 = client.clone();
|
||||
let mut c2 = client.clone();
|
||||
|
||||
let req1 = c1.add(context::current(), 1, 2);
|
||||
let req2 = c2.add(context::current(), 3, 4);
|
||||
let req1 = client.add(context::current(), 1, 2);
|
||||
let req2 = client.add(context::current(), 3, 4);
|
||||
|
||||
let responses = join_all(vec![req1, req2]).await;
|
||||
assert_matches!(responses[0], Ok(3));
|
||||
@@ -167,3 +212,38 @@ async fn concurrent_join_all() -> io::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn counter() -> io::Result<()> {
|
||||
#[tarpc::service]
|
||||
trait Counter {
|
||||
async fn count() -> u32;
|
||||
}
|
||||
|
||||
struct CountService(u32);
|
||||
|
||||
impl Counter for &mut CountService {
|
||||
type CountFut = futures::future::Ready<u32>;
|
||||
|
||||
fn count(self, _: context::Context) -> Self::CountFut {
|
||||
self.0 += 1;
|
||||
futures::future::ready(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(async {
|
||||
let mut requests = BaseChannel::with_defaults(rx).requests();
|
||||
let mut counter = CountService(0);
|
||||
|
||||
while let Some(Ok(request)) = requests.next().await {
|
||||
request.execute(counter.serve()).await;
|
||||
}
|
||||
});
|
||||
|
||||
let client = CounterClient::new(client::Config::default(), tx).spawn()?;
|
||||
assert_matches!(client.count(context::current()).await, Ok(1));
|
||||
assert_matches!(client.count(context::current()).await, Ok(2));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user