mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9c86ce157 | ||
|
|
d4f579542d | ||
|
|
5011dbe057 | ||
|
|
4aa90ee933 | ||
|
|
eb91000fed | ||
|
|
006a9f3af1 | ||
|
|
a3d00b07da | ||
|
|
d62706e62c | ||
|
|
b92dd154bc | ||
|
|
a6758fd1f9 | ||
|
|
2c241cc809 | ||
|
|
263ef8a897 | ||
|
|
d50290a21c | ||
|
|
26988cb833 | ||
|
|
6cf18a1caf | ||
|
|
84932df9b4 | ||
|
|
8dc3711a80 | ||
|
|
7c5afa97bb | ||
|
|
324df5cd15 | ||
|
|
3264979993 | ||
|
|
dd63fb59bf | ||
|
|
f4db8cc5b4 | ||
|
|
e9ba350496 | ||
|
|
e6d779e70b | ||
|
|
ce5f8cfb0c | ||
|
|
4b69dc8db5 | ||
|
|
866db2a2cd | ||
|
|
bed85e2827 | ||
|
|
93f3880025 | ||
|
|
878f594d5b | ||
|
|
aa9bbad109 | ||
|
|
7e872ce925 | ||
|
|
62541b709d | ||
|
|
8c43f94fb6 | ||
|
|
7fa4e5064d | ||
|
|
94db7610bb |
75
.github/workflows/main.yml
vendored
75
.github/workflows/main.yml
vendored
@@ -18,20 +18,8 @@ jobs:
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
target: mipsel-unknown-linux-gnu
|
||||
override: true
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: check
|
||||
args: --all-features
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: check
|
||||
args: --all-features --target mipsel-unknown-linux-gnu
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- run: cargo check --all-features
|
||||
|
||||
test:
|
||||
name: Test Suite
|
||||
@@ -42,34 +30,13 @@ jobs:
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features serde1
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features tokio1
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features serde-transport
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --manifest-path tarpc/Cargo.toml --features tcp
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --all-features
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- run: cargo test
|
||||
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde1
|
||||
- run: cargo test --manifest-path tarpc/Cargo.toml --features tokio1
|
||||
- run: cargo test --manifest-path tarpc/Cargo.toml --features serde-transport
|
||||
- run: cargo test --manifest-path tarpc/Cargo.toml --features tcp
|
||||
- run: cargo test --all-features
|
||||
|
||||
fmt:
|
||||
name: Rustfmt
|
||||
@@ -80,16 +47,10 @@ jobs:
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- run: rustup component add rustfmt
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: fmt
|
||||
args: --all -- --check
|
||||
components: rustfmt
|
||||
- run: cargo fmt --all -- --check
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
@@ -100,13 +61,7 @@ jobs:
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- run: rustup component add clippy
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: clippy
|
||||
args: --all-features -- -D warnings
|
||||
components: clippy
|
||||
- run: cargo clippy --all-features -- -D warnings
|
||||
|
||||
23
README.md
23
README.md
@@ -67,7 +67,7 @@ Some other features of tarpc:
|
||||
Add to your `Cargo.toml` dependencies:
|
||||
|
||||
```toml
|
||||
tarpc = "0.31"
|
||||
tarpc = "0.34"
|
||||
```
|
||||
|
||||
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
@@ -83,7 +83,7 @@ your `Cargo.toml`:
|
||||
anyhow = "1.0"
|
||||
futures = "0.3"
|
||||
tarpc = { version = "0.31", features = ["tokio1"] }
|
||||
tokio = { version = "1.0", features = ["macros"] }
|
||||
tokio = { version = "1.0", features = ["rt-multi-thread", "macros"] }
|
||||
```
|
||||
|
||||
In the following example, we use an in-process channel for communication between
|
||||
@@ -93,14 +93,10 @@ For a more real-world example, see [example-service](example-service).
|
||||
First, let's set up the dependencies and service definition.
|
||||
|
||||
```rust
|
||||
|
||||
use futures::{
|
||||
future::{self, Ready},
|
||||
prelude::*,
|
||||
};
|
||||
use futures::future::{self, Ready};
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{self, incoming::Incoming, Channel},
|
||||
server::{self, Channel},
|
||||
};
|
||||
|
||||
// This is the service definition. It looks a lot like a trait definition.
|
||||
@@ -122,13 +118,8 @@ implement it for our Server struct.
|
||||
struct HelloServer;
|
||||
|
||||
impl World for HelloServer {
|
||||
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
|
||||
// an associated type representing the future output by the fn.
|
||||
|
||||
type HelloFut = Ready<String>;
|
||||
|
||||
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
||||
future::ready(format!("Hello, {name}!"))
|
||||
async fn hello(self, _: context::Context, name: String) -> String {
|
||||
format!("Hello, {name}!")
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -148,7 +139,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
// that takes a config and any Transport as input.
|
||||
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
||||
let client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
||||
|
||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||
|
||||
44
RELEASES.md
44
RELEASES.md
@@ -1,3 +1,47 @@
|
||||
## 0.34.0 (2023-12-29)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- `#[tarpc::server]` is no more! Service traits now use async fns.
|
||||
- `Channel::execute` no longer spawns request handlers. Async-fn-in-traits makes it impossible to
|
||||
add a Send bound to the future returned by `Serve::serve`. Instead, `Channel::execute` returns a
|
||||
stream of futures, where each future is a request handler. To achieve the former behavior:
|
||||
```rust
|
||||
channel.execute(server.serve())
|
||||
.for_each(|rpc| { tokio::spawn(rpc); })
|
||||
```
|
||||
|
||||
### New Features
|
||||
|
||||
- Request hooks are added to the serve trait, so that it's easy to hook in cross-cutting
|
||||
functionality like throttling, authorization, etc.
|
||||
- The Client trait is back! This makes it possible to hook in generic client functionality like load
|
||||
balancing, retries, etc.
|
||||
|
||||
## 0.33.0 (2023-04-01)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
Opentelemetry dependency version increased to 0.18.
|
||||
|
||||
## 0.32.0 (2023-03-24)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- As part of a fix to return more channel errors in RPC results, a few error types have changed:
|
||||
|
||||
0. `client::RpcError::Disconnected` was split into the following errors:
|
||||
- Shutdown: the client was shutdown, either intentionally or due to an error. If due to an
|
||||
error, pending RPCs should see the more specific errors below.
|
||||
- Send: an RPC message failed to send over the transport. Only the RPC that failed to be sent
|
||||
will see this error.
|
||||
- Receive: a fatal error occurred while receiving from the transport. All in-flight RPCs will
|
||||
receive this error.
|
||||
0. `client::ChannelError` and `server::ChannelError` are unified in `tarpc::ChannelError`.
|
||||
Previously, server transport errors would not indicate during which activity the transport
|
||||
error occurred. Now, just like the client already was, it will be specific: reading, readying,
|
||||
sending, flushing, or closing.
|
||||
|
||||
## 0.31.0 (2022-11-03)
|
||||
|
||||
### New Features
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc-example-service"
|
||||
version = "0.13.0"
|
||||
version = "0.15.0"
|
||||
rust-version = "1.56"
|
||||
authors = ["Tim Kuehn <tikue@google.com>"]
|
||||
edition = "2021"
|
||||
@@ -18,14 +18,15 @@ anyhow = "1.0"
|
||||
clap = { version = "3.0.0-rc.9", features = ["derive"] }
|
||||
log = "0.4"
|
||||
futures = "0.3"
|
||||
opentelemetry = { version = "0.17", features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.16", features = ["rt-tokio"] }
|
||||
opentelemetry = { version = "0.21.0" }
|
||||
opentelemetry-jaeger = { version = "0.20.0", features = ["rt-tokio"] }
|
||||
rand = "0.8"
|
||||
tarpc = { version = "0.31", path = "../tarpc", features = ["full"] }
|
||||
tarpc = { version = "0.34", path = "../tarpc", features = ["full"] }
|
||||
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
|
||||
tracing = { version = "0.1" }
|
||||
tracing-opentelemetry = "0.17"
|
||||
tracing-subscriber = {version = "0.3", features = ["env-filter"]}
|
||||
tracing-opentelemetry = "0.22.0"
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
opentelemetry_sdk = "0.21.1"
|
||||
|
||||
[lib]
|
||||
name = "service"
|
||||
|
||||
15
example-service/README.md
Normal file
15
example-service/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# Example
|
||||
|
||||
Example service to demonstrate how to set up `tarpc` with [Jaeger](https://www.jaegertracing.io). To see traces Jaeger, run the following with `RUST_LOG=trace`.
|
||||
|
||||
## Server
|
||||
|
||||
```bash
|
||||
cargo run --bin server -- --port 50051
|
||||
```
|
||||
|
||||
## Client
|
||||
|
||||
```bash
|
||||
cargo run --bin client -- --server-addr "[::1]:50051" --name "Bob"
|
||||
```
|
||||
@@ -26,7 +26,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
let flags = Flags::parse();
|
||||
init_tracing("Tarpc Example Client")?;
|
||||
|
||||
let transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
|
||||
let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
|
||||
transport.config_mut().max_frame_length(usize::MAX);
|
||||
|
||||
// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
|
||||
// config and any Transport as input.
|
||||
@@ -42,7 +43,10 @@ async fn main() -> anyhow::Result<()> {
|
||||
.instrument(tracing::info_span!("Two Hellos"))
|
||||
.await;
|
||||
|
||||
tracing::info!("{:?}", hello);
|
||||
match hello {
|
||||
Ok(hello) => tracing::info!("{hello:?}"),
|
||||
Err(e) => tracing::warn!("{:?}", anyhow::Error::from(e)),
|
||||
}
|
||||
|
||||
// Let the background span processor finish.
|
||||
sleep(Duration::from_micros(1)).await;
|
||||
|
||||
@@ -19,10 +19,10 @@ pub trait World {
|
||||
pub fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||
|
||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||
let tracer = opentelemetry_jaeger::new_agent_pipeline()
|
||||
.with_service_name(service_name)
|
||||
.with_max_packet_size(2usize.pow(13))
|
||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||
.install_batch(opentelemetry_sdk::runtime::Tokio)?;
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||
|
||||
@@ -34,7 +34,6 @@ struct Flags {
|
||||
#[derive(Clone)]
|
||||
struct HelloServer(SocketAddr);
|
||||
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
async fn hello(self, _: context::Context, name: String) -> String {
|
||||
let sleep_time =
|
||||
@@ -44,6 +43,10 @@ impl World for HelloServer {
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let flags = Flags::parse();
|
||||
@@ -66,7 +69,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
// the generated World trait.
|
||||
.map(|channel| {
|
||||
let server = HelloServer(channel.transport().peer_addr().unwrap());
|
||||
channel.execute(server.serve())
|
||||
channel.execute(server.serve()).for_each(spawn)
|
||||
})
|
||||
// Max 10 channels.
|
||||
.buffer_unordered(10)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc-plugins"
|
||||
version = "0.12.0"
|
||||
version = "0.13.0"
|
||||
rust-version = "1.56"
|
||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||
edition = "2021"
|
||||
|
||||
@@ -12,18 +12,18 @@ extern crate quote;
|
||||
extern crate syn;
|
||||
|
||||
use proc_macro::TokenStream;
|
||||
use proc_macro2::{Span, TokenStream as TokenStream2};
|
||||
use proc_macro2::TokenStream as TokenStream2;
|
||||
use quote::{format_ident, quote, ToTokens};
|
||||
use syn::{
|
||||
braced,
|
||||
ext::IdentExt,
|
||||
parenthesized,
|
||||
parse::{Parse, ParseStream},
|
||||
parse_macro_input, parse_quote, parse_str,
|
||||
parse_macro_input, parse_quote,
|
||||
spanned::Spanned,
|
||||
token::Comma,
|
||||
Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool,
|
||||
MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility,
|
||||
Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type,
|
||||
Visibility,
|
||||
};
|
||||
|
||||
/// Accumulates multiple errors into a result.
|
||||
@@ -220,15 +220,15 @@ impl Parse for DeriveSerde {
|
||||
/// Adds the following annotations to the annotated item:
|
||||
///
|
||||
/// ```rust
|
||||
/// #[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
||||
/// #[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)]
|
||||
/// #[serde(crate = "tarpc::serde")]
|
||||
/// # struct Foo;
|
||||
/// ```
|
||||
#[proc_macro_attribute]
|
||||
pub fn derive_serde(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let mut gen: proc_macro2::TokenStream = quote! {
|
||||
#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
||||
#[serde(crate = "tarpc::serde")]
|
||||
#[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)]
|
||||
#[serde(crate = "::tarpc::serde")]
|
||||
};
|
||||
gen.extend(proc_macro2::TokenStream::from(item));
|
||||
proc_macro::TokenStream::from(gen)
|
||||
@@ -257,11 +257,10 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
.map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
|
||||
.collect();
|
||||
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
|
||||
let response_fut_name = &format!("{}ResponseFut", ident.unraw());
|
||||
let derive_serialize = if derive_serde.0 {
|
||||
Some(
|
||||
quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
||||
#[serde(crate = "tarpc::serde")]},
|
||||
quote! {#[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)]
|
||||
#[serde(crate = "::tarpc::serde")]},
|
||||
)
|
||||
} else {
|
||||
None
|
||||
@@ -274,10 +273,9 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
ServiceGenerator {
|
||||
response_fut_name,
|
||||
service_ident: ident,
|
||||
client_stub_ident: &format_ident!("{}Stub", ident),
|
||||
server_ident: &format_ident!("Serve{}", ident),
|
||||
response_fut_ident: &Ident::new(response_fut_name, ident.span()),
|
||||
client_ident: &format_ident!("{}Client", ident),
|
||||
request_ident: &format_ident!("{}Request", ident),
|
||||
response_ident: &format_ident!("{}Response", ident),
|
||||
@@ -304,137 +302,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
.zip(camel_case_fn_names.iter())
|
||||
.map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
|
||||
.collect::<Vec<_>>(),
|
||||
future_types: &camel_case_fn_names
|
||||
.iter()
|
||||
.map(|name| parse_str(&format!("{name}Fut")).unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
derive_serialize: derive_serialize.as_ref(),
|
||||
}
|
||||
.into_token_stream()
|
||||
.into()
|
||||
}
|
||||
|
||||
/// generate an identifier consisting of the method name to CamelCase with
|
||||
/// Fut appended to it.
|
||||
fn associated_type_for_rpc(method: &ImplItemMethod) -> String {
|
||||
snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut"
|
||||
}
|
||||
|
||||
/// Transforms an async function into a sync one, returning a type declaration
|
||||
/// for the return type (a future).
|
||||
fn transform_method(method: &mut ImplItemMethod) -> ImplItemType {
|
||||
method.sig.asyncness = None;
|
||||
|
||||
// get either the return type or ().
|
||||
let ret = match &method.sig.output {
|
||||
ReturnType::Default => quote!(()),
|
||||
ReturnType::Type(_, ret) => quote!(#ret),
|
||||
};
|
||||
|
||||
let fut_name = associated_type_for_rpc(method);
|
||||
let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span());
|
||||
|
||||
// generate the updated return signature.
|
||||
method.sig.output = parse_quote! {
|
||||
-> ::core::pin::Pin<Box<
|
||||
dyn ::core::future::Future<Output = #ret> + ::core::marker::Send
|
||||
>>
|
||||
};
|
||||
|
||||
// transform the body of the method into Box::pin(async move { body }).
|
||||
let block = method.block.clone();
|
||||
method.block = parse_quote! [{
|
||||
Box::pin(async move
|
||||
#block
|
||||
)
|
||||
}];
|
||||
|
||||
// generate and return type declaration for return type.
|
||||
let t: ImplItemType = parse_quote! {
|
||||
type #fut_name_ident = ::core::pin::Pin<Box<dyn ::core::future::Future<Output = #ret> + ::core::marker::Send>>;
|
||||
};
|
||||
|
||||
t
|
||||
}
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let mut item = syn::parse_macro_input!(input as ItemImpl);
|
||||
let span = item.span();
|
||||
|
||||
// the generated type declarations
|
||||
let mut types: Vec<ImplItemType> = Vec::new();
|
||||
let mut expected_non_async_types: Vec<(&ImplItemMethod, String)> = Vec::new();
|
||||
let mut found_non_async_types: Vec<&ImplItemType> = Vec::new();
|
||||
|
||||
for inner in &mut item.items {
|
||||
match inner {
|
||||
ImplItem::Method(method) => {
|
||||
if method.sig.asyncness.is_some() {
|
||||
// if this function is declared async, transform it into a regular function
|
||||
let typedecl = transform_method(method);
|
||||
types.push(typedecl);
|
||||
} else {
|
||||
// If it's not async, keep track of all required associated types for better
|
||||
// error reporting.
|
||||
expected_non_async_types.push((method, associated_type_for_rpc(method)));
|
||||
}
|
||||
}
|
||||
ImplItem::Type(typedecl) => found_non_async_types.push(typedecl),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) =
|
||||
verify_types_were_provided(span, &expected_non_async_types, &found_non_async_types)
|
||||
{
|
||||
return TokenStream::from(e.to_compile_error());
|
||||
}
|
||||
|
||||
// add the type declarations into the impl block
|
||||
for t in types.into_iter() {
|
||||
item.items.push(syn::ImplItem::Type(t));
|
||||
}
|
||||
|
||||
TokenStream::from(quote!(#item))
|
||||
}
|
||||
|
||||
fn verify_types_were_provided(
|
||||
span: Span,
|
||||
expected: &[(&ImplItemMethod, String)],
|
||||
provided: &[&ImplItemType],
|
||||
) -> syn::Result<()> {
|
||||
let mut result = Ok(());
|
||||
for (method, expected) in expected {
|
||||
if !provided.iter().any(|typedecl| typedecl.ident == expected) {
|
||||
let mut e = syn::Error::new(
|
||||
span,
|
||||
format!("not all trait items implemented, missing: `{expected}`"),
|
||||
);
|
||||
let fn_span = method.sig.fn_token.span();
|
||||
e.extend(syn::Error::new(
|
||||
fn_span.join(method.sig.ident.span()).unwrap_or(fn_span),
|
||||
format!(
|
||||
"hint: `#[tarpc::server]` only rewrites async fns, and `fn {}` is not async",
|
||||
method.sig.ident
|
||||
),
|
||||
));
|
||||
match result {
|
||||
Ok(_) => result = Err(e),
|
||||
Err(ref mut error) => error.extend(Some(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
// Things needed to generate the service items: trait, serve impl, request/response enums, and
|
||||
// the client stub.
|
||||
struct ServiceGenerator<'a> {
|
||||
service_ident: &'a Ident,
|
||||
client_stub_ident: &'a Ident,
|
||||
server_ident: &'a Ident,
|
||||
response_fut_ident: &'a Ident,
|
||||
response_fut_name: &'a str,
|
||||
client_ident: &'a Ident,
|
||||
request_ident: &'a Ident,
|
||||
response_ident: &'a Ident,
|
||||
@@ -442,7 +321,6 @@ struct ServiceGenerator<'a> {
|
||||
attrs: &'a [Attribute],
|
||||
rpcs: &'a [RpcMethod],
|
||||
camel_case_idents: &'a [Ident],
|
||||
future_types: &'a [Type],
|
||||
method_idents: &'a [&'a Ident],
|
||||
request_names: &'a [String],
|
||||
method_attrs: &'a [&'a [Attribute]],
|
||||
@@ -458,49 +336,53 @@ impl<'a> ServiceGenerator<'a> {
|
||||
attrs,
|
||||
rpcs,
|
||||
vis,
|
||||
future_types,
|
||||
return_types,
|
||||
service_ident,
|
||||
client_stub_ident,
|
||||
request_ident,
|
||||
response_ident,
|
||||
server_ident,
|
||||
..
|
||||
} = self;
|
||||
|
||||
let types_and_fns = rpcs
|
||||
let rpc_fns = rpcs
|
||||
.iter()
|
||||
.zip(future_types.iter())
|
||||
.zip(return_types.iter())
|
||||
.map(
|
||||
|(
|
||||
(
|
||||
RpcMethod {
|
||||
attrs, ident, args, ..
|
||||
},
|
||||
future_type,
|
||||
),
|
||||
RpcMethod {
|
||||
attrs, ident, args, ..
|
||||
},
|
||||
output,
|
||||
)| {
|
||||
let ty_doc = format!("The response future returned by [`{service_ident}::{ident}`].");
|
||||
quote! {
|
||||
#[doc = #ty_doc]
|
||||
type #future_type: std::future::Future<Output = #output>;
|
||||
|
||||
#( #attrs )*
|
||||
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type;
|
||||
async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output;
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let stub_doc = format!("The stub trait for service [`{service_ident}`].");
|
||||
quote! {
|
||||
#( #attrs )*
|
||||
#vis trait #service_ident: Sized {
|
||||
#( #types_and_fns )*
|
||||
#vis trait #service_ident: ::core::marker::Sized {
|
||||
#( #rpc_fns )*
|
||||
|
||||
/// Returns a serving function to use with
|
||||
/// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute).
|
||||
/// [InFlightRequest::execute](::tarpc::server::InFlightRequest::execute).
|
||||
fn serve(self) -> #server_ident<Self> {
|
||||
#server_ident { service: self }
|
||||
}
|
||||
}
|
||||
|
||||
#[doc = #stub_doc]
|
||||
#vis trait #client_stub_ident: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
|
||||
}
|
||||
|
||||
impl<S> #client_stub_ident for S
|
||||
where S: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
|
||||
{
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -510,7 +392,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
/// A serving function to use with [tarpc::server::InFlightRequest::execute].
|
||||
/// A serving function to use with [::tarpc::server::InFlightRequest::execute].
|
||||
#[derive(Clone)]
|
||||
#vis struct #server_ident<S> {
|
||||
service: S,
|
||||
@@ -524,7 +406,6 @@ impl<'a> ServiceGenerator<'a> {
|
||||
server_ident,
|
||||
service_ident,
|
||||
response_ident,
|
||||
response_fut_ident,
|
||||
camel_case_idents,
|
||||
arg_pats,
|
||||
method_idents,
|
||||
@@ -533,14 +414,14 @@ impl<'a> ServiceGenerator<'a> {
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl<S> tarpc::server::Serve<#request_ident> for #server_ident<S>
|
||||
impl<S> ::tarpc::server::Serve for #server_ident<S>
|
||||
where S: #service_ident
|
||||
{
|
||||
type Req = #request_ident;
|
||||
type Resp = #response_ident;
|
||||
type Fut = #response_fut_ident<S>;
|
||||
|
||||
fn method(&self, req: &#request_ident) -> Option<&'static str> {
|
||||
Some(match req {
|
||||
fn method(&self, req: &#request_ident) -> ::core::option::Option<&'static str> {
|
||||
::core::option::Option::Some(match req {
|
||||
#(
|
||||
#request_ident::#camel_case_idents{..} => {
|
||||
#request_names
|
||||
@@ -549,15 +430,16 @@ impl<'a> ServiceGenerator<'a> {
|
||||
})
|
||||
}
|
||||
|
||||
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
|
||||
async fn serve(self, ctx: ::tarpc::context::Context, req: #request_ident)
|
||||
-> ::core::result::Result<#response_ident, ::tarpc::ServerError> {
|
||||
match req {
|
||||
#(
|
||||
#request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
|
||||
#response_fut_ident::#camel_case_idents(
|
||||
::core::result::Result::Ok(#response_ident::#camel_case_idents(
|
||||
#service_ident::#method_idents(
|
||||
self.service, ctx, #( #arg_pats ),*
|
||||
)
|
||||
)
|
||||
).await
|
||||
))
|
||||
}
|
||||
)*
|
||||
}
|
||||
@@ -608,73 +490,6 @@ impl<'a> ServiceGenerator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn enum_response_future(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
vis,
|
||||
service_ident,
|
||||
response_fut_ident,
|
||||
camel_case_idents,
|
||||
future_types,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
/// A future resolving to a server response.
|
||||
#[allow(missing_docs)]
|
||||
#vis enum #response_fut_ident<S: #service_ident> {
|
||||
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_debug_for_response_future(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
service_ident,
|
||||
response_fut_ident,
|
||||
response_fut_name,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fmt.debug_struct(#response_fut_name).finish()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_future_for_response_future(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
service_ident,
|
||||
response_fut_ident,
|
||||
response_ident,
|
||||
camel_case_idents,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl<S: #service_ident> std::future::Future for #response_fut_ident<S> {
|
||||
type Output = #response_ident;
|
||||
|
||||
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
|
||||
-> std::task::Poll<#response_ident>
|
||||
{
|
||||
unsafe {
|
||||
match std::pin::Pin::get_unchecked_mut(self) {
|
||||
#(
|
||||
#response_fut_ident::#camel_case_idents(resp) =>
|
||||
std::pin::Pin::new_unchecked(resp)
|
||||
.poll(cx)
|
||||
.map(#response_ident::#camel_case_idents),
|
||||
)*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn struct_client(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
vis,
|
||||
@@ -688,8 +503,10 @@ impl<'a> ServiceGenerator<'a> {
|
||||
#[allow(unused)]
|
||||
#[derive(Clone, Debug)]
|
||||
/// The client stub that makes RPC calls to the server. All request methods return
|
||||
/// [Futures](std::future::Future).
|
||||
#vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>);
|
||||
/// [Futures](::core::future::Future).
|
||||
#vis struct #client_ident<
|
||||
Stub = ::tarpc::client::Channel<#request_ident, #response_ident>
|
||||
>(Stub);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -705,20 +522,31 @@ impl<'a> ServiceGenerator<'a> {
|
||||
quote! {
|
||||
impl #client_ident {
|
||||
/// Returns a new client stub that sends requests over the given transport.
|
||||
#vis fn new<T>(config: tarpc::client::Config, transport: T)
|
||||
-> tarpc::client::NewClient<
|
||||
#vis fn new<T>(config: ::tarpc::client::Config, transport: T)
|
||||
-> ::tarpc::client::NewClient<
|
||||
Self,
|
||||
tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
|
||||
::tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
|
||||
>
|
||||
where
|
||||
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>
|
||||
T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>>
|
||||
{
|
||||
let new_client = tarpc::client::new(config, transport);
|
||||
tarpc::client::NewClient {
|
||||
let new_client = ::tarpc::client::new(config, transport);
|
||||
::tarpc::client::NewClient {
|
||||
client: #client_ident(new_client.client),
|
||||
dispatch: new_client.dispatch,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Stub> ::core::convert::From<Stub> for #client_ident<Stub>
|
||||
where Stub: ::tarpc::client::stub::Stub<
|
||||
Req = #request_ident,
|
||||
Resp = #response_ident>
|
||||
{
|
||||
/// Returns a new client stub that sends requests over the given transport.
|
||||
fn from(stub: Stub) -> Self {
|
||||
#client_ident(stub)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
@@ -741,18 +569,22 @@ impl<'a> ServiceGenerator<'a> {
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl #client_ident {
|
||||
impl<Stub> #client_ident<Stub>
|
||||
where Stub: ::tarpc::client::stub::Stub<
|
||||
Req = #request_ident,
|
||||
Resp = #response_ident>
|
||||
{
|
||||
#(
|
||||
#[allow(unused)]
|
||||
#( #method_attrs )*
|
||||
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
|
||||
-> impl std::future::Future<Output = Result<#return_types, tarpc::client::RpcError>> + '_ {
|
||||
#vis fn #method_idents(&self, ctx: ::tarpc::context::Context, #( #args ),*)
|
||||
-> impl ::core::future::Future<Output = ::core::result::Result<#return_types, ::tarpc::client::RpcError>> + '_ {
|
||||
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
|
||||
let resp = self.0.call(ctx, #request_names, request);
|
||||
async move {
|
||||
match resp.await? {
|
||||
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
|
||||
_ => unreachable!(),
|
||||
#response_ident::#camel_case_idents(msg) => ::core::result::Result::Ok(msg),
|
||||
_ => ::core::unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -770,9 +602,6 @@ impl<'a> ToTokens for ServiceGenerator<'a> {
|
||||
self.impl_serve_for_server(),
|
||||
self.enum_request(),
|
||||
self.enum_response(),
|
||||
self.enum_response_future(),
|
||||
self.impl_debug_for_response_future(),
|
||||
self.impl_future_for_response_future(),
|
||||
self.struct_client(),
|
||||
self.impl_client_new(),
|
||||
self.impl_client_rpc_methods(),
|
||||
|
||||
@@ -1,8 +1,3 @@
|
||||
use assert_type_eq::assert_type_eq;
|
||||
use futures::Future;
|
||||
use std::pin::Pin;
|
||||
use tarpc::context;
|
||||
|
||||
// these need to be out here rather than inside the function so that the
|
||||
// assert_type_eq macro can pick them up.
|
||||
#[tarpc::service]
|
||||
@@ -12,42 +7,6 @@ trait Foo {
|
||||
async fn baz();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn type_generation_works() {
|
||||
#[tarpc::server]
|
||||
impl Foo for () {
|
||||
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
|
||||
(s, i)
|
||||
}
|
||||
|
||||
async fn bar(self, _: context::Context, s: String) -> String {
|
||||
s
|
||||
}
|
||||
|
||||
async fn baz(self, _: context::Context) {}
|
||||
}
|
||||
|
||||
// the assert_type_eq macro can only be used once per block.
|
||||
{
|
||||
assert_type_eq!(
|
||||
<() as Foo>::TwoPartFut,
|
||||
Pin<Box<dyn Future<Output = (String, i32)> + Send>>
|
||||
);
|
||||
}
|
||||
{
|
||||
assert_type_eq!(
|
||||
<() as Foo>::BarFut,
|
||||
Pin<Box<dyn Future<Output = String> + Send>>
|
||||
);
|
||||
}
|
||||
{
|
||||
assert_type_eq!(
|
||||
<() as Foo>::BazFut,
|
||||
Pin<Box<dyn Future<Output = ()> + Send>>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
#[test]
|
||||
fn raw_idents_work() {
|
||||
@@ -59,24 +18,6 @@ fn raw_idents_work() {
|
||||
async fn r#fn(r#impl: r#yield) -> r#yield;
|
||||
async fn r#async();
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl r#trait for () {
|
||||
async fn r#await(
|
||||
self,
|
||||
_: context::Context,
|
||||
r#struct: r#yield,
|
||||
r#enum: i32,
|
||||
) -> (r#yield, i32) {
|
||||
(r#struct, r#enum)
|
||||
}
|
||||
|
||||
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
|
||||
r#impl
|
||||
}
|
||||
|
||||
async fn r#async(self, _: context::Context) {}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -100,45 +41,4 @@ fn syntax() {
|
||||
#[doc = "attr"]
|
||||
async fn one_arg_implicit_return_error(one: String);
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl Syntax for () {
|
||||
#[deny(warnings)]
|
||||
#[allow(non_snake_case)]
|
||||
async fn TestCamelCaseDoesntConflict(self, _: context::Context) {}
|
||||
|
||||
async fn hello(self, _: context::Context) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
async fn attr(self, _: context::Context, _s: String) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
async fn no_args_no_return(self, _: context::Context) {}
|
||||
|
||||
async fn no_args(self, _: context::Context) -> () {}
|
||||
|
||||
async fn one_arg(self, _: context::Context, _one: String) -> i32 {
|
||||
0
|
||||
}
|
||||
|
||||
async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {}
|
||||
|
||||
async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
async fn no_args_ret_error(self, _: context::Context) -> i32 {
|
||||
0
|
||||
}
|
||||
|
||||
async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
async fn no_arg_implicit_return_error(self, _: context::Context) {}
|
||||
|
||||
async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,6 @@ use tarpc::context;
|
||||
|
||||
#[test]
|
||||
fn att_service_trait() {
|
||||
use futures::future::{ready, Ready};
|
||||
|
||||
#[tarpc::service]
|
||||
trait Foo {
|
||||
async fn two_part(s: String, i: i32) -> (String, i32);
|
||||
@@ -12,19 +10,16 @@ fn att_service_trait() {
|
||||
}
|
||||
|
||||
impl Foo for () {
|
||||
type TwoPartFut = Ready<(String, i32)>;
|
||||
fn two_part(self, _: context::Context, s: String, i: i32) -> Self::TwoPartFut {
|
||||
ready((s, i))
|
||||
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
|
||||
(s, i)
|
||||
}
|
||||
|
||||
type BarFut = Ready<String>;
|
||||
fn bar(self, _: context::Context, s: String) -> Self::BarFut {
|
||||
ready(s)
|
||||
async fn bar(self, _: context::Context, s: String) -> String {
|
||||
s
|
||||
}
|
||||
|
||||
type BazFut = Ready<()>;
|
||||
fn baz(self, _: context::Context) -> Self::BazFut {
|
||||
ready(())
|
||||
async fn baz(self, _: context::Context) {
|
||||
()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -32,8 +27,6 @@ fn att_service_trait() {
|
||||
#[allow(non_camel_case_types)]
|
||||
#[test]
|
||||
fn raw_idents() {
|
||||
use futures::future::{ready, Ready};
|
||||
|
||||
type r#yield = String;
|
||||
|
||||
#[tarpc::service]
|
||||
@@ -44,19 +37,21 @@ fn raw_idents() {
|
||||
}
|
||||
|
||||
impl r#trait for () {
|
||||
type AwaitFut = Ready<(r#yield, i32)>;
|
||||
fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut {
|
||||
ready((r#struct, r#enum))
|
||||
async fn r#await(
|
||||
self,
|
||||
_: context::Context,
|
||||
r#struct: r#yield,
|
||||
r#enum: i32,
|
||||
) -> (r#yield, i32) {
|
||||
(r#struct, r#enum)
|
||||
}
|
||||
|
||||
type FnFut = Ready<r#yield>;
|
||||
fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut {
|
||||
ready(r#impl)
|
||||
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
|
||||
r#impl
|
||||
}
|
||||
|
||||
type AsyncFut = Ready<()>;
|
||||
fn r#async(self, _: context::Context) -> Self::AsyncFut {
|
||||
ready(())
|
||||
async fn r#async(self, _: context::Context) {
|
||||
()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc"
|
||||
version = "0.31.0"
|
||||
version = "0.34.0"
|
||||
rust-version = "1.58.0"
|
||||
authors = [
|
||||
"Adam Wright <adam.austin.wright@gmail.com>",
|
||||
@@ -19,7 +19,7 @@ description = "An RPC framework for Rust with a focus on ease of use."
|
||||
[features]
|
||||
default = []
|
||||
|
||||
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
|
||||
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive", "serde/rc"]
|
||||
tokio1 = ["tokio/rt"]
|
||||
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
||||
serde-transport-json = ["tokio-serde/json"]
|
||||
@@ -49,7 +49,7 @@ pin-project = "1.0"
|
||||
rand = "0.8"
|
||||
serde = { optional = true, version = "1.0", features = ["derive"] }
|
||||
static_assertions = "1.1.0"
|
||||
tarpc-plugins = { path = "../plugins", version = "0.12" }
|
||||
tarpc-plugins = { path = "../plugins", version = "0.13" }
|
||||
thiserror = "1.0"
|
||||
tokio = { version = "1", features = ["time"] }
|
||||
tokio-util = { version = "0.7.3", features = ["time"] }
|
||||
@@ -58,8 +58,8 @@ tracing = { version = "0.1", default-features = false, features = [
|
||||
"attributes",
|
||||
"log",
|
||||
] }
|
||||
tracing-opentelemetry = { version = "0.17.2", default-features = false }
|
||||
opentelemetry = { version = "0.17.0", default-features = false }
|
||||
tracing-opentelemetry = { version = "0.18.0", default-features = false }
|
||||
opentelemetry = { version = "0.18.0", default-features = false }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
@@ -68,16 +68,19 @@ bincode = "1.3"
|
||||
bytes = { version = "1", features = ["serde"] }
|
||||
flate2 = "1.0"
|
||||
futures-test = "0.3"
|
||||
opentelemetry = { version = "0.17.0", default-features = false, features = [
|
||||
opentelemetry = { version = "0.18.0", default-features = false, features = [
|
||||
"rt-tokio",
|
||||
] }
|
||||
opentelemetry-jaeger = { version = "0.16.0", features = ["rt-tokio"] }
|
||||
opentelemetry-jaeger = { version = "0.17.0", features = ["rt-tokio"] }
|
||||
pin-utils = "0.1.0-alpha"
|
||||
serde_bytes = "0.11"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
tokio = { version = "1", features = ["full", "test-util"] }
|
||||
tokio = { version = "1", features = ["full", "test-util", "tracing"] }
|
||||
console-subscriber = "0.1"
|
||||
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
|
||||
trybuild = "1.0"
|
||||
tokio-rustls = "0.23"
|
||||
rustls-pemfile = "1.0"
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
@@ -103,6 +106,10 @@ required-features = ["full"]
|
||||
name = "custom_transport"
|
||||
required-features = ["serde1", "tokio1", "serde-transport"]
|
||||
|
||||
[[example]]
|
||||
name = "tls_over_tcp"
|
||||
required-features = ["full"]
|
||||
|
||||
[[test]]
|
||||
name = "service_functional"
|
||||
required-features = ["serde-transport"]
|
||||
|
||||
11
tarpc/examples/certs/eddsa/client.cert
Normal file
11
tarpc/examples/certs/eddsa/client.cert
Normal file
@@ -0,0 +1,11 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBlDCCAUagAwIBAgICAxUwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
|
||||
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
|
||||
NjEwMTEwNFowGjEYMBYGA1UEAwwPcG9ueXRvd24gY2xpZW50MCowBQYDK2VwAyEA
|
||||
NTKuLume19IhJfEFd/5OZUuYDKZH6xvy4AGver17OoejgZswgZgwDAYDVR0TAQH/
|
||||
BAIwADALBgNVHQ8EBAMCBsAwFgYDVR0lAQH/BAwwCgYIKwYBBQUHAwIwHQYDVR0O
|
||||
BBYEFDjdrlMu4tyw5MHtbg7WnzSGRBpFMEQGA1UdIwQ9MDuAFHIl7fHKWP6/l8FE
|
||||
fI2YEIM3oHxKoSCkHjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQYIBezAF
|
||||
BgMrZXADQQCaahfj/QLxoCOpvl6y0ZQ9CpojPqBnxV3460j5nUOp040Va2MpF137
|
||||
izCBY7LwgUE/YG6E+kH30G4jMEnqVEYK
|
||||
-----END CERTIFICATE-----
|
||||
19
tarpc/examples/certs/eddsa/client.chain
Normal file
19
tarpc/examples/certs/eddsa/client.chain
Normal file
@@ -0,0 +1,19 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
|
||||
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
|
||||
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
|
||||
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
|
||||
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
|
||||
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
|
||||
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
|
||||
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
|
||||
-----END CERTIFICATE-----
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
|
||||
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
|
||||
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
|
||||
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
|
||||
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
|
||||
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
|
||||
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
|
||||
-----END CERTIFICATE-----
|
||||
3
tarpc/examples/certs/eddsa/client.key
Normal file
3
tarpc/examples/certs/eddsa/client.key
Normal file
@@ -0,0 +1,3 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MC4CAQAwBQYDK2VwBCIEIIJX9ThTHpVS1SNZb6HP4myg4fRInIVGunTRdgnc+weH
|
||||
-----END PRIVATE KEY-----
|
||||
12
tarpc/examples/certs/eddsa/end.cert
Normal file
12
tarpc/examples/certs/eddsa/end.cert
Normal file
@@ -0,0 +1,12 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBuDCCAWqgAwIBAgICAcgwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk
|
||||
RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw
|
||||
NjEwMTEwNFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wKjAFBgMrZXADIQDc
|
||||
RLl3/N2tPoWnzBV3noVn/oheEl8IUtiY11Vg/QXTUKOBwDCBvTAMBgNVHRMBAf8E
|
||||
AjAAMAsGA1UdDwQEAwIGwDAdBgNVHQ4EFgQUk7U2mnxedNWBAH84BsNy5si3ZQow
|
||||
RAYDVR0jBD0wO4AUciXt8cpY/r+XwUR8jZgQgzegfEqhIKQeMBwxGjAYBgNVBAMM
|
||||
EXBvbnl0b3duIEVkRFNBIENBggF7MDsGA1UdEQQ0MDKCDnRlc3RzZXJ2ZXIuY29t
|
||||
ghVzZWNvbmQudGVzdHNlcnZlci5jb22CCWxvY2FsaG9zdDAFBgMrZXADQQCFWIcF
|
||||
9FiztCuUNzgXDNu5kshuflt0RjkjWpGlWzQjGoYM2IvYhNVPeqnCiY92gqwDSBtq
|
||||
amD2TBup4eNUCsQB
|
||||
-----END CERTIFICATE-----
|
||||
19
tarpc/examples/certs/eddsa/end.chain
Normal file
19
tarpc/examples/certs/eddsa/end.chain
Normal file
@@ -0,0 +1,19 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE
|
||||
U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD
|
||||
DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh
|
||||
AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU
|
||||
ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG
|
||||
AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU
|
||||
oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc
|
||||
zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg=
|
||||
-----END CERTIFICATE-----
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG
|
||||
A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0
|
||||
MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh
|
||||
ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU
|
||||
phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR
|
||||
W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC
|
||||
t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB
|
||||
-----END CERTIFICATE-----
|
||||
3
tarpc/examples/certs/eddsa/end.key
Normal file
3
tarpc/examples/certs/eddsa/end.key
Normal file
@@ -0,0 +1,3 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MC4CAQAwBQYDK2VwBCIEIMU6xGVe8JTpZ3bN/wajHfw6pEHt0Rd7wPBxds9eEFy2
|
||||
-----END PRIVATE KEY-----
|
||||
@@ -1,5 +1,11 @@
|
||||
// Copyright 2022 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
|
||||
use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt};
|
||||
use futures::{prelude::*, Sink, SinkExt, Stream, StreamExt, TryStreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_bytes::ByteBuf;
|
||||
use std::{io, io::Read, io::Write};
|
||||
@@ -99,13 +105,16 @@ pub trait World {
|
||||
#[derive(Clone, Debug)]
|
||||
struct HelloServer;
|
||||
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
async fn hello(self, _: context::Context, name: String) -> String {
|
||||
format!("Hey, {name}!")
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let mut incoming = tcp::listen("localhost:0", Bincode::default).await?;
|
||||
@@ -114,6 +123,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let transport = incoming.next().await.unwrap().unwrap();
|
||||
BaseChannel::with_defaults(add_compression(transport))
|
||||
.execute(HelloServer.serve())
|
||||
.for_each(spawn)
|
||||
.await;
|
||||
});
|
||||
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
// Copyright 2022 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use futures::prelude::*;
|
||||
use tarpc::context::Context;
|
||||
use tarpc::serde_transport as transport;
|
||||
use tarpc::server::{BaseChannel, Channel};
|
||||
@@ -13,7 +20,6 @@ pub trait PingService {
|
||||
#[derive(Clone)]
|
||||
struct Service;
|
||||
|
||||
#[tarpc::server]
|
||||
impl PingService for Service {
|
||||
async fn ping(self, _: Context) {}
|
||||
}
|
||||
@@ -26,13 +32,18 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let listener = UnixListener::bind(bind_addr).unwrap();
|
||||
let codec_builder = LengthDelimitedCodec::builder();
|
||||
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (conn, _addr) = listener.accept().await.unwrap();
|
||||
let framed = codec_builder.new_framed(conn);
|
||||
let transport = transport::new(framed, Bincode::default());
|
||||
|
||||
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
|
||||
let fut = BaseChannel::with_defaults(transport)
|
||||
.execute(Service.serve())
|
||||
.for_each(spawn);
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -79,7 +79,6 @@ struct Subscriber {
|
||||
topics: Vec<String>,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl subscriber::Subscriber for Subscriber {
|
||||
async fn topics(self, _: context::Context) -> Vec<String> {
|
||||
self.topics.clone()
|
||||
@@ -117,7 +116,8 @@ impl Subscriber {
|
||||
))
|
||||
}
|
||||
};
|
||||
let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve()));
|
||||
let (handler, abort_handle) =
|
||||
future::abortable(handler.execute(subscriber.serve()).for_each(spawn));
|
||||
tokio::spawn(async move {
|
||||
match handler.await {
|
||||
Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."),
|
||||
@@ -143,6 +143,10 @@ struct PublisherAddrs {
|
||||
subscriptions: SocketAddr,
|
||||
}
|
||||
|
||||
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
|
||||
impl Publisher {
|
||||
async fn start(self) -> io::Result<PublisherAddrs> {
|
||||
let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?;
|
||||
@@ -162,6 +166,7 @@ impl Publisher {
|
||||
|
||||
server::BaseChannel::with_defaults(publisher)
|
||||
.execute(self.serve())
|
||||
.for_each(spawn)
|
||||
.await
|
||||
});
|
||||
|
||||
@@ -257,7 +262,6 @@ impl Publisher {
|
||||
}
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl publisher::Publisher for Publisher {
|
||||
async fn publish(self, _: context::Context, topic: String, message: String) {
|
||||
info!("received message to publish.");
|
||||
@@ -282,7 +286,7 @@ impl publisher::Publisher for Publisher {
|
||||
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
|
||||
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||
let tracer = opentelemetry_jaeger::new_agent_pipeline()
|
||||
.with_service_name(service_name)
|
||||
.with_max_packet_size(2usize.pow(13))
|
||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use futures::future::{self, Ready};
|
||||
use futures::prelude::*;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{self, Channel},
|
||||
@@ -23,22 +23,21 @@ pub trait World {
|
||||
struct HelloServer;
|
||||
|
||||
impl World for HelloServer {
|
||||
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
|
||||
// an associated type representing the future output by the fn.
|
||||
|
||||
type HelloFut = Ready<String>;
|
||||
|
||||
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
||||
future::ready(format!("Hello, {name}!"))
|
||||
async fn hello(self, _: context::Context, name: String) -> String {
|
||||
format!("Hello, {name}!")
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
|
||||
let server = server::BaseChannel::with_defaults(server_transport);
|
||||
tokio::spawn(server.execute(HelloServer.serve()));
|
||||
tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn));
|
||||
|
||||
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
// that takes a config and any Transport as input.
|
||||
|
||||
150
tarpc/examples/tls_over_tcp.rs
Normal file
150
tarpc/examples/tls_over_tcp.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
// Copyright 2023 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use futures::prelude::*;
|
||||
use rustls_pemfile::certs;
|
||||
use std::io::{BufReader, Cursor};
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient;
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::{self, RootCertStore};
|
||||
use tokio_rustls::{TlsAcceptor, TlsConnector};
|
||||
|
||||
use tarpc::context::Context;
|
||||
use tarpc::serde_transport as transport;
|
||||
use tarpc::server::{BaseChannel, Channel};
|
||||
use tarpc::tokio_serde::formats::Bincode;
|
||||
use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec;
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait PingService {
|
||||
async fn ping() -> String;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Service;
|
||||
|
||||
impl PingService for Service {
|
||||
async fn ping(self, _: Context) -> String {
|
||||
"🔒".to_owned()
|
||||
}
|
||||
}
|
||||
|
||||
// certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca
|
||||
// used on client-side for server tls
|
||||
const END_CHAIN: &str = include_str!("certs/eddsa/end.chain");
|
||||
// used on client-side for client-auth
|
||||
const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key");
|
||||
const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert");
|
||||
|
||||
// used on server-side for server tls
|
||||
const END_CERT: &str = include_str!("certs/eddsa/end.cert");
|
||||
const END_PRIVATEKEY: &str = include_str!("certs/eddsa/end.key");
|
||||
// used on server-side for client-auth
|
||||
const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain");
|
||||
|
||||
pub fn load_certs(data: &str) -> Vec<rustls::Certificate> {
|
||||
certs(&mut BufReader::new(Cursor::new(data)))
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn load_private_key(key: &str) -> rustls::PrivateKey {
|
||||
let mut reader = BufReader::new(Cursor::new(key));
|
||||
loop {
|
||||
match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
|
||||
Some(rustls_pemfile::Item::RSAKey(key)) => return rustls::PrivateKey(key),
|
||||
Some(rustls_pemfile::Item::PKCS8Key(key)) => return rustls::PrivateKey(key),
|
||||
Some(rustls_pemfile::Item::ECKey(key)) => return rustls::PrivateKey(key),
|
||||
None => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
panic!("no keys found in {:?} (encrypted keys not supported)", key);
|
||||
}
|
||||
|
||||
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// -------------------- start here to setup tls tcp tokio stream --------------------------
|
||||
// ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs
|
||||
// ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs
|
||||
let cert = load_certs(END_CERT);
|
||||
let key = load_private_key(END_PRIVATEKEY);
|
||||
let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000);
|
||||
|
||||
// ------------- server side client_auth cert loading start
|
||||
let mut client_auth_roots = RootCertStore::empty();
|
||||
for root in load_certs(CLIENT_CHAIN_CLIENT_AUTH) {
|
||||
client_auth_roots.add(&root).unwrap();
|
||||
}
|
||||
let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots);
|
||||
// ------------- server side client_auth cert loading end
|
||||
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_client_cert_verifier(client_auth) // use .with_no_client_auth() instead if you don't want client-auth
|
||||
.with_single_cert(cert, key)
|
||||
.unwrap();
|
||||
let acceptor = TlsAcceptor::from(Arc::new(config));
|
||||
let listener = TcpListener::bind(&server_addr).await.unwrap();
|
||||
let codec_builder = LengthDelimitedCodec::builder();
|
||||
|
||||
// ref ./custom_transport.rs server side
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (stream, _peer_addr) = listener.accept().await.unwrap();
|
||||
let tls_stream = acceptor.accept(stream).await.unwrap();
|
||||
let framed = codec_builder.new_framed(tls_stream);
|
||||
|
||||
let transport = transport::new(framed, Bincode::default());
|
||||
|
||||
let fut = BaseChannel::with_defaults(transport)
|
||||
.execute(Service.serve())
|
||||
.for_each(spawn);
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
});
|
||||
|
||||
// ---------------------- client connection ---------------------
|
||||
// tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
for root in load_certs(END_CHAIN) {
|
||||
root_store.add(&root).unwrap();
|
||||
}
|
||||
|
||||
let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH);
|
||||
let client_auth_certs = load_certs(CLIENT_CERT_CLIENT_AUTH);
|
||||
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_single_cert(client_auth_certs, client_auth_private_key)?; // use .with_no_client_auth() instead if you don't want client-auth
|
||||
|
||||
let domain = rustls::ServerName::try_from("localhost")?;
|
||||
let connector = TlsConnector::from(Arc::new(config));
|
||||
|
||||
let stream = TcpStream::connect(server_addr).await?;
|
||||
let stream = connector.connect(domain, stream).await?;
|
||||
|
||||
let transport = transport::new(codec_builder.new_framed(stream), Bincode::default());
|
||||
let answer = PingServiceClient::new(Default::default(), transport)
|
||||
.spawn()
|
||||
.ping(tarpc::context::current())
|
||||
.await?;
|
||||
|
||||
println!("ping answer: {answer}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -4,13 +4,34 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use crate::{add::Add as AddService, double::Double as DoubleService};
|
||||
use futures::{future, prelude::*};
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{incoming::Incoming, BaseChannel},
|
||||
tokio_serde::formats::Json,
|
||||
use crate::{
|
||||
add::{Add as AddService, AddStub},
|
||||
double::Double as DoubleService,
|
||||
};
|
||||
use futures::{future, prelude::*};
|
||||
use std::{
|
||||
io,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use tarpc::{
|
||||
client::{
|
||||
self,
|
||||
stub::{load_balance, retry},
|
||||
RpcError,
|
||||
},
|
||||
context, serde_transport,
|
||||
server::{
|
||||
incoming::{spawn_incoming, Incoming},
|
||||
request_hook::{self, BeforeRequestList},
|
||||
BaseChannel,
|
||||
},
|
||||
tokio_serde::formats::Json,
|
||||
ClientMessage, Response, ServerError, Transport,
|
||||
};
|
||||
use tokio::net::TcpStream;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
pub mod add {
|
||||
@@ -32,7 +53,6 @@ pub mod double {
|
||||
#[derive(Clone)]
|
||||
struct AddServer;
|
||||
|
||||
#[tarpc::server]
|
||||
impl AddService for AddServer {
|
||||
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
|
||||
x + y
|
||||
@@ -40,12 +60,14 @@ impl AddService for AddServer {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct DoubleServer {
|
||||
add_client: add::AddClient,
|
||||
struct DoubleServer<Stub> {
|
||||
add_client: add::AddClient<Stub>,
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl DoubleService for DoubleServer {
|
||||
impl<Stub> DoubleService for DoubleServer<Stub>
|
||||
where
|
||||
Stub: AddStub + Clone + Send + Sync + 'static,
|
||||
{
|
||||
async fn double(self, _: context::Context, x: i32) -> Result<i32, String> {
|
||||
self.add_client
|
||||
.add(context::current(), x, x)
|
||||
@@ -55,7 +77,7 @@ impl DoubleService for DoubleServer {
|
||||
}
|
||||
|
||||
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
||||
let tracer = opentelemetry_jaeger::new_agent_pipeline()
|
||||
.with_service_name(service_name)
|
||||
.with_auto_split_batch(true)
|
||||
.with_max_packet_size(2usize.pow(13))
|
||||
@@ -70,32 +92,88 @@ fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen_on_random_port<Item, SinkItem>() -> anyhow::Result<(
|
||||
impl Stream<Item = serde_transport::Transport<TcpStream, Item, SinkItem, Json<Item, SinkItem>>>,
|
||||
std::net::SocketAddr,
|
||||
)>
|
||||
where
|
||||
Item: for<'de> serde::Deserialize<'de>,
|
||||
SinkItem: serde::Serialize,
|
||||
{
|
||||
let listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()))
|
||||
.take(1);
|
||||
let addr = listener.get_ref().get_ref().local_addr();
|
||||
Ok((listener, addr))
|
||||
}
|
||||
|
||||
fn make_stub<Req, Resp, const N: usize>(
|
||||
backends: [impl Transport<ClientMessage<Arc<Req>>, Response<Resp>> + Send + Sync + 'static; N],
|
||||
) -> retry::Retry<
|
||||
impl Fn(&Result<Resp, RpcError>, u32) -> bool + Clone,
|
||||
load_balance::RoundRobin<client::Channel<Arc<Req>, Resp>>,
|
||||
>
|
||||
where
|
||||
Req: Send + Sync + 'static,
|
||||
Resp: Send + Sync + 'static,
|
||||
{
|
||||
let stub = load_balance::RoundRobin::new(
|
||||
backends
|
||||
.into_iter()
|
||||
.map(|transport| tarpc::client::new(client::Config::default(), transport).spawn())
|
||||
.collect(),
|
||||
);
|
||||
let stub = retry::Retry::new(stub, |resp, attempts| {
|
||||
if let Err(e) = resp {
|
||||
tracing::warn!("Got an error: {e:?}");
|
||||
attempts < 3
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
stub
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
init_tracing("tarpc_tracing_example")?;
|
||||
|
||||
let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let addr = add_listener.get_ref().local_addr();
|
||||
let add_server = add_listener
|
||||
.map(BaseChannel::with_defaults)
|
||||
.take(1)
|
||||
.execute(AddServer.serve());
|
||||
tokio::spawn(add_server);
|
||||
let (add_listener1, addr1) = listen_on_random_port().await?;
|
||||
let (add_listener2, addr2) = listen_on_random_port().await?;
|
||||
let something_bad_happened = Arc::new(AtomicBool::new(false));
|
||||
let server = request_hook::before()
|
||||
.then_fn(move |_: &mut _, _: &_| {
|
||||
let something_bad_happened = something_bad_happened.clone();
|
||||
async move {
|
||||
if something_bad_happened.fetch_xor(true, Ordering::Relaxed) {
|
||||
Err(ServerError::new(
|
||||
io::ErrorKind::NotFound,
|
||||
"Gamma Ray!".into(),
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
})
|
||||
.serving(AddServer.serve());
|
||||
let add_server = add_listener1
|
||||
.chain(add_listener2)
|
||||
.map(BaseChannel::with_defaults);
|
||||
tokio::spawn(spawn_incoming(add_server.execute(server)));
|
||||
|
||||
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn();
|
||||
let add_client = add::AddClient::from(make_stub([
|
||||
tarpc::serde_transport::tcp::connect(addr1, Json::default).await?,
|
||||
tarpc::serde_transport::tcp::connect(addr2, Json::default).await?,
|
||||
]));
|
||||
|
||||
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||
.await?
|
||||
.filter_map(|r| future::ready(r.ok()));
|
||||
let addr = double_listener.get_ref().local_addr();
|
||||
let double_server = double_listener
|
||||
.map(BaseChannel::with_defaults)
|
||||
.take(1)
|
||||
.execute(DoubleServer { add_client }.serve());
|
||||
tokio::spawn(double_server);
|
||||
let double_server = double_listener.map(BaseChannel::with_defaults).take(1);
|
||||
let server = DoubleServer { add_client }.serve();
|
||||
tokio::spawn(spawn_incoming(double_server.execute(server)));
|
||||
|
||||
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
let double_client =
|
||||
|
||||
@@ -7,18 +7,18 @@
|
||||
//! Provides a client that connects to a server and sends multiplexed requests.
|
||||
|
||||
mod in_flight_requests;
|
||||
pub mod stub;
|
||||
|
||||
use crate::{
|
||||
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||
context, trace, ClientMessage, Request, Response, ServerError, Transport,
|
||||
context, trace, ChannelError, ClientMessage, Request, Response, ServerError, Transport,
|
||||
};
|
||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
|
||||
use in_flight_requests::InFlightRequests;
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
error::Error,
|
||||
fmt, mem,
|
||||
fmt,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
@@ -124,7 +124,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
pub async fn call(
|
||||
&self,
|
||||
mut ctx: context::Context,
|
||||
request_name: &str,
|
||||
request_name: &'static str,
|
||||
request: Req,
|
||||
) -> Result<Resp, RpcError> {
|
||||
let span = Span::current();
|
||||
@@ -147,6 +147,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
response: &mut response,
|
||||
request_id,
|
||||
cancellation: &self.cancellation,
|
||||
cancel: true,
|
||||
};
|
||||
self.to_dispatch
|
||||
.send(DispatchRequest {
|
||||
@@ -157,7 +158,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
response_completion,
|
||||
})
|
||||
.await
|
||||
.map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
|
||||
.map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?;
|
||||
response_guard.response().await
|
||||
}
|
||||
}
|
||||
@@ -165,19 +166,25 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
/// A server response that is completed by request dispatch when the corresponding response
|
||||
/// arrives off the wire.
|
||||
struct ResponseGuard<'a, Resp> {
|
||||
response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
response: &'a mut oneshot::Receiver<Result<Resp, RpcError>>,
|
||||
cancellation: &'a RequestCancellation,
|
||||
request_id: u64,
|
||||
cancel: bool,
|
||||
}
|
||||
|
||||
/// An error that can occur in the processing of an RPC. This is not request-specific errors but
|
||||
/// rather cross-cutting errors that can always occur.
|
||||
#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum RpcError {
|
||||
/// The client disconnected from the server.
|
||||
#[error("the client disconnected from the server")]
|
||||
Disconnected,
|
||||
#[error("the connection to the server was already shutdown")]
|
||||
Shutdown,
|
||||
/// The client failed to send the request.
|
||||
#[error("the client failed to send the request")]
|
||||
Send(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
/// An error occurred while waiting for the server response.
|
||||
#[error("an error occurred while waiting for the server response")]
|
||||
Receive(#[source] Arc<dyn std::error::Error + Send + Sync + 'static>),
|
||||
/// The request exceeded its deadline.
|
||||
#[error("the request exceeded its deadline")]
|
||||
DeadlineExceeded,
|
||||
@@ -186,24 +193,18 @@ pub enum RpcError {
|
||||
Server(#[from] ServerError),
|
||||
}
|
||||
|
||||
impl From<DeadlineExceededError> for RpcError {
|
||||
fn from(_: DeadlineExceededError) -> Self {
|
||||
RpcError::DeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
impl<Resp> ResponseGuard<'_, Resp> {
|
||||
async fn response(mut self) -> Result<Resp, RpcError> {
|
||||
let response = (&mut self.response).await;
|
||||
// Cancel drop logic once a response has been received.
|
||||
mem::forget(self);
|
||||
self.cancel = false;
|
||||
match response {
|
||||
Ok(resp) => Ok(resp?.message?),
|
||||
Ok(response) => response,
|
||||
Err(oneshot::error::RecvError { .. }) => {
|
||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||
// there's nothing listening on the other side, so there's no point in
|
||||
// propagating cancellation.
|
||||
Err(RpcError::Disconnected)
|
||||
Err(RpcError::Shutdown)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -223,7 +224,9 @@ impl<Resp> Drop for ResponseGuard<'_, Resp> {
|
||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||
// receiver as closed.
|
||||
self.response.close();
|
||||
self.cancellation.cancel(self.request_id);
|
||||
if self.cancel {
|
||||
self.cancellation.cancel(self.request_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,7 +241,6 @@ where
|
||||
{
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
|
||||
let (cancellation, canceled_requests) = cancellations();
|
||||
let canceled_requests = canceled_requests;
|
||||
|
||||
NewClient {
|
||||
client: Channel {
|
||||
@@ -270,42 +272,18 @@ pub struct RequestDispatch<Req, Resp, C> {
|
||||
/// Requests that were dropped.
|
||||
canceled_requests: CanceledRequests,
|
||||
/// Requests already written to the wire that haven't yet received responses.
|
||||
in_flight_requests: InFlightRequests<Resp>,
|
||||
in_flight_requests: InFlightRequests<Result<Resp, RpcError>>,
|
||||
/// Configures limits to prevent unlimited resource usage.
|
||||
config: Config,
|
||||
}
|
||||
|
||||
/// Critical errors that result in a Channel disconnecting.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ChannelError<E>
|
||||
where
|
||||
E: Error + Send + Sync + 'static,
|
||||
{
|
||||
/// Could not read from the transport.
|
||||
#[error("could not read from the transport")]
|
||||
Read(#[source] E),
|
||||
/// Could not ready the transport for writes.
|
||||
#[error("could not ready the transport for writes")]
|
||||
Ready(#[source] E),
|
||||
/// Could not write to the transport.
|
||||
#[error("could not write to the transport")]
|
||||
Write(#[source] E),
|
||||
/// Could not flush the transport.
|
||||
#[error("could not flush the transport")]
|
||||
Flush(#[source] E),
|
||||
/// Could not close the write end of the transport.
|
||||
#[error("could not close the write end of the transport")]
|
||||
Close(#[source] E),
|
||||
/// Could not poll expired requests.
|
||||
#[error("could not poll expired requests")]
|
||||
Timer(#[source] tokio::time::error::Error),
|
||||
}
|
||||
|
||||
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
|
||||
where
|
||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||
{
|
||||
fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<Resp> {
|
||||
fn in_flight_requests<'a>(
|
||||
self: &'a mut Pin<&mut Self>,
|
||||
) -> &'a mut InFlightRequests<Result<Resp, RpcError>> {
|
||||
self.as_mut().project().in_flight_requests
|
||||
}
|
||||
|
||||
@@ -365,7 +343,17 @@ where
|
||||
) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
|
||||
self.transport_pin_mut()
|
||||
.poll_next(cx)
|
||||
.map_err(ChannelError::Read)
|
||||
.map_err(|e| {
|
||||
let e = Arc::new(e);
|
||||
for span in self
|
||||
.in_flight_requests()
|
||||
.complete_all_requests(|| Err(RpcError::Receive(e.clone())))
|
||||
{
|
||||
let _entered = span.enter();
|
||||
tracing::info!("ReceiveError");
|
||||
}
|
||||
ChannelError::Read(e)
|
||||
})
|
||||
.map_ok(|response| {
|
||||
self.complete(response);
|
||||
})
|
||||
@@ -395,7 +383,10 @@ where
|
||||
// Receiving Poll::Ready(None) when polling expired requests never indicates "Closed",
|
||||
// because there can temporarily be zero in-flight rquests. Therefore, there is no need to
|
||||
// track the status like is done with pending and cancelled requests.
|
||||
if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx) {
|
||||
if let Poll::Ready(Some(_)) = self
|
||||
.in_flight_requests()
|
||||
.poll_expired(cx, || Err(RpcError::DeadlineExceeded))
|
||||
{
|
||||
// Expired requests are considered complete; there is no compelling reason to send a
|
||||
// cancellation message to the server, since it will have already exhausted its
|
||||
// allotted processing time.
|
||||
@@ -506,11 +497,10 @@ where
|
||||
Some(dispatch_request) => dispatch_request,
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
let entered = span.enter();
|
||||
let _entered = span.enter();
|
||||
// poll_next_request only returns Ready if there is room to buffer another request.
|
||||
// Therefore, we can call write_request without fear of erroring due to a full
|
||||
// buffer.
|
||||
let request_id = request_id;
|
||||
let request = ClientMessage::Request(Request {
|
||||
id: request_id,
|
||||
message: request,
|
||||
@@ -519,13 +509,16 @@ where
|
||||
trace_context: ctx.trace_context,
|
||||
},
|
||||
});
|
||||
self.start_send(request)?;
|
||||
tracing::info!("SendRequest");
|
||||
drop(entered);
|
||||
|
||||
self.in_flight_requests()
|
||||
.insert_request(request_id, ctx, span, response_completion)
|
||||
.insert_request(request_id, ctx, span.clone(), response_completion)
|
||||
.expect("Request IDs should be unique");
|
||||
match self.start_send(request) {
|
||||
Ok(()) => tracing::info!("SendRequest"),
|
||||
Err(e) => {
|
||||
self.in_flight_requests()
|
||||
.complete_request(request_id, Err(RpcError::Send(Box::new(e))));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
@@ -550,7 +543,15 @@ where
|
||||
|
||||
/// Sends a server response to the client task that initiated the associated request.
|
||||
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
|
||||
self.in_flight_requests().complete_request(response)
|
||||
if let Some(span) = self.in_flight_requests().complete_request(
|
||||
response.request_id,
|
||||
response.message.map_err(RpcError::Server),
|
||||
) {
|
||||
let _entered = span.enter();
|
||||
tracing::info!("ReceiveResponse");
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -599,30 +600,37 @@ struct DispatchRequest<Req, Resp> {
|
||||
pub span: Span,
|
||||
pub request_id: u64,
|
||||
pub request: Req,
|
||||
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
pub response_completion: oneshot::Sender<Result<Resp, RpcError>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard};
|
||||
use super::{
|
||||
cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError,
|
||||
};
|
||||
use crate::{
|
||||
client::{
|
||||
in_flight_requests::{DeadlineExceededError, InFlightRequests},
|
||||
Config,
|
||||
},
|
||||
context,
|
||||
client::{in_flight_requests::InFlightRequests, Config},
|
||||
context::{self, current},
|
||||
transport::{self, channel::UnboundedChannel},
|
||||
ClientMessage, Response,
|
||||
ChannelError, ClientMessage, Response,
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{prelude::*, task::*};
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
fmt::Display,
|
||||
marker::PhantomData,
|
||||
pin::Pin,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
sync::Arc,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{
|
||||
mpsc::{self},
|
||||
oneshot,
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tracing::Span;
|
||||
|
||||
#[tokio::test]
|
||||
@@ -643,7 +651,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
|
||||
assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
|
||||
assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -654,6 +662,7 @@ mod tests {
|
||||
response: &mut response,
|
||||
cancellation: &cancellation,
|
||||
request_id: 3,
|
||||
cancel: true,
|
||||
});
|
||||
// resp's drop() is run, which should send a cancel message.
|
||||
let cx = &mut Context::from_waker(noop_waker_ref());
|
||||
@@ -674,6 +683,7 @@ mod tests {
|
||||
response: &mut response,
|
||||
cancellation: &cancellation,
|
||||
request_id: 3,
|
||||
cancel: true,
|
||||
}
|
||||
.response()
|
||||
.await
|
||||
@@ -775,6 +785,185 @@ mod tests {
|
||||
assert!(dispatch.as_mut().poll_next_request(cx).is_pending());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shutdown_error() {
|
||||
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
|
||||
let (dispatch, mut channel, _) = set_up();
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
// send succeeds
|
||||
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
drop(dispatch);
|
||||
// error on receive
|
||||
assert_matches!(resp.response().await, Err(RpcError::Shutdown));
|
||||
let (dispatch, channel, _) = set_up();
|
||||
drop(dispatch);
|
||||
// error on send
|
||||
let resp = channel
|
||||
.call(current(), "test_request", "hi".to_string())
|
||||
.await;
|
||||
assert_matches!(resp, Err(RpcError::Shutdown));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_write() {
|
||||
let cause = TransportError::Write;
|
||||
let (mut dispatch, mut channel, mut cx) = setup_always_err(cause);
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
|
||||
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
assert!(dispatch.as_mut().poll(&mut cx).is_pending());
|
||||
let res = resp.response().await;
|
||||
assert_matches!(res, Err(RpcError::Send(_)));
|
||||
let client_error: anyhow::Error = res.unwrap_err().into();
|
||||
let mut chain = client_error.chain();
|
||||
chain.next(); // original RpcError
|
||||
assert_eq!(
|
||||
chain
|
||||
.next()
|
||||
.unwrap()
|
||||
.downcast_ref::<ChannelError<TransportError>>(),
|
||||
Some(&ChannelError::Write(cause))
|
||||
);
|
||||
assert_eq!(
|
||||
client_error.root_cause().downcast_ref::<TransportError>(),
|
||||
Some(&cause)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_read() {
|
||||
let cause = TransportError::Read;
|
||||
let (mut dispatch, mut channel, mut cx) = setup_always_err(cause);
|
||||
let (tx, mut rx) = oneshot::channel();
|
||||
let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
|
||||
assert_eq!(
|
||||
dispatch.as_mut().pump_write(&mut cx),
|
||||
Poll::Ready(Some(Ok(())))
|
||||
);
|
||||
assert_eq!(
|
||||
dispatch.as_mut().pump_read(&mut cx),
|
||||
Poll::Ready(Some(Err(ChannelError::Read(Arc::new(cause)))))
|
||||
);
|
||||
assert_matches!(resp.response().await, Err(RpcError::Receive(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_ready() {
|
||||
let cause = TransportError::Ready;
|
||||
let (mut dispatch, _, mut cx) = setup_always_err(cause);
|
||||
assert_eq!(
|
||||
dispatch.as_mut().poll(&mut cx),
|
||||
Poll::Ready(Err(ChannelError::Ready(cause)))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_flush() {
|
||||
let cause = TransportError::Flush;
|
||||
let (mut dispatch, _, mut cx) = setup_always_err(cause);
|
||||
assert_eq!(
|
||||
dispatch.as_mut().poll(&mut cx),
|
||||
Poll::Ready(Err(ChannelError::Flush(cause)))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transport_error_close() {
|
||||
let cause = TransportError::Close;
|
||||
let (mut dispatch, channel, mut cx) = setup_always_err(cause);
|
||||
drop(channel);
|
||||
assert_eq!(
|
||||
dispatch.as_mut().poll(&mut cx),
|
||||
Poll::Ready(Err(ChannelError::Close(cause)))
|
||||
);
|
||||
}
|
||||
|
||||
fn setup_always_err(
|
||||
cause: TransportError,
|
||||
) -> (
|
||||
Pin<Box<RequestDispatch<String, String, AlwaysErrorTransport<String>>>>,
|
||||
Channel<String, String>,
|
||||
Context<'static>,
|
||||
) {
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
||||
let (cancellation, canceled_requests) = cancellations();
|
||||
let transport: AlwaysErrorTransport<String> = AlwaysErrorTransport(cause, PhantomData);
|
||||
let dispatch = Box::pin(RequestDispatch::<String, String, _> {
|
||||
transport: transport.fuse(),
|
||||
pending_requests,
|
||||
canceled_requests,
|
||||
in_flight_requests: InFlightRequests::default(),
|
||||
config: Config::default(),
|
||||
});
|
||||
let channel = Channel {
|
||||
to_dispatch,
|
||||
cancellation,
|
||||
next_request_id: Arc::new(AtomicUsize::new(0)),
|
||||
};
|
||||
let cx = Context::from_waker(noop_waker_ref());
|
||||
(dispatch, channel, cx)
|
||||
}
|
||||
|
||||
struct AlwaysErrorTransport<I>(TransportError, PhantomData<I>);
|
||||
|
||||
#[derive(Debug, Error, PartialEq, Eq, Clone, Copy)]
|
||||
enum TransportError {
|
||||
Read,
|
||||
Ready,
|
||||
Write,
|
||||
Flush,
|
||||
Close,
|
||||
}
|
||||
|
||||
impl Display for TransportError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(&format!("{self:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Clone, S> Sink<S> for AlwaysErrorTransport<I> {
|
||||
type Error = TransportError;
|
||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
match self.0 {
|
||||
TransportError::Ready => Poll::Ready(Err(self.0)),
|
||||
TransportError::Flush => Poll::Pending,
|
||||
_ => Poll::Ready(Ok(())),
|
||||
}
|
||||
}
|
||||
fn start_send(self: Pin<&mut Self>, _: S) -> Result<(), Self::Error> {
|
||||
if matches!(self.0, TransportError::Write) {
|
||||
Err(self.0)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
if matches!(self.0, TransportError::Flush) {
|
||||
Poll::Ready(Err(self.0))
|
||||
} else {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
if matches!(self.0, TransportError::Close) {
|
||||
Poll::Ready(Err(self.0))
|
||||
} else {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Clone> Stream for AlwaysErrorTransport<I> {
|
||||
type Item = Result<Response<I>, TransportError>;
|
||||
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
if matches!(self.0, TransportError::Read) {
|
||||
Poll::Ready(Some(Err(self.0)))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn set_up() -> (
|
||||
Pin<
|
||||
Box<
|
||||
@@ -814,8 +1003,8 @@ mod tests {
|
||||
async fn send_request<'a>(
|
||||
channel: &'a mut Channel<String, String>,
|
||||
request: &str,
|
||||
response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
|
||||
response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
|
||||
response_completion: oneshot::Sender<Result<String, RpcError>>,
|
||||
response: &'a mut oneshot::Receiver<Result<String, RpcError>>,
|
||||
) -> ResponseGuard<'a, String> {
|
||||
let request_id =
|
||||
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
|
||||
@@ -830,6 +1019,7 @@ mod tests {
|
||||
response,
|
||||
cancellation: &channel.cancellation,
|
||||
request_id,
|
||||
cancel: true,
|
||||
};
|
||||
channel.to_dispatch.send(request).await.unwrap();
|
||||
response_guard
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use crate::{
|
||||
context,
|
||||
util::{Compact, TimeUntil},
|
||||
Response,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use std::{
|
||||
@@ -28,17 +27,11 @@ impl<Resp> Default for InFlightRequests<Resp> {
|
||||
}
|
||||
}
|
||||
|
||||
/// The request exceeded its deadline.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[error("the request exceeded its deadline")]
|
||||
pub struct DeadlineExceededError;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RequestData<Resp> {
|
||||
struct RequestData<Res> {
|
||||
ctx: context::Context,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
response_completion: oneshot::Sender<Res>,
|
||||
/// The key to remove the timer for the request's deadline.
|
||||
deadline_key: delay_queue::Key,
|
||||
}
|
||||
@@ -48,7 +41,7 @@ struct RequestData<Resp> {
|
||||
#[derive(Debug)]
|
||||
pub struct AlreadyExistsError;
|
||||
|
||||
impl<Resp> InFlightRequests<Resp> {
|
||||
impl<Res> InFlightRequests<Res> {
|
||||
/// Returns the number of in-flight requests.
|
||||
pub fn len(&self) -> usize {
|
||||
self.request_data.len()
|
||||
@@ -65,7 +58,7 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
request_id: u64,
|
||||
ctx: context::Context,
|
||||
span: Span,
|
||||
response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
response_completion: oneshot::Sender<Res>,
|
||||
) -> Result<(), AlreadyExistsError> {
|
||||
match self.request_data.entry(request_id) {
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
@@ -84,23 +77,31 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
}
|
||||
|
||||
/// Removes a request without aborting. Returns true iff the request was found.
|
||||
pub fn complete_request(&mut self, response: Response<Resp>) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&response.request_id) {
|
||||
let _entered = request_data.span.enter();
|
||||
tracing::info!("ReceiveResponse");
|
||||
pub fn complete_request(&mut self, request_id: u64, result: Res) -> Option<Span> {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
let _ = request_data.response_completion.send(Ok(response));
|
||||
return true;
|
||||
let _ = request_data.response_completion.send(result);
|
||||
return Some(request_data.span);
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
"No in-flight request found for request_id = {}.",
|
||||
response.request_id
|
||||
);
|
||||
tracing::debug!("No in-flight request found for request_id = {request_id}.");
|
||||
|
||||
// If the response completion was absent, then the request was already canceled.
|
||||
false
|
||||
None
|
||||
}
|
||||
|
||||
/// Completes all requests using the provided function.
|
||||
/// Returns Spans for all completes requests.
|
||||
pub fn complete_all_requests<'a>(
|
||||
&'a mut self,
|
||||
mut result: impl FnMut() -> Res + 'a,
|
||||
) -> impl Iterator<Item = Span> + 'a {
|
||||
self.deadlines.clear();
|
||||
self.request_data.drain().map(move |(_, request_data)| {
|
||||
let _ = request_data.response_completion.send(result());
|
||||
request_data.span
|
||||
})
|
||||
}
|
||||
|
||||
/// Cancels a request without completing (typically used when a request handle was dropped
|
||||
@@ -117,16 +118,18 @@ impl<Resp> InFlightRequests<Resp> {
|
||||
|
||||
/// Yields a request that has expired, completing it with a TimedOut error.
|
||||
/// The caller should send cancellation messages for any yielded request ID.
|
||||
pub fn poll_expired(&mut self, cx: &mut Context) -> Poll<Option<u64>> {
|
||||
pub fn poll_expired(
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
expired_error: impl Fn() -> Res,
|
||||
) -> Poll<Option<u64>> {
|
||||
self.deadlines.poll_expired(cx).map(|expired| {
|
||||
let request_id = expired?.into_inner();
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
let _entered = request_data.span.enter();
|
||||
tracing::error!("DeadlineExceeded");
|
||||
self.request_data.compact(0.1);
|
||||
let _ = request_data
|
||||
.response_completion
|
||||
.send(Err(DeadlineExceededError));
|
||||
let _ = request_data.response_completion.send(expired_error());
|
||||
}
|
||||
Some(request_id)
|
||||
})
|
||||
|
||||
45
tarpc/src/client/stub.rs
Normal file
45
tarpc/src/client/stub.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
//! Provides a Stub trait, implemented by types that can call remote services.
|
||||
|
||||
use crate::{
|
||||
client::{Channel, RpcError},
|
||||
context,
|
||||
};
|
||||
|
||||
pub mod load_balance;
|
||||
pub mod retry;
|
||||
|
||||
#[cfg(test)]
|
||||
mod mock;
|
||||
|
||||
/// A connection to a remote service.
|
||||
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub trait Stub {
|
||||
/// The service request type.
|
||||
type Req;
|
||||
|
||||
/// The service response type.
|
||||
type Resp;
|
||||
|
||||
/// Calls a remote service.
|
||||
async fn call(
|
||||
&self,
|
||||
ctx: context::Context,
|
||||
request_name: &'static str,
|
||||
request: Self::Req,
|
||||
) -> Result<Self::Resp, RpcError>;
|
||||
}
|
||||
|
||||
impl<Req, Resp> Stub for Channel<Req, Resp> {
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
|
||||
async fn call(
|
||||
&self,
|
||||
ctx: context::Context,
|
||||
request_name: &'static str,
|
||||
request: Req,
|
||||
) -> Result<Self::Resp, RpcError> {
|
||||
Self::call(self, ctx, request_name, request).await
|
||||
}
|
||||
}
|
||||
279
tarpc/src/client/stub/load_balance.rs
Normal file
279
tarpc/src/client/stub/load_balance.rs
Normal file
@@ -0,0 +1,279 @@
|
||||
//! Provides load-balancing [Stubs](crate::client::stub::Stub).
|
||||
|
||||
pub use consistent_hash::ConsistentHash;
|
||||
pub use round_robin::RoundRobin;
|
||||
|
||||
/// Provides a stub that load-balances with a simple round-robin strategy.
|
||||
mod round_robin {
|
||||
use crate::{
|
||||
client::{stub, RpcError},
|
||||
context,
|
||||
};
|
||||
use cycle::AtomicCycle;
|
||||
|
||||
impl<Stub> stub::Stub for RoundRobin<Stub>
|
||||
where
|
||||
Stub: stub::Stub,
|
||||
{
|
||||
type Req = Stub::Req;
|
||||
type Resp = Stub::Resp;
|
||||
|
||||
async fn call(
|
||||
&self,
|
||||
ctx: context::Context,
|
||||
request_name: &'static str,
|
||||
request: Self::Req,
|
||||
) -> Result<Stub::Resp, RpcError> {
|
||||
let next = self.stubs.next();
|
||||
next.call(ctx, request_name, request).await
|
||||
}
|
||||
}
|
||||
|
||||
/// A Stub that load-balances across backing stubs by round robin.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RoundRobin<Stub> {
|
||||
stubs: AtomicCycle<Stub>,
|
||||
}
|
||||
|
||||
impl<Stub> RoundRobin<Stub>
|
||||
where
|
||||
Stub: stub::Stub,
|
||||
{
|
||||
/// Returns a new RoundRobin stub.
|
||||
pub fn new(stubs: Vec<Stub>) -> Self {
|
||||
Self {
|
||||
stubs: AtomicCycle::new(stubs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod cycle {
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
/// Cycles endlessly and atomically over a collection of elements of type T.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AtomicCycle<T>(Arc<State<T>>);
|
||||
|
||||
#[derive(Debug)]
|
||||
struct State<T> {
|
||||
elements: Vec<T>,
|
||||
next: AtomicUsize,
|
||||
}
|
||||
|
||||
impl<T> AtomicCycle<T> {
|
||||
pub fn new(elements: Vec<T>) -> Self {
|
||||
Self(Arc::new(State {
|
||||
elements,
|
||||
next: Default::default(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn next(&self) -> &T {
|
||||
self.0.next()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> State<T> {
|
||||
pub fn next(&self) -> &T {
|
||||
let next = self.next.fetch_add(1, Ordering::Relaxed);
|
||||
&self.elements[next % self.elements.len()]
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cycle() {
|
||||
let cycle = AtomicCycle::new(vec![1, 2, 3]);
|
||||
assert_eq!(cycle.next(), &1);
|
||||
assert_eq!(cycle.next(), &2);
|
||||
assert_eq!(cycle.next(), &3);
|
||||
assert_eq!(cycle.next(), &1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides a stub that load-balances with a consistent hashing strategy.
|
||||
///
|
||||
/// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use
|
||||
/// the same stub.
|
||||
mod consistent_hash {
|
||||
use crate::{
|
||||
client::{stub, RpcError},
|
||||
context,
|
||||
};
|
||||
use std::{
|
||||
collections::hash_map::RandomState,
|
||||
hash::{BuildHasher, Hash, Hasher},
|
||||
num::TryFromIntError,
|
||||
};
|
||||
|
||||
impl<Stub, S> stub::Stub for ConsistentHash<Stub, S>
|
||||
where
|
||||
Stub: stub::Stub,
|
||||
Stub::Req: Hash,
|
||||
S: BuildHasher,
|
||||
{
|
||||
type Req = Stub::Req;
|
||||
type Resp = Stub::Resp;
|
||||
|
||||
async fn call(
|
||||
&self,
|
||||
ctx: context::Context,
|
||||
request_name: &'static str,
|
||||
request: Self::Req,
|
||||
) -> Result<Stub::Resp, RpcError> {
|
||||
let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect(
|
||||
"invariant broken: stubs_len is not larger than a usize, \
|
||||
so the hash modulo stubs_len should always fit in a usize",
|
||||
);
|
||||
let next = &self.stubs[index];
|
||||
next.call(ctx, request_name, request).await
|
||||
}
|
||||
}
|
||||
|
||||
/// A Stub that load-balances across backing stubs by round robin.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ConsistentHash<Stub, S = RandomState> {
|
||||
stubs: Vec<Stub>,
|
||||
stubs_len: u64,
|
||||
hasher: S,
|
||||
}
|
||||
|
||||
impl<Stub> ConsistentHash<Stub, RandomState>
|
||||
where
|
||||
Stub: stub::Stub,
|
||||
Stub::Req: Hash,
|
||||
{
|
||||
/// Returns a new RoundRobin stub.
|
||||
/// Returns an err if the length of `stubs` overflows a u64.
|
||||
pub fn new(stubs: Vec<Stub>) -> Result<Self, TryFromIntError> {
|
||||
Ok(Self {
|
||||
stubs_len: stubs.len().try_into()?,
|
||||
stubs,
|
||||
hasher: RandomState::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<Stub, S> ConsistentHash<Stub, S>
|
||||
where
|
||||
Stub: stub::Stub,
|
||||
Stub::Req: Hash,
|
||||
S: BuildHasher,
|
||||
{
|
||||
/// Returns a new RoundRobin stub.
|
||||
/// Returns an err if the length of `stubs` overflows a u64.
|
||||
pub fn with_hasher(stubs: Vec<Stub>, hasher: S) -> Result<Self, TryFromIntError> {
|
||||
Ok(Self {
|
||||
stubs_len: stubs.len().try_into()?,
|
||||
stubs,
|
||||
hasher,
|
||||
})
|
||||
}
|
||||
|
||||
fn hash_request(&self, req: &Stub::Req) -> u64 {
|
||||
let mut hasher = self.hasher.build_hasher();
|
||||
req.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ConsistentHash;
|
||||
use crate::{
|
||||
client::stub::{mock::Mock, Stub},
|
||||
context,
|
||||
};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
hash::{BuildHasher, Hash, Hasher},
|
||||
rc::Rc,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test() -> anyhow::Result<()> {
|
||||
let stub = ConsistentHash::<_, FakeHasherBuilder>::with_hasher(
|
||||
vec![
|
||||
// For easier reading of the assertions made in this test, each Mock's response
|
||||
// value is equal to a hash value that should map to its index: 3 % 3 = 0, 1 %
|
||||
// 3 = 1, etc.
|
||||
Mock::new([('a', 3), ('b', 3), ('c', 3)]),
|
||||
Mock::new([('a', 1), ('b', 1), ('c', 1)]),
|
||||
Mock::new([('a', 2), ('b', 2), ('c', 2)]),
|
||||
],
|
||||
FakeHasherBuilder::new([('a', 1), ('b', 2), ('c', 3)]),
|
||||
)?;
|
||||
|
||||
for _ in 0..2 {
|
||||
let resp = stub.call(context::current(), "", 'a').await?;
|
||||
assert_eq!(resp, 1);
|
||||
|
||||
let resp = stub.call(context::current(), "", 'b').await?;
|
||||
assert_eq!(resp, 2);
|
||||
|
||||
let resp = stub.call(context::current(), "", 'c').await?;
|
||||
assert_eq!(resp, 3);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct HashRecorder(Vec<u8>);
|
||||
impl Hasher for HashRecorder {
|
||||
fn write(&mut self, bytes: &[u8]) {
|
||||
self.0 = Vec::from(bytes);
|
||||
}
|
||||
fn finish(&self) -> u64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
struct FakeHasherBuilder {
|
||||
recorded_hashes: Rc<HashMap<Vec<u8>, u64>>,
|
||||
}
|
||||
|
||||
struct FakeHasher {
|
||||
recorded_hashes: Rc<HashMap<Vec<u8>, u64>>,
|
||||
output: u64,
|
||||
}
|
||||
|
||||
impl BuildHasher for FakeHasherBuilder {
|
||||
type Hasher = FakeHasher;
|
||||
|
||||
fn build_hasher(&self) -> Self::Hasher {
|
||||
FakeHasher {
|
||||
recorded_hashes: self.recorded_hashes.clone(),
|
||||
output: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeHasherBuilder {
|
||||
fn new<T: Hash, const N: usize>(fake_hashes: [(T, u64); N]) -> Self {
|
||||
let mut recorded_hashes = HashMap::new();
|
||||
for (to_hash, fake_hash) in fake_hashes {
|
||||
let mut recorder = HashRecorder(vec![]);
|
||||
to_hash.hash(&mut recorder);
|
||||
recorded_hashes.insert(recorder.0, fake_hash);
|
||||
}
|
||||
Self {
|
||||
recorded_hashes: Rc::new(recorded_hashes),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Hasher for FakeHasher {
|
||||
fn write(&mut self, bytes: &[u8]) {
|
||||
if let Some(hash) = self.recorded_hashes.get(bytes) {
|
||||
self.output = *hash;
|
||||
}
|
||||
}
|
||||
fn finish(&self) -> u64 {
|
||||
self.output
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
49
tarpc/src/client/stub/mock.rs
Normal file
49
tarpc/src/client/stub/mock.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use crate::{
|
||||
client::{stub::Stub, RpcError},
|
||||
context, ServerError,
|
||||
};
|
||||
use std::{collections::HashMap, hash::Hash, io};
|
||||
|
||||
/// A mock stub that returns user-specified responses.
|
||||
pub struct Mock<Req, Resp> {
|
||||
responses: HashMap<Req, Resp>,
|
||||
}
|
||||
|
||||
impl<Req, Resp> Mock<Req, Resp>
|
||||
where
|
||||
Req: Eq + Hash,
|
||||
{
|
||||
/// Returns a new mock, mocking the specified (request, response) pairs.
|
||||
pub fn new<const N: usize>(responses: [(Req, Resp); N]) -> Self {
|
||||
Self {
|
||||
responses: HashMap::from(responses),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> Stub for Mock<Req, Resp>
|
||||
where
|
||||
Req: Eq + Hash,
|
||||
Resp: Clone,
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
|
||||
async fn call(
|
||||
&self,
|
||||
_: context::Context,
|
||||
_: &'static str,
|
||||
request: Self::Req,
|
||||
) -> Result<Resp, RpcError> {
|
||||
self.responses
|
||||
.get(&request)
|
||||
.cloned()
|
||||
.map(Ok)
|
||||
.unwrap_or_else(|| {
|
||||
Err(RpcError::Server(ServerError {
|
||||
kind: io::ErrorKind::NotFound,
|
||||
detail: "mock (request, response) entry not found".into(),
|
||||
}))
|
||||
})
|
||||
}
|
||||
}
|
||||
56
tarpc/src/client/stub/retry.rs
Normal file
56
tarpc/src/client/stub/retry.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
//! Provides a stub that retries requests based on response contents..
|
||||
|
||||
use crate::{
|
||||
client::{stub, RpcError},
|
||||
context,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
impl<Stub, Req, F> stub::Stub for Retry<F, Stub>
|
||||
where
|
||||
Stub: stub::Stub<Req = Arc<Req>>,
|
||||
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool,
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Stub::Resp;
|
||||
|
||||
async fn call(
|
||||
&self,
|
||||
ctx: context::Context,
|
||||
request_name: &'static str,
|
||||
request: Self::Req,
|
||||
) -> Result<Stub::Resp, RpcError> {
|
||||
let request = Arc::new(request);
|
||||
for i in 1.. {
|
||||
let result = self
|
||||
.stub
|
||||
.call(ctx, request_name, Arc::clone(&request))
|
||||
.await;
|
||||
if (self.should_retry)(&result, i) {
|
||||
tracing::trace!("Retrying on attempt {i}");
|
||||
continue;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
unreachable!("Wow, that was a lot of attempts!");
|
||||
}
|
||||
}
|
||||
|
||||
/// A Stub that retries requests based on response contents.
|
||||
/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Retry<F, Stub> {
|
||||
should_retry: F,
|
||||
stub: Stub,
|
||||
}
|
||||
|
||||
impl<Stub, Req, F> Retry<F, Stub>
|
||||
where
|
||||
Stub: stub::Stub<Req = Arc<Req>>,
|
||||
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool,
|
||||
{
|
||||
/// Creates a new Retry stub that delegates calls to the underlying `stub`.
|
||||
pub fn new(stub: Stub, should_retry: F) -> Self {
|
||||
Self { stub, should_retry }
|
||||
}
|
||||
}
|
||||
113
tarpc/src/lib.rs
113
tarpc/src/lib.rs
@@ -126,13 +126,9 @@
|
||||
//! struct HelloServer;
|
||||
//!
|
||||
//! impl World for HelloServer {
|
||||
//! // Each defined rpc generates two items in the trait, a fn that serves the RPC, and
|
||||
//! // an associated type representing the future output by the fn.
|
||||
//!
|
||||
//! type HelloFut = Ready<String>;
|
||||
//!
|
||||
//! fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
||||
//! future::ready(format!("Hello, {name}!"))
|
||||
//! // Each defined rpc generates an async fn that serves the RPC
|
||||
//! async fn hello(self, _: context::Context, name: String) -> String {
|
||||
//! format!("Hello, {name}!")
|
||||
//! }
|
||||
//! }
|
||||
//! ```
|
||||
@@ -164,11 +160,9 @@
|
||||
//! # #[derive(Clone)]
|
||||
//! # struct HelloServer;
|
||||
//! # impl World for HelloServer {
|
||||
//! # // Each defined rpc generates two items in the trait, a fn that serves the RPC, and
|
||||
//! # // an associated type representing the future output by the fn.
|
||||
//! # type HelloFut = Ready<String>;
|
||||
//! # fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
||||
//! # future::ready(format!("Hello, {name}!"))
|
||||
//! // Each defined rpc generates an async fn that serves the RPC
|
||||
//! # async fn hello(self, _: context::Context, name: String) -> String {
|
||||
//! # format!("Hello, {name}!")
|
||||
//! # }
|
||||
//! # }
|
||||
//! # #[cfg(not(feature = "tokio1"))]
|
||||
@@ -179,7 +173,12 @@
|
||||
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||
//!
|
||||
//! let server = server::BaseChannel::with_defaults(server_transport);
|
||||
//! tokio::spawn(server.execute(HelloServer.serve()));
|
||||
//! tokio::spawn(
|
||||
//! server.execute(HelloServer.serve())
|
||||
//! // Handle all requests concurrently.
|
||||
//! .for_each(|response| async move {
|
||||
//! tokio::spawn(response);
|
||||
//! }));
|
||||
//!
|
||||
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||
//! // that takes a config and any Transport as input.
|
||||
@@ -200,6 +199,7 @@
|
||||
//!
|
||||
//! Use `cargo doc` as you normally would to see the documentation created for all
|
||||
//! items expanded by a `service!` invocation.
|
||||
|
||||
#![deny(missing_docs)]
|
||||
#![allow(clippy::type_complexity)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
@@ -244,62 +244,6 @@ pub use tarpc_plugins::derive_serde;
|
||||
/// * `fn new_stub` -- creates a new Client stub.
|
||||
pub use tarpc_plugins::service;
|
||||
|
||||
/// A utility macro that can be used for RPC server implementations.
|
||||
///
|
||||
/// Syntactic sugar to make using async functions in the server implementation
|
||||
/// easier. It does this by rewriting code like this, which would normally not
|
||||
/// compile because async functions are disallowed in trait implementations:
|
||||
///
|
||||
/// ```rust
|
||||
/// # use tarpc::context;
|
||||
/// # use std::net::SocketAddr;
|
||||
/// #[tarpc::service]
|
||||
/// trait World {
|
||||
/// async fn hello(name: String) -> String;
|
||||
/// }
|
||||
///
|
||||
/// #[derive(Clone)]
|
||||
/// struct HelloServer(SocketAddr);
|
||||
///
|
||||
/// #[tarpc::server]
|
||||
/// impl World for HelloServer {
|
||||
/// async fn hello(self, _: context::Context, name: String) -> String {
|
||||
/// format!("Hello, {name}! You are connected from {:?}.", self.0)
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Into code like this, which matches the service trait definition:
|
||||
///
|
||||
/// ```rust
|
||||
/// # use tarpc::context;
|
||||
/// # use std::pin::Pin;
|
||||
/// # use futures::Future;
|
||||
/// # use std::net::SocketAddr;
|
||||
/// #[derive(Clone)]
|
||||
/// struct HelloServer(SocketAddr);
|
||||
///
|
||||
/// #[tarpc::service]
|
||||
/// trait World {
|
||||
/// async fn hello(name: String) -> String;
|
||||
/// }
|
||||
///
|
||||
/// impl World for HelloServer {
|
||||
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
|
||||
///
|
||||
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
|
||||
/// + Send>> {
|
||||
/// Box::pin(async move {
|
||||
/// format!("Hello, {name}! You are connected from {:?}.", self.0)
|
||||
/// })
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Note that this won't touch functions unless they have been annotated with
|
||||
/// `async`, meaning that this should not break existing code.
|
||||
pub use tarpc_plugins::server;
|
||||
|
||||
pub(crate) mod cancellations;
|
||||
pub mod client;
|
||||
pub mod context;
|
||||
@@ -311,6 +255,7 @@ pub use crate::transport::sealed::Transport;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use futures::task::*;
|
||||
use std::sync::Arc;
|
||||
use std::{error::Error, fmt::Display, io, time::SystemTime};
|
||||
|
||||
/// A message from a client to a server.
|
||||
@@ -383,6 +328,36 @@ pub struct ServerError {
|
||||
pub detail: String,
|
||||
}
|
||||
|
||||
/// Critical errors that result in a Channel disconnecting.
|
||||
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
|
||||
pub enum ChannelError<E>
|
||||
where
|
||||
E: Error + Send + Sync + 'static,
|
||||
{
|
||||
/// Could not read from the transport.
|
||||
#[error("could not read from the transport")]
|
||||
Read(#[source] Arc<E>),
|
||||
/// Could not ready the transport for writes.
|
||||
#[error("could not ready the transport for writes")]
|
||||
Ready(#[source] E),
|
||||
/// Could not write to the transport.
|
||||
#[error("could not write to the transport")]
|
||||
Write(#[source] E),
|
||||
/// Could not flush the transport.
|
||||
#[error("could not flush the transport")]
|
||||
Flush(#[source] E),
|
||||
/// Could not close the write end of the transport.
|
||||
#[error("could not close the write end of the transport")]
|
||||
Close(#[source] E),
|
||||
}
|
||||
|
||||
impl ServerError {
|
||||
/// Returns a new server error with `kind` and `detail`.
|
||||
pub fn new(kind: io::ErrorKind, detail: String) -> ServerError {
|
||||
Self { kind, detail }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Request<T> {
|
||||
/// Returns the deadline for this request.
|
||||
pub fn deadline(&self) -> &SystemTime {
|
||||
|
||||
@@ -129,14 +129,6 @@ pub mod tcp {
|
||||
tokio_util::codec::length_delimited,
|
||||
};
|
||||
|
||||
mod private {
|
||||
use super::*;
|
||||
|
||||
pub trait Sealed {}
|
||||
|
||||
impl<Item, SinkItem, Codec> Sealed for Transport<TcpStream, Item, SinkItem, Codec> {}
|
||||
}
|
||||
|
||||
impl<Item, SinkItem, Codec> Transport<TcpStream, Item, SinkItem, Codec> {
|
||||
/// Returns the peer address of the underlying TcpStream.
|
||||
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
|
||||
@@ -218,7 +210,19 @@ pub mod tcp {
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
CodecFn: Fn() -> Codec,
|
||||
{
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
listen_on(TcpListener::bind(addr).await?, codec_fn).await
|
||||
}
|
||||
|
||||
/// Wrap accepted connections from `listener` in TCP transports.
|
||||
pub async fn listen_on<Item, SinkItem, Codec, CodecFn>(
|
||||
listener: TcpListener,
|
||||
codec_fn: CodecFn,
|
||||
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
|
||||
where
|
||||
Item: for<'de> Deserialize<'de>,
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
CodecFn: Fn() -> Codec,
|
||||
{
|
||||
let local_addr = listener.local_addr()?;
|
||||
Ok(Incoming {
|
||||
listener,
|
||||
@@ -372,7 +376,19 @@ pub mod unix {
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
CodecFn: Fn() -> Codec,
|
||||
{
|
||||
let listener = UnixListener::bind(path)?;
|
||||
listen_on(UnixListener::bind(path)?, codec_fn).await
|
||||
}
|
||||
|
||||
/// Wrap accepted connections from `listener` in Unix Domain Socket transports.
|
||||
pub async fn listen_on<Item, SinkItem, Codec, CodecFn>(
|
||||
listener: UnixListener,
|
||||
codec_fn: CodecFn,
|
||||
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
|
||||
where
|
||||
Item: for<'de> Deserialize<'de>,
|
||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||
CodecFn: Fn() -> Codec,
|
||||
{
|
||||
let local_addr = listener.local_addr()?;
|
||||
Ok(Incoming {
|
||||
listener,
|
||||
@@ -545,7 +561,7 @@ pub mod unix {
|
||||
mod tests {
|
||||
use super::Transport;
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{task::*, Sink, Stream};
|
||||
use futures::{task::*, Sink, SinkExt, Stream, StreamExt};
|
||||
use pin_utils::pin_mut;
|
||||
use std::{
|
||||
io::{self, Cursor},
|
||||
@@ -639,7 +655,7 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(tcp)]
|
||||
#[cfg(feature = "tcp")]
|
||||
#[tokio::test]
|
||||
async fn tcp() -> io::Result<()> {
|
||||
use super::tcp;
|
||||
@@ -658,11 +674,30 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "tcp")]
|
||||
#[tokio::test]
|
||||
async fn tcp_on_existing_transport() -> io::Result<()> {
|
||||
use super::tcp;
|
||||
|
||||
let transport = tokio::net::TcpListener::bind("0.0.0.0:0").await?;
|
||||
let mut listener = tcp::listen_on(transport, SymmetricalJson::<String>::default).await?;
|
||||
let addr = listener.local_addr();
|
||||
tokio::spawn(async move {
|
||||
let mut transport = listener.next().await.unwrap().unwrap();
|
||||
let message = transport.next().await.unwrap().unwrap();
|
||||
transport.send(message).await.unwrap();
|
||||
});
|
||||
let mut transport = tcp::connect(addr, SymmetricalJson::<String>::default).await?;
|
||||
transport.send(String::from("test")).await?;
|
||||
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
|
||||
assert_matches!(transport.next().await, None);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(all(unix, feature = "unix"))]
|
||||
#[tokio::test]
|
||||
async fn uds() -> io::Result<()> {
|
||||
use super::unix;
|
||||
use super::*;
|
||||
|
||||
let sock = unix::TempPathBuf::with_random("uds");
|
||||
let mut listener = unix::listen(&sock, SymmetricalJson::<String>::default).await?;
|
||||
@@ -677,4 +712,24 @@ mod tests {
|
||||
assert_matches!(transport.next().await, None);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(all(unix, feature = "unix"))]
|
||||
#[tokio::test]
|
||||
async fn uds_on_existing_transport() -> io::Result<()> {
|
||||
use super::unix;
|
||||
|
||||
let sock = unix::TempPathBuf::with_random("uds");
|
||||
let transport = tokio::net::UnixListener::bind(&sock)?;
|
||||
let mut listener = unix::listen_on(transport, SymmetricalJson::<String>::default).await?;
|
||||
tokio::spawn(async move {
|
||||
let mut transport = listener.next().await.unwrap().unwrap();
|
||||
let message = transport.next().await.unwrap().unwrap();
|
||||
transport.send(message).await.unwrap();
|
||||
});
|
||||
let mut transport = unix::connect(&sock, SymmetricalJson::<String>::default).await?;
|
||||
transport.send(String::from("test")).await?;
|
||||
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
|
||||
assert_matches!(transport.next().await, None);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
use crate::{
|
||||
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||
context::{self, SpanExt},
|
||||
trace, ClientMessage, Request, Response, Transport,
|
||||
trace, ChannelError, ClientMessage, Request, Response, ServerError, Transport,
|
||||
};
|
||||
use ::tokio::sync::mpsc;
|
||||
use futures::{
|
||||
@@ -21,17 +21,11 @@ use futures::{
|
||||
};
|
||||
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
error::Error,
|
||||
fmt,
|
||||
marker::PhantomData,
|
||||
mem::{self, ManuallyDrop},
|
||||
pin::Pin,
|
||||
};
|
||||
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc};
|
||||
use tracing::{info_span, instrument::Instrument, Span};
|
||||
|
||||
mod in_flight_requests;
|
||||
pub mod request_hook;
|
||||
#[cfg(test)]
|
||||
mod testing;
|
||||
|
||||
@@ -41,10 +35,9 @@ pub mod limits;
|
||||
/// Provides helper methods for streams of Channels.
|
||||
pub mod incoming;
|
||||
|
||||
/// Provides convenience functionality for tokio-enabled applications.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub mod tokio;
|
||||
use request_hook::{
|
||||
AfterRequest, BeforeRequest, HookThenServe, HookThenServeThenHook, ServeThenHook,
|
||||
};
|
||||
|
||||
/// Settings that control the behavior of [channels](Channel).
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -74,32 +67,204 @@ impl Config {
|
||||
}
|
||||
|
||||
/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
|
||||
pub trait Serve<Req> {
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub trait Serve {
|
||||
/// Type of request.
|
||||
type Req;
|
||||
|
||||
/// Type of response.
|
||||
type Resp;
|
||||
|
||||
/// Type of response future.
|
||||
type Fut: Future<Output = Self::Resp>;
|
||||
/// Responds to a single request.
|
||||
async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError>;
|
||||
|
||||
/// Extracts a method name from the request.
|
||||
fn method(&self, _request: &Req) -> Option<&'static str> {
|
||||
fn method(&self, _request: &Self::Req) -> Option<&'static str> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Responds to a single request.
|
||||
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut;
|
||||
/// Runs a hook before execution of the request.
|
||||
///
|
||||
/// If the hook returns an error, the request will not be executed and the error will be
|
||||
/// returned instead.
|
||||
///
|
||||
/// The hook can also modify the request context. This could be used, for example, to enforce a
|
||||
/// maximum deadline on all requests.
|
||||
///
|
||||
/// Any type that implements [`BeforeRequest`] can be used as the hook. Types that implement
|
||||
/// `FnMut(&mut Context, &RequestType) -> impl Future<Output = Result<(), ServerError>>` can
|
||||
/// also be used.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use futures::{executor::block_on, future};
|
||||
/// use tarpc::{context, ServerError, server::{Serve, serve}};
|
||||
/// use std::io;
|
||||
///
|
||||
/// let serve = serve(|_ctx, i| async move { Ok(i + 1) })
|
||||
/// .before(|_ctx: &mut context::Context, req: &i32| {
|
||||
/// future::ready(
|
||||
/// if *req == 1 {
|
||||
/// Err(ServerError::new(
|
||||
/// io::ErrorKind::Other,
|
||||
/// format!("I don't like {req}")))
|
||||
/// } else {
|
||||
/// Ok(())
|
||||
/// })
|
||||
/// });
|
||||
/// let response = serve.serve(context::current(), 1);
|
||||
/// assert!(block_on(response).is_err());
|
||||
/// ```
|
||||
fn before<Hook>(self, hook: Hook) -> HookThenServe<Self, Hook>
|
||||
where
|
||||
Hook: BeforeRequest<Self::Req>,
|
||||
Self: Sized,
|
||||
{
|
||||
HookThenServe::new(self, hook)
|
||||
}
|
||||
|
||||
/// Runs a hook after completion of a request.
|
||||
///
|
||||
/// The hook can modify the request context and the response.
|
||||
///
|
||||
/// Any type that implements [`AfterRequest`] can be used as the hook. Types that implement
|
||||
/// `FnMut(&mut Context, &mut Result<ResponseType, ServerError>) -> impl Future<Output = ()>`
|
||||
/// can also be used.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use futures::{executor::block_on, future};
|
||||
/// use tarpc::{context, ServerError, server::{Serve, serve}};
|
||||
/// use std::io;
|
||||
///
|
||||
/// let serve = serve(
|
||||
/// |_ctx, i| async move {
|
||||
/// if i == 1 {
|
||||
/// Err(ServerError::new(
|
||||
/// io::ErrorKind::Other,
|
||||
/// format!("{i} is the loneliest number")))
|
||||
/// } else {
|
||||
/// Ok(i + 1)
|
||||
/// }
|
||||
/// })
|
||||
/// .after(|_ctx: &mut context::Context, resp: &mut Result<i32, ServerError>| {
|
||||
/// if let Err(e) = resp {
|
||||
/// eprintln!("server error: {e:?}");
|
||||
/// }
|
||||
/// future::ready(())
|
||||
/// });
|
||||
///
|
||||
/// let response = serve.serve(context::current(), 1);
|
||||
/// assert!(block_on(response).is_err());
|
||||
/// ```
|
||||
fn after<Hook>(self, hook: Hook) -> ServeThenHook<Self, Hook>
|
||||
where
|
||||
Hook: AfterRequest<Self::Resp>,
|
||||
Self: Sized,
|
||||
{
|
||||
ServeThenHook::new(self, hook)
|
||||
}
|
||||
|
||||
/// Runs a hook before and after execution of the request.
|
||||
///
|
||||
/// If the hook returns an error, the request will not be executed and the error will be
|
||||
/// returned instead.
|
||||
///
|
||||
/// The hook can also modify the request context and the response. This could be used, for
|
||||
/// example, to enforce a maximum deadline on all requests.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use futures::{executor::block_on, future};
|
||||
/// use tarpc::{
|
||||
/// context, ServerError, server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest}}
|
||||
/// };
|
||||
/// use std::{io, time::Instant};
|
||||
///
|
||||
/// struct PrintLatency(Instant);
|
||||
///
|
||||
/// impl<Req> BeforeRequest<Req> for PrintLatency {
|
||||
/// async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> {
|
||||
/// self.0 = Instant::now();
|
||||
/// Ok(())
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// impl<Resp> AfterRequest<Resp> for PrintLatency {
|
||||
/// async fn after(
|
||||
/// &mut self,
|
||||
/// _: &mut context::Context,
|
||||
/// _: &mut Result<Resp, ServerError>,
|
||||
/// ) {
|
||||
/// tracing::info!("Elapsed: {:?}", self.0.elapsed());
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// let serve = serve(|_ctx, i| async move {
|
||||
/// Ok(i + 1)
|
||||
/// }).before_and_after(PrintLatency(Instant::now()));
|
||||
/// let response = serve.serve(context::current(), 1);
|
||||
/// assert!(block_on(response).is_ok());
|
||||
/// ```
|
||||
fn before_and_after<Hook>(
|
||||
self,
|
||||
hook: Hook,
|
||||
) -> HookThenServeThenHook<Self::Req, Self::Resp, Self, Hook>
|
||||
where
|
||||
Hook: BeforeRequest<Self::Req> + AfterRequest<Self::Resp>,
|
||||
Self: Sized,
|
||||
{
|
||||
HookThenServeThenHook::new(self, hook)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, Fut, F> Serve<Req> for F
|
||||
/// A Serve wrapper around a Fn.
|
||||
#[derive(Debug)]
|
||||
pub struct ServeFn<Req, Resp, F> {
|
||||
f: F,
|
||||
data: PhantomData<fn(Req) -> Resp>,
|
||||
}
|
||||
|
||||
impl<Req, Resp, F> Clone for ServeFn<Req, Resp, F>
|
||||
where
|
||||
F: Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
f: self.f.clone(),
|
||||
data: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, F> Copy for ServeFn<Req, Resp, F> where F: Copy {}
|
||||
|
||||
/// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future<Output =
|
||||
/// Result<Resp, ServerError>>`.
|
||||
pub fn serve<Req, Resp, Fut, F>(f: F) -> ServeFn<Req, Resp, F>
|
||||
where
|
||||
F: FnOnce(context::Context, Req) -> Fut,
|
||||
Fut: Future<Output = Resp>,
|
||||
Fut: Future<Output = Result<Resp, ServerError>>,
|
||||
{
|
||||
type Resp = Resp;
|
||||
type Fut = Fut;
|
||||
ServeFn {
|
||||
f,
|
||||
data: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut {
|
||||
self(ctx, req)
|
||||
impl<Req, Resp, Fut, F> Serve for ServeFn<Req, Resp, F>
|
||||
where
|
||||
F: FnOnce(context::Context, Req) -> Fut,
|
||||
Fut: Future<Output = Result<Resp, ServerError>>,
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
|
||||
async fn serve(self, ctx: context::Context, req: Req) -> Result<Resp, ServerError> {
|
||||
(self.f)(ctx, req).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,7 +292,7 @@ pub struct BaseChannel<Req, Resp, T> {
|
||||
/// Holds data necessary to clean up in-flight requests.
|
||||
in_flight_requests: InFlightRequests,
|
||||
/// Types the request and response.
|
||||
ghost: PhantomData<(Req, Resp)>,
|
||||
ghost: PhantomData<(fn() -> Req, fn(Resp))>,
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
|
||||
@@ -208,10 +373,11 @@ where
|
||||
Ok(TrackedRequest {
|
||||
abort_registration,
|
||||
span,
|
||||
response_guard: ManuallyDrop::new(ResponseGuard {
|
||||
response_guard: ResponseGuard {
|
||||
request_id: request.id,
|
||||
request_cancellation: self.request_cancellation.clone(),
|
||||
}),
|
||||
cancel: false,
|
||||
},
|
||||
request,
|
||||
})
|
||||
}
|
||||
@@ -240,7 +406,7 @@ pub struct TrackedRequest<Req> {
|
||||
/// A span representing the server processing of this request.
|
||||
pub span: Span,
|
||||
/// An inert response guard. Becomes active in an InFlightRequest.
|
||||
pub response_guard: ManuallyDrop<ResponseGuard>,
|
||||
pub response_guard: ResponseGuard,
|
||||
}
|
||||
|
||||
/// The server end of an open connection with a client, receiving requests from, and sending
|
||||
@@ -313,6 +479,34 @@ where
|
||||
/// This is a terminal operation. After calling `requests`, the channel cannot be retrieved,
|
||||
/// and the only way to complete requests is via [`Requests::execute`] or
|
||||
/// [`InFlightRequest::execute`].
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tarpc::{
|
||||
/// context,
|
||||
/// client::{self, NewClient},
|
||||
/// server::{self, BaseChannel, Channel, serve},
|
||||
/// transport,
|
||||
/// };
|
||||
/// use futures::prelude::*;
|
||||
///
|
||||
/// #[tokio::main]
|
||||
/// async fn main() {
|
||||
/// let (tx, rx) = transport::channel::unbounded();
|
||||
/// let server = BaseChannel::new(server::Config::default(), rx);
|
||||
/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx);
|
||||
/// tokio::spawn(dispatch);
|
||||
///
|
||||
/// let mut requests = server.requests();
|
||||
/// tokio::spawn(async move {
|
||||
/// while let Some(Ok(request)) = requests.next().await {
|
||||
/// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) })));
|
||||
/// }
|
||||
/// });
|
||||
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
|
||||
/// }
|
||||
/// ```
|
||||
fn requests(self) -> Requests<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
@@ -326,37 +520,47 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs the channel until completion by executing all requests using the given service
|
||||
/// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's
|
||||
/// default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
fn execute<S>(self, serve: S) -> self::tokio::TokioChannelExecutor<Requests<Self>, S>
|
||||
/// Returns a stream of request execution futures. Each future represents an in-flight request
|
||||
/// being responded to by the server. The futures must be awaited or spawned to complete their
|
||||
/// requests.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport};
|
||||
/// use futures::prelude::*;
|
||||
/// use tracing_subscriber::prelude::*;
|
||||
///
|
||||
/// #[derive(PartialEq, Eq, Debug)]
|
||||
/// struct MyInt(i32);
|
||||
///
|
||||
/// # #[cfg(not(feature = "tokio1"))]
|
||||
/// # fn main() {}
|
||||
/// # #[cfg(feature = "tokio1")]
|
||||
/// #[tokio::main]
|
||||
/// async fn main() {
|
||||
/// let (tx, rx) = transport::channel::unbounded();
|
||||
/// let client = client::new(client::Config::default(), tx).spawn();
|
||||
/// let channel = BaseChannel::with_defaults(rx);
|
||||
/// tokio::spawn(
|
||||
/// channel.execute(serve(|_, MyInt(i)| async move { Ok(MyInt(i + 1)) }))
|
||||
/// .for_each(|response| async move {
|
||||
/// tokio::spawn(response);
|
||||
/// }));
|
||||
/// assert_eq!(
|
||||
/// client.call(context::current(), "AddOne", MyInt(1)).await.unwrap(),
|
||||
/// MyInt(2));
|
||||
/// }
|
||||
/// ```
|
||||
fn execute<S>(self, serve: S) -> impl Stream<Item = impl Future<Output = ()>>
|
||||
where
|
||||
Self: Sized,
|
||||
S: Serve<Self::Req, Resp = Self::Resp> + Send + 'static,
|
||||
S::Fut: Send,
|
||||
Self::Req: Send + 'static,
|
||||
Self::Resp: Send + 'static,
|
||||
S: Serve<Req = Self::Req, Resp = Self::Resp> + Clone,
|
||||
{
|
||||
self.requests().execute(serve)
|
||||
}
|
||||
}
|
||||
|
||||
/// Critical errors that result in a Channel disconnecting.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ChannelError<E>
|
||||
where
|
||||
E: Error + Send + Sync + 'static,
|
||||
{
|
||||
/// An error occurred reading from, or writing to, the transport.
|
||||
#[error("an error occurred in the transport: {0}")]
|
||||
Transport(#[source] E),
|
||||
/// An error occurred while polling expired requests.
|
||||
#[error("an error occurred while polling expired requests: {0}")]
|
||||
Timer(#[source] ::tokio::time::error::Error),
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
@@ -413,7 +617,7 @@ where
|
||||
let request_status = match self
|
||||
.transport_pin_mut()
|
||||
.poll_next(cx)
|
||||
.map_err(ChannelError::Transport)?
|
||||
.map_err(|e| ChannelError::Read(Arc::new(e)))?
|
||||
{
|
||||
Poll::Ready(Some(message)) => match message {
|
||||
ClientMessage::Request(request) => {
|
||||
@@ -445,15 +649,17 @@ where
|
||||
Poll::Pending => Pending,
|
||||
};
|
||||
|
||||
tracing::trace!(
|
||||
"Expired requests: {:?}, Inbound: {:?}",
|
||||
expiration_status,
|
||||
request_status
|
||||
);
|
||||
match cancellation_status
|
||||
let status = cancellation_status
|
||||
.combine(expiration_status)
|
||||
.combine(request_status)
|
||||
{
|
||||
.combine(request_status);
|
||||
|
||||
tracing::trace!(
|
||||
"Cancellations: {cancellation_status:?}, \
|
||||
Expired requests: {expiration_status:?}, \
|
||||
Inbound: {request_status:?}, \
|
||||
Overall: {status:?}",
|
||||
);
|
||||
match status {
|
||||
Ready => continue,
|
||||
Closed => return Poll::Ready(None),
|
||||
Pending => return Poll::Pending,
|
||||
@@ -473,7 +679,7 @@ where
|
||||
self.project()
|
||||
.transport
|
||||
.poll_ready(cx)
|
||||
.map_err(ChannelError::Transport)
|
||||
.map_err(ChannelError::Ready)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
|
||||
@@ -486,7 +692,7 @@ where
|
||||
self.project()
|
||||
.transport
|
||||
.start_send(response)
|
||||
.map_err(ChannelError::Transport)
|
||||
.map_err(ChannelError::Write)
|
||||
} else {
|
||||
// If the request isn't tracked anymore, there's no need to send the response.
|
||||
Ok(())
|
||||
@@ -498,14 +704,14 @@ where
|
||||
self.project()
|
||||
.transport
|
||||
.poll_flush(cx)
|
||||
.map_err(ChannelError::Transport)
|
||||
.map_err(ChannelError::Flush)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.project()
|
||||
.transport
|
||||
.poll_close(cx)
|
||||
.map_err(ChannelError::Transport)
|
||||
.map_err(ChannelError::Close)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -581,13 +787,19 @@ where
|
||||
request,
|
||||
abort_registration,
|
||||
span,
|
||||
response_guard,
|
||||
mut response_guard,
|
||||
}| {
|
||||
// The response guard becomes active once in an InFlightRequest.
|
||||
response_guard.cancel = true;
|
||||
{
|
||||
let _entered = span.enter();
|
||||
tracing::info!("BeginRequest");
|
||||
}
|
||||
InFlightRequest {
|
||||
request,
|
||||
abort_registration,
|
||||
span,
|
||||
response_guard: ManuallyDrop::into_inner(response_guard),
|
||||
response_guard,
|
||||
response_tx: self.responses_tx.clone(),
|
||||
}
|
||||
},
|
||||
@@ -657,6 +869,51 @@ where
|
||||
}
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
|
||||
/// Returns a stream of request execution futures. Each future represents an in-flight request
|
||||
/// being responded to by the server. The futures must be awaited or spawned to complete their
|
||||
/// requests.
|
||||
///
|
||||
/// If the channel encounters an error, the stream is terminated and the error is logged.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport};
|
||||
/// use futures::prelude::*;
|
||||
///
|
||||
/// # #[cfg(not(feature = "tokio1"))]
|
||||
/// # fn main() {}
|
||||
/// # #[cfg(feature = "tokio1")]
|
||||
/// #[tokio::main]
|
||||
/// async fn main() {
|
||||
/// let (tx, rx) = transport::channel::unbounded();
|
||||
/// let requests = BaseChannel::new(server::Config::default(), rx).requests();
|
||||
/// let client = client::new(client::Config::default(), tx).spawn();
|
||||
/// tokio::spawn(
|
||||
/// requests.execute(serve(|_, i| async move { Ok(i + 1) }))
|
||||
/// .for_each(|response| async move {
|
||||
/// tokio::spawn(response);
|
||||
/// }));
|
||||
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn execute<S>(self, serve: S) -> impl Stream<Item = impl Future<Output = ()>>
|
||||
where
|
||||
S: Serve<Req = C::Req, Resp = C::Resp> + Clone,
|
||||
{
|
||||
self.take_while(|result| {
|
||||
if let Err(e) = result {
|
||||
tracing::warn!("Requests stream errored out: {}", e);
|
||||
}
|
||||
futures::future::ready(result.is_ok())
|
||||
})
|
||||
.filter_map(|result| async move { result.ok() })
|
||||
.map(move |request| {
|
||||
let serve = serve.clone();
|
||||
request.execute(serve)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> fmt::Debug for Requests<C>
|
||||
@@ -674,11 +931,14 @@ where
|
||||
pub struct ResponseGuard {
|
||||
request_cancellation: RequestCancellation,
|
||||
request_id: u64,
|
||||
cancel: bool,
|
||||
}
|
||||
|
||||
impl Drop for ResponseGuard {
|
||||
fn drop(&mut self) {
|
||||
self.request_cancellation.cancel(self.request_id);
|
||||
if self.cancel {
|
||||
self.request_cancellation.cancel(self.request_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -715,13 +975,43 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
///
|
||||
/// If the returned Future is dropped before completion, a cancellation message will be sent to
|
||||
/// the Channel to clean up associated request state.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tarpc::{
|
||||
/// context,
|
||||
/// client::{self, NewClient},
|
||||
/// server::{self, BaseChannel, Channel, serve},
|
||||
/// transport,
|
||||
/// };
|
||||
/// use futures::prelude::*;
|
||||
///
|
||||
/// #[tokio::main]
|
||||
/// async fn main() {
|
||||
/// let (tx, rx) = transport::channel::unbounded();
|
||||
/// let server = BaseChannel::new(server::Config::default(), rx);
|
||||
/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx);
|
||||
/// tokio::spawn(dispatch);
|
||||
///
|
||||
/// tokio::spawn(async move {
|
||||
/// let mut requests = server.requests();
|
||||
/// while let Some(Ok(in_flight_request)) = requests.next().await {
|
||||
/// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await;
|
||||
/// }
|
||||
///
|
||||
/// });
|
||||
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
pub async fn execute<S>(self, serve: S)
|
||||
where
|
||||
S: Serve<Req, Resp = Res>,
|
||||
S: Serve<Req = Req, Resp = Res>,
|
||||
{
|
||||
let Self {
|
||||
response_tx,
|
||||
response_guard,
|
||||
mut response_guard,
|
||||
abort_registration,
|
||||
span,
|
||||
request:
|
||||
@@ -732,15 +1022,14 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
},
|
||||
} = self;
|
||||
let method = serve.method(&message);
|
||||
span.record("otel.name", &method.unwrap_or(""));
|
||||
span.record("otel.name", method.unwrap_or(""));
|
||||
let _ = Abortable::new(
|
||||
async move {
|
||||
tracing::info!("BeginRequest");
|
||||
let response = serve.serve(context, message).await;
|
||||
let message = serve.serve(context, message).await;
|
||||
tracing::info!("CompleteRequest");
|
||||
let response = Response {
|
||||
request_id,
|
||||
message: Ok(response),
|
||||
message,
|
||||
};
|
||||
let _ = response_tx.send(response).await;
|
||||
tracing::info!("BufferResponse");
|
||||
@@ -752,10 +1041,17 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
// Request processing has completed, meaning either the channel canceled the request or
|
||||
// a request was sent back to the channel. Either way, the channel will clean up the
|
||||
// request data, so the request does not need to be canceled.
|
||||
mem::forget(response_guard);
|
||||
response_guard.cancel = false;
|
||||
}
|
||||
}
|
||||
|
||||
fn print_err(e: &(dyn Error + 'static)) -> String {
|
||||
anyhow::Chain::new(e)
|
||||
.map(|e| e.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(": ")
|
||||
}
|
||||
|
||||
impl<C> Stream for Requests<C>
|
||||
where
|
||||
C: Channel,
|
||||
@@ -764,17 +1060,33 @@ where
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
let read = self.as_mut().pump_read(cx)?;
|
||||
let read = self.as_mut().pump_read(cx).map_err(|e| {
|
||||
tracing::trace!("read: {}", print_err(&e));
|
||||
e
|
||||
})?;
|
||||
let read_closed = matches!(read, Poll::Ready(None));
|
||||
match (read, self.as_mut().pump_write(cx, read_closed)?) {
|
||||
let write = self.as_mut().pump_write(cx, read_closed).map_err(|e| {
|
||||
tracing::trace!("write: {}", print_err(&e));
|
||||
e
|
||||
})?;
|
||||
match (read, write) {
|
||||
(Poll::Ready(None), Poll::Ready(None)) => {
|
||||
tracing::trace!("read: Poll::Ready(None), write: Poll::Ready(None)");
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
(Poll::Ready(Some(request_handler)), _) => {
|
||||
tracing::trace!("read: Poll::Ready(Some), write: _");
|
||||
return Poll::Ready(Some(Ok(request_handler)));
|
||||
}
|
||||
(_, Poll::Ready(Some(()))) => {}
|
||||
_ => {
|
||||
(_, Poll::Ready(Some(()))) => {
|
||||
tracing::trace!("read: _, write: Poll::Ready(Some)");
|
||||
}
|
||||
(read @ Poll::Pending, write) | (read, write @ Poll::Pending) => {
|
||||
tracing::trace!(
|
||||
"read pending: {}, write pending: {}",
|
||||
read.is_pending(),
|
||||
write.is_pending()
|
||||
);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
@@ -784,11 +1096,14 @@ where
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests};
|
||||
use super::{
|
||||
in_flight_requests::AlreadyExistsError, serve, AfterRequest, BaseChannel, BeforeRequest,
|
||||
Channel, Config, Requests, Serve,
|
||||
};
|
||||
use crate::{
|
||||
context, trace,
|
||||
transport::channel::{self, UnboundedChannel},
|
||||
ClientMessage, Request, Response,
|
||||
ClientMessage, Request, Response, ServerError,
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{
|
||||
@@ -797,7 +1112,12 @@ mod tests {
|
||||
Future,
|
||||
};
|
||||
use futures_test::task::noop_context;
|
||||
use std::{pin::Pin, task::Poll};
|
||||
use std::{
|
||||
io,
|
||||
pin::Pin,
|
||||
task::Poll,
|
||||
time::{Duration, Instant, SystemTime},
|
||||
};
|
||||
|
||||
fn test_channel<Req, Resp>() -> (
|
||||
Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
|
||||
@@ -858,6 +1178,89 @@ mod tests {
|
||||
Abortable::new(pending(), abort_registration)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_serve() {
|
||||
let serve = serve(|_, i| async move { Ok(i) });
|
||||
assert_matches!(serve.serve(context::current(), 7).await, Ok(7));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn serve_before_mutates_context() -> anyhow::Result<()> {
|
||||
struct SetDeadline(SystemTime);
|
||||
impl<Req> BeforeRequest<Req> for SetDeadline {
|
||||
async fn before(
|
||||
&mut self,
|
||||
ctx: &mut context::Context,
|
||||
_: &Req,
|
||||
) -> Result<(), ServerError> {
|
||||
ctx.deadline = self.0;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let some_time = SystemTime::UNIX_EPOCH + Duration::from_secs(37);
|
||||
let some_other_time = SystemTime::UNIX_EPOCH + Duration::from_secs(83);
|
||||
|
||||
let serve = serve(move |ctx: context::Context, i| async move {
|
||||
assert_eq!(ctx.deadline, some_time);
|
||||
Ok(i)
|
||||
});
|
||||
let deadline_hook = serve.before(SetDeadline(some_time));
|
||||
let mut ctx = context::current();
|
||||
ctx.deadline = some_other_time;
|
||||
deadline_hook.serve(ctx, 7).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn serve_before_and_after() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
struct PrintLatency {
|
||||
start: Instant,
|
||||
}
|
||||
impl PrintLatency {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
start: Instant::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<Req> BeforeRequest<Req> for PrintLatency {
|
||||
async fn before(
|
||||
&mut self,
|
||||
_: &mut context::Context,
|
||||
_: &Req,
|
||||
) -> Result<(), ServerError> {
|
||||
self.start = Instant::now();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
impl<Resp> AfterRequest<Resp> for PrintLatency {
|
||||
async fn after(&mut self, _: &mut context::Context, _: &mut Result<Resp, ServerError>) {
|
||||
tracing::info!("Elapsed: {:?}", self.start.elapsed());
|
||||
}
|
||||
}
|
||||
|
||||
let serve = serve(move |_: context::Context, i| async move { Ok(i) });
|
||||
serve
|
||||
.before_and_after(PrintLatency::new())
|
||||
.serve(context::current(), 7)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn serve_before_error_aborts_request() -> anyhow::Result<()> {
|
||||
let serve = serve(|_, _| async { panic!("Shouldn't get here") });
|
||||
let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async {
|
||||
Err(ServerError::new(io::ErrorKind::Other, "oops".into()))
|
||||
});
|
||||
let resp: Result<i32, _> = deadline_hook.serve(context::current(), 7).await;
|
||||
assert_matches!(resp, Err(_));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn base_channel_start_send_duplicate_request_returns_error() {
|
||||
let (mut channel, _tx) = test_channel::<(), ()>();
|
||||
@@ -1058,7 +1461,7 @@ mod tests {
|
||||
Poll::Ready(Some(Ok(request))) => request,
|
||||
result => panic!("Unexpected result: {:?}", result),
|
||||
};
|
||||
request.execute(|_, _| async {}).await;
|
||||
request.execute(serve(|_, _| async { Ok(()) })).await;
|
||||
assert!(requests
|
||||
.as_mut()
|
||||
.channel_pin_mut()
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
use super::{
|
||||
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
|
||||
Channel,
|
||||
Channel, Serve,
|
||||
};
|
||||
use futures::prelude::*;
|
||||
use std::{fmt, hash::Hash};
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
use super::{tokio::TokioServerExecutor, Serve};
|
||||
|
||||
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
|
||||
pub trait Incoming<C>
|
||||
where
|
||||
@@ -28,16 +25,62 @@ where
|
||||
MaxRequestsPerChannel::new(self, n)
|
||||
}
|
||||
|
||||
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
|
||||
/// concurrently by spawning on tokio's default executor, and each request will be also
|
||||
/// be spawned on tokio's default executor.
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
|
||||
/// Returns a stream of channels in execution. Each channel in execution is a stream of
|
||||
/// futures, where each future is an in-flight request being rsponded to.
|
||||
fn execute<S>(
|
||||
self,
|
||||
serve: S,
|
||||
) -> impl Stream<Item = impl Stream<Item = impl Future<Output = ()>>>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp>,
|
||||
S: Serve<Req = C::Req, Resp = C::Resp> + Clone,
|
||||
{
|
||||
TokioServerExecutor::new(self, serve)
|
||||
self.map(move |channel| channel.execute(serve.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
/// Spawns all channels-in-execution, delegating to the tokio runtime to manage their completion.
|
||||
/// Each channel is spawned, and each request from each channel is spawned.
|
||||
/// Note that this function is generic over any stream-of-streams-of-futures, but it is intended
|
||||
/// for spawning streams of channels.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use tarpc::{
|
||||
/// context,
|
||||
/// client::{self, NewClient},
|
||||
/// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve},
|
||||
/// transport,
|
||||
/// };
|
||||
/// use futures::prelude::*;
|
||||
///
|
||||
/// #[tokio::main]
|
||||
/// async fn main() {
|
||||
/// let (tx, rx) = transport::channel::unbounded();
|
||||
/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx);
|
||||
/// tokio::spawn(dispatch);
|
||||
///
|
||||
/// let incoming = stream::once(async move {
|
||||
/// BaseChannel::new(server::Config::default(), rx)
|
||||
/// }).execute(serve(|_, i| async move { Ok(i + 1) }));
|
||||
/// tokio::spawn(spawn_incoming(incoming));
|
||||
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
|
||||
/// }
|
||||
/// ```
|
||||
pub async fn spawn_incoming(
|
||||
incoming: impl Stream<
|
||||
Item = impl Stream<Item = impl Future<Output = ()> + Send + 'static> + Send + 'static,
|
||||
>,
|
||||
) {
|
||||
use futures::pin_mut;
|
||||
pin_mut!(incoming);
|
||||
while let Some(channel) = incoming.next().await {
|
||||
tokio::spawn(async move {
|
||||
pin_mut!(channel);
|
||||
while let Some(request) = channel.next().await {
|
||||
tokio::spawn(request);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
25
tarpc/src/server/request_hook.rs
Normal file
25
tarpc/src/server/request_hook.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright 2022 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Hooks for horizontal functionality that can run either before or after a request is executed.
|
||||
|
||||
/// A request hook that runs before a request is executed.
|
||||
mod before;
|
||||
|
||||
/// A request hook that runs after a request is completed.
|
||||
mod after;
|
||||
|
||||
/// A request hook that runs both before a request is executed and after it is completed.
|
||||
mod before_and_after;
|
||||
|
||||
pub use {
|
||||
after::{AfterRequest, ServeThenHook},
|
||||
before::{
|
||||
before, BeforeRequest, BeforeRequestCons, BeforeRequestList, BeforeRequestNil,
|
||||
HookThenServe,
|
||||
},
|
||||
before_and_after::HookThenServeThenHook,
|
||||
};
|
||||
72
tarpc/src/server/request_hook/after.rs
Normal file
72
tarpc/src/server/request_hook/after.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright 2022 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Provides a hook that runs after request execution.
|
||||
|
||||
use crate::{context, server::Serve, ServerError};
|
||||
use futures::prelude::*;
|
||||
|
||||
/// A hook that runs after request execution.
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub trait AfterRequest<Resp> {
|
||||
/// The function that is called after request execution.
|
||||
///
|
||||
/// The hook can modify the request context and the response.
|
||||
async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>);
|
||||
}
|
||||
|
||||
impl<F, Fut, Resp> AfterRequest<Resp> for F
|
||||
where
|
||||
F: FnMut(&mut context::Context, &mut Result<Resp, ServerError>) -> Fut,
|
||||
Fut: Future<Output = ()>,
|
||||
{
|
||||
async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>) {
|
||||
self(ctx, resp).await
|
||||
}
|
||||
}
|
||||
|
||||
/// A Service function that runs a hook after request execution.
|
||||
pub struct ServeThenHook<Serv, Hook> {
|
||||
serve: Serv,
|
||||
hook: Hook,
|
||||
}
|
||||
|
||||
impl<Serv, Hook> ServeThenHook<Serv, Hook> {
|
||||
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
|
||||
Self { serve, hook }
|
||||
}
|
||||
}
|
||||
|
||||
impl<Serv: Clone, Hook: Clone> Clone for ServeThenHook<Serv, Hook> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
serve: self.serve.clone(),
|
||||
hook: self.hook.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Serv, Hook> Serve for ServeThenHook<Serv, Hook>
|
||||
where
|
||||
Serv: Serve,
|
||||
Hook: AfterRequest<Serv::Resp>,
|
||||
{
|
||||
type Req = Serv::Req;
|
||||
type Resp = Serv::Resp;
|
||||
|
||||
async fn serve(
|
||||
self,
|
||||
mut ctx: context::Context,
|
||||
req: Serv::Req,
|
||||
) -> Result<Serv::Resp, ServerError> {
|
||||
let ServeThenHook {
|
||||
serve, mut hook, ..
|
||||
} = self;
|
||||
let mut resp = serve.serve(ctx, req).await;
|
||||
hook.after(&mut ctx, &mut resp).await;
|
||||
resp
|
||||
}
|
||||
}
|
||||
210
tarpc/src/server/request_hook/before.rs
Normal file
210
tarpc/src/server/request_hook/before.rs
Normal file
@@ -0,0 +1,210 @@
|
||||
// Copyright 2022 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Provides a hook that runs before request execution.
|
||||
|
||||
use crate::{context, server::Serve, ServerError};
|
||||
use futures::prelude::*;
|
||||
|
||||
/// A hook that runs before request execution.
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub trait BeforeRequest<Req> {
|
||||
/// The function that is called before request execution.
|
||||
///
|
||||
/// If this function returns an error, the request will not be executed and the error will be
|
||||
/// returned instead.
|
||||
///
|
||||
/// This function can also modify the request context. This could be used, for example, to
|
||||
/// enforce a maximum deadline on all requests.
|
||||
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>;
|
||||
}
|
||||
|
||||
/// A list of hooks that run in order before request execution.
|
||||
pub trait BeforeRequestList<Req>: BeforeRequest<Req> {
|
||||
/// The hook returned by `BeforeRequestList::then`.
|
||||
type Then<Next>: BeforeRequest<Req>
|
||||
where
|
||||
Next: BeforeRequest<Req>;
|
||||
|
||||
/// Returns a hook that, when run, runs two hooks, first `self` and then `next`.
|
||||
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next>;
|
||||
|
||||
/// Same as `then`, but helps the compiler with type inference when Next is a closure.
|
||||
fn then_fn<
|
||||
Next: FnMut(&mut context::Context, &Req) -> Fut,
|
||||
Fut: Future<Output = Result<(), ServerError>>,
|
||||
>(
|
||||
self,
|
||||
next: Next,
|
||||
) -> Self::Then<Next>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.then(next)
|
||||
}
|
||||
|
||||
/// The service fn returned by `BeforeRequestList::serving`.
|
||||
type Serve<S: Serve<Req = Req>>: Serve<Req = Req>;
|
||||
|
||||
/// Runs the list of request hooks before execution of the given serve fn.
|
||||
/// This is equivalent to `serve.before(before_request_chain)` but may be syntactically nicer.
|
||||
fn serving<S: Serve<Req = Req>>(self, serve: S) -> Self::Serve<S>;
|
||||
}
|
||||
|
||||
impl<F, Fut, Req> BeforeRequest<Req> for F
|
||||
where
|
||||
F: FnMut(&mut context::Context, &Req) -> Fut,
|
||||
Fut: Future<Output = Result<(), ServerError>>,
|
||||
{
|
||||
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> {
|
||||
self(ctx, req).await
|
||||
}
|
||||
}
|
||||
|
||||
/// A Service function that runs a hook before request execution.
|
||||
#[derive(Clone)]
|
||||
pub struct HookThenServe<Serv, Hook> {
|
||||
serve: Serv,
|
||||
hook: Hook,
|
||||
}
|
||||
|
||||
impl<Serv, Hook> HookThenServe<Serv, Hook> {
|
||||
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
|
||||
Self { serve, hook }
|
||||
}
|
||||
}
|
||||
|
||||
impl<Serv, Hook> Serve for HookThenServe<Serv, Hook>
|
||||
where
|
||||
Serv: Serve,
|
||||
Hook: BeforeRequest<Serv::Req>,
|
||||
{
|
||||
type Req = Serv::Req;
|
||||
type Resp = Serv::Resp;
|
||||
|
||||
async fn serve(
|
||||
self,
|
||||
mut ctx: context::Context,
|
||||
req: Self::Req,
|
||||
) -> Result<Serv::Resp, ServerError> {
|
||||
let HookThenServe {
|
||||
serve, mut hook, ..
|
||||
} = self;
|
||||
hook.before(&mut ctx, &req).await?;
|
||||
serve.serve(ctx, req).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a request hook builder that runs a series of hooks before request execution.
|
||||
///
|
||||
/// Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use futures::{executor::block_on, future};
|
||||
/// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self,
|
||||
/// BeforeRequest, BeforeRequestList}}};
|
||||
/// use std::{cell::Cell, io};
|
||||
///
|
||||
/// let i = Cell::new(0);
|
||||
/// let serve = request_hook::before()
|
||||
/// .then_fn(|_, _| async {
|
||||
/// assert!(i.get() == 0);
|
||||
/// i.set(1);
|
||||
/// Ok(())
|
||||
/// })
|
||||
/// .then_fn(|_, _| async {
|
||||
/// assert!(i.get() == 1);
|
||||
/// i.set(2);
|
||||
/// Ok(())
|
||||
/// })
|
||||
/// .serving(serve(|_ctx, i| async move { Ok(i + 1) }));
|
||||
/// let response = serve.clone().serve(context::current(), 1);
|
||||
/// assert!(block_on(response).is_ok());
|
||||
/// assert!(i.get() == 2);
|
||||
/// ```
|
||||
pub fn before() -> BeforeRequestNil {
|
||||
BeforeRequestNil
|
||||
}
|
||||
|
||||
/// A list of hooks that run in order before a request is executed.
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct BeforeRequestCons<First, Rest>(First, Rest);
|
||||
|
||||
/// A noop hook that runs before a request is executed.
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct BeforeRequestNil;
|
||||
|
||||
impl<Req, First: BeforeRequest<Req>, Rest: BeforeRequest<Req>> BeforeRequest<Req>
|
||||
for BeforeRequestCons<First, Rest>
|
||||
{
|
||||
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> {
|
||||
let BeforeRequestCons(first, rest) = self;
|
||||
first.before(ctx, req).await?;
|
||||
rest.before(ctx, req).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req> BeforeRequest<Req> for BeforeRequestNil {
|
||||
async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, First: BeforeRequest<Req>, Rest: BeforeRequestList<Req>> BeforeRequestList<Req>
|
||||
for BeforeRequestCons<First, Rest>
|
||||
{
|
||||
type Then<Next> = BeforeRequestCons<First, Rest::Then<Next>> where Next: BeforeRequest<Req>;
|
||||
|
||||
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next> {
|
||||
let BeforeRequestCons(first, rest) = self;
|
||||
BeforeRequestCons(first, rest.then(next))
|
||||
}
|
||||
|
||||
type Serve<S: Serve<Req = Req>> = HookThenServe<S, Self>;
|
||||
|
||||
fn serving<S: Serve<Req = Req>>(self, serve: S) -> Self::Serve<S> {
|
||||
HookThenServe::new(serve, self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req> BeforeRequestList<Req> for BeforeRequestNil {
|
||||
type Then<Next> = BeforeRequestCons<Next, BeforeRequestNil> where Next: BeforeRequest<Req>;
|
||||
|
||||
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next> {
|
||||
BeforeRequestCons(next, BeforeRequestNil)
|
||||
}
|
||||
|
||||
type Serve<S: Serve<Req = Req>> = S;
|
||||
|
||||
fn serving<S: Serve<Req = Req>>(self, serve: S) -> S {
|
||||
serve
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn before_request_list() {
|
||||
use crate::server::serve;
|
||||
use futures::executor::block_on;
|
||||
use std::cell::Cell;
|
||||
|
||||
let i = Cell::new(0);
|
||||
let serve = before()
|
||||
.then_fn(|_, _| async {
|
||||
assert!(i.get() == 0);
|
||||
i.set(1);
|
||||
Ok(())
|
||||
})
|
||||
.then_fn(|_, _| async {
|
||||
assert!(i.get() == 1);
|
||||
i.set(2);
|
||||
Ok(())
|
||||
})
|
||||
.serving(serve(|_ctx, i| async move { Ok(i + 1) }));
|
||||
let response = serve.clone().serve(context::current(), 1);
|
||||
assert!(block_on(response).is_ok());
|
||||
assert!(i.get() == 2);
|
||||
}
|
||||
57
tarpc/src/server/request_hook/before_and_after.rs
Normal file
57
tarpc/src/server/request_hook/before_and_after.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright 2022 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
//! Provides a hook that runs both before and after request execution.
|
||||
|
||||
use super::{after::AfterRequest, before::BeforeRequest};
|
||||
use crate::{context, server::Serve, ServerError};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// A Service function that runs a hook both before and after request execution.
|
||||
pub struct HookThenServeThenHook<Req, Resp, Serv, Hook> {
|
||||
serve: Serv,
|
||||
hook: Hook,
|
||||
fns: PhantomData<(fn(Req), fn(Resp))>,
|
||||
}
|
||||
|
||||
impl<Req, Resp, Serv, Hook> HookThenServeThenHook<Req, Resp, Serv, Hook> {
|
||||
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
|
||||
Self {
|
||||
serve,
|
||||
hook,
|
||||
fns: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, Serv: Clone, Hook: Clone> Clone for HookThenServeThenHook<Req, Resp, Serv, Hook> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
serve: self.serve.clone(),
|
||||
hook: self.hook.clone(),
|
||||
fns: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, Serv, Hook> Serve for HookThenServeThenHook<Req, Resp, Serv, Hook>
|
||||
where
|
||||
Serv: Serve<Req = Req, Resp = Resp>,
|
||||
Hook: BeforeRequest<Req> + AfterRequest<Resp>,
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
|
||||
async fn serve(self, mut ctx: context::Context, req: Req) -> Result<Serv::Resp, ServerError> {
|
||||
let HookThenServeThenHook {
|
||||
serve, mut hook, ..
|
||||
} = self;
|
||||
hook.before(&mut ctx, &req).await?;
|
||||
let mut resp = serve.serve(ctx, req).await;
|
||||
hook.after(&mut ctx, &mut resp).await;
|
||||
resp
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,7 @@ use crate::{
|
||||
};
|
||||
use futures::{task::*, Sink, Stream};
|
||||
use pin_project::pin_project;
|
||||
use std::{collections::VecDeque, io, mem::ManuallyDrop, pin::Pin, time::SystemTime};
|
||||
use std::{collections::VecDeque, io, pin::Pin, time::SystemTime};
|
||||
use tracing::Span;
|
||||
|
||||
#[pin_project]
|
||||
@@ -101,10 +101,11 @@ impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
||||
},
|
||||
abort_registration,
|
||||
span: Span::none(),
|
||||
response_guard: ManuallyDrop::new(ResponseGuard {
|
||||
response_guard: ResponseGuard {
|
||||
request_cancellation,
|
||||
request_id: id,
|
||||
}),
|
||||
cancel: false,
|
||||
},
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
use super::{Channel, Requests, Serve};
|
||||
use futures::{prelude::*, ready, task::*};
|
||||
use pin_project::pin_project;
|
||||
use std::pin::Pin;
|
||||
|
||||
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
|
||||
/// for each new channel. Returned by
|
||||
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
|
||||
#[must_use]
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct TokioServerExecutor<T, S> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
serve: S,
|
||||
}
|
||||
|
||||
impl<T, S> TokioServerExecutor<T, S> {
|
||||
pub(crate) fn new(inner: T, serve: S) -> Self {
|
||||
Self { inner, serve }
|
||||
}
|
||||
}
|
||||
|
||||
/// A future that drives the server by [spawning](tokio::spawn) each [response
|
||||
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
|
||||
/// [`Channel::execute`](crate::server::Channel::execute).
|
||||
#[must_use]
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct TokioChannelExecutor<T, S> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
serve: S,
|
||||
}
|
||||
|
||||
impl<T, S> TokioServerExecutor<T, S> {
|
||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||
self.as_mut().project().inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> TokioChannelExecutor<T, S> {
|
||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
||||
self.as_mut().project().inner
|
||||
}
|
||||
}
|
||||
|
||||
// Send + 'static execution helper methods.
|
||||
|
||||
impl<C> Requests<C>
|
||||
where
|
||||
C: Channel,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
{
|
||||
/// Executes all requests using the given service function. Requests are handled concurrently
|
||||
/// by [spawning](::tokio::spawn) each handler on tokio's default executor.
|
||||
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
|
||||
where
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
||||
{
|
||||
TokioChannelExecutor { inner: self, serve }
|
||||
}
|
||||
}
|
||||
|
||||
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
|
||||
where
|
||||
St: Sized + Stream<Item = C>,
|
||||
C: Channel + Send + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||
Se::Fut: Send,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||
tokio::spawn(channel.execute(self.serve.clone()));
|
||||
}
|
||||
tracing::info!("Server shutting down.");
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
|
||||
where
|
||||
C: Channel + 'static,
|
||||
C::Req: Send + 'static,
|
||||
C::Resp: Send + 'static,
|
||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
||||
S::Fut: Send,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||
match response_handler {
|
||||
Ok(resp) => {
|
||||
let server = self.serve.clone();
|
||||
tokio::spawn(async move {
|
||||
resp.execute(server).await;
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Requests stream errored out: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
@@ -14,9 +14,15 @@ use tokio::sync::mpsc;
|
||||
/// Errors that occur in the sending or receiving of messages over a channel.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ChannelError {
|
||||
/// An error occurred sending over the channel.
|
||||
#[error("an error occurred sending over the channel")]
|
||||
/// An error occurred readying to send into the channel.
|
||||
#[error("an error occurred readying to send into the channel")]
|
||||
Ready(#[source] Box<dyn Error + Send + Sync + 'static>),
|
||||
/// An error occurred sending into the channel.
|
||||
#[error("an error occurred sending into the channel")]
|
||||
Send(#[source] Box<dyn Error + Send + Sync + 'static>),
|
||||
/// An error occurred receiving from the channel.
|
||||
#[error("an error occurred receiving from the channel")]
|
||||
Receive(#[source] Box<dyn Error + Send + Sync + 'static>),
|
||||
}
|
||||
|
||||
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
|
||||
@@ -48,7 +54,10 @@ impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Item, ChannelError>>> {
|
||||
self.rx.poll_recv(cx).map(|option| option.map(Ok))
|
||||
self.rx
|
||||
.poll_recv(cx)
|
||||
.map(|option| option.map(Ok))
|
||||
.map_err(ChannelError::Receive)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,7 +68,7 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(if self.tx.is_closed() {
|
||||
Err(ChannelError::Send(CLOSED_MESSAGE.into()))
|
||||
Err(ChannelError::Ready(CLOSED_MESSAGE.into()))
|
||||
} else {
|
||||
Ok(())
|
||||
})
|
||||
@@ -110,7 +119,11 @@ impl<Item, SinkItem> Stream for Channel<Item, SinkItem> {
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Item, ChannelError>>> {
|
||||
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
|
||||
self.project()
|
||||
.rx
|
||||
.poll_next(cx)
|
||||
.map(|option| option.map(Ok))
|
||||
.map_err(ChannelError::Receive)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,7 +134,7 @@ impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
|
||||
self.project()
|
||||
.tx
|
||||
.poll_ready(cx)
|
||||
.map_err(|e| ChannelError::Send(Box::new(e)))
|
||||
.map_err(|e| ChannelError::Ready(Box::new(e)))
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||
@@ -146,16 +159,17 @@ impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg(all(test, feature = "tokio1"))]
|
||||
mod tests {
|
||||
use crate::{
|
||||
client, context,
|
||||
server::{incoming::Incoming, BaseChannel},
|
||||
client::{self, RpcError},
|
||||
context,
|
||||
server::{incoming::Incoming, serve, BaseChannel},
|
||||
transport::{
|
||||
self,
|
||||
channel::{Channel, UnboundedChannel},
|
||||
},
|
||||
ServerError,
|
||||
};
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{prelude::*, stream};
|
||||
@@ -177,25 +191,28 @@ mod tests {
|
||||
tokio::spawn(
|
||||
stream::once(future::ready(server_channel))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(|_ctx, request: String| {
|
||||
future::ready(request.parse::<u64>().map_err(|_| {
|
||||
io::Error::new(
|
||||
.execute(serve(|_ctx, request: String| async move {
|
||||
request.parse::<u64>().map_err(|_| {
|
||||
ServerError::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("{request:?} is not an int"),
|
||||
)
|
||||
}))
|
||||
})
|
||||
}))
|
||||
.for_each(|channel| async move {
|
||||
tokio::spawn(channel.for_each(|response| response));
|
||||
}),
|
||||
);
|
||||
|
||||
let client = client::new(client::Config::default(), client_channel).spawn();
|
||||
|
||||
let response1 = client.call(context::current(), "", "123".into()).await?;
|
||||
let response2 = client.call(context::current(), "", "abc".into()).await?;
|
||||
let response1 = client.call(context::current(), "", "123".into()).await;
|
||||
let response2 = client.call(context::current(), "", "abc".into()).await;
|
||||
|
||||
trace!("response1: {:?}, response2: {:?}", response1, response2);
|
||||
|
||||
assert_matches!(response1, Ok(123));
|
||||
assert_matches!(response2, Err(ref e) if e.kind() == io::ErrorKind::InvalidInput);
|
||||
assert_matches!(response2, Err(RpcError::Server(e)) if e.kind == io::ErrorKind::InvalidInput);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
fn ui() {
|
||||
let t = trybuild::TestCases::new();
|
||||
t.compile_fail("tests/compile_fail/*.rs");
|
||||
#[cfg(feature = "tokio1")]
|
||||
t.compile_fail("tests/compile_fail/tokio/*.rs");
|
||||
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
||||
t.compile_fail("tests/compile_fail/serde_transport/*.rs");
|
||||
}
|
||||
|
||||
@@ -2,10 +2,14 @@ error: unused `RequestDispatch` that must be used
|
||||
--> tests/compile_fail/must_use_request_dispatch.rs:13:9
|
||||
|
|
||||
13 | WorldClient::new(client::Config::default(), client_transport).dispatch;
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/must_use_request_dispatch.rs:11:12
|
||||
|
|
||||
11 | #[deny(unused_must_use)]
|
||||
| ^^^^^^^^^^^^^^^
|
||||
help: use `let _ = ...` to ignore the resulting value
|
||||
|
|
||||
13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch;
|
||||
| +++++++
|
||||
|
||||
@@ -2,10 +2,14 @@ error: unused `tarpc::serde_transport::tcp::Connect` that must be used
|
||||
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9
|
||||
|
|
||||
7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:5:12
|
||||
|
|
||||
5 | #[deny(unused_must_use)]
|
||||
| ^^^^^^^^^^^^^^^
|
||||
help: use `let _ = ...` to ignore the resulting value
|
||||
|
|
||||
7 | let _ = serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
|
||||
| +++++++
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
#[tarpc::service(derive_serde = false)]
|
||||
trait World {
|
||||
async fn hello(name: String) -> String;
|
||||
}
|
||||
|
||||
struct HelloServer;
|
||||
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
fn hello(name: String) -> String {
|
||||
format!("Hello, {name}!", name)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
@@ -1,11 +0,0 @@
|
||||
error: not all trait items implemented, missing: `HelloFut`
|
||||
--> $DIR/tarpc_server_missing_async.rs:9:1
|
||||
|
|
||||
9 | impl World for HelloServer {
|
||||
| ^^^^
|
||||
|
||||
error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async
|
||||
--> $DIR/tarpc_server_missing_async.rs:10:5
|
||||
|
|
||||
10 | fn hello(name: String) -> String {
|
||||
| ^^
|
||||
@@ -1,29 +0,0 @@
|
||||
use tarpc::{
|
||||
context,
|
||||
server::{self, Channel},
|
||||
};
|
||||
|
||||
#[tarpc::service]
|
||||
trait World {
|
||||
async fn hello(name: String) -> String;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct HelloServer;
|
||||
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
async fn hello(self, _: context::Context, name: String) -> String {
|
||||
format!("Hello, {name}!")
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let (_, server_transport) = tarpc::transport::channel::unbounded();
|
||||
let server = server::BaseChannel::with_defaults(server_transport);
|
||||
|
||||
#[deny(unused_must_use)]
|
||||
{
|
||||
server.execute(HelloServer.serve());
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
error: unused `TokioChannelExecutor` that must be used
|
||||
--> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9
|
||||
|
|
||||
27 | server.execute(HelloServer.serve());
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12
|
||||
|
|
||||
25 | #[deny(unused_must_use)]
|
||||
| ^^^^^^^^^^^^^^^
|
||||
@@ -1,30 +0,0 @@
|
||||
use futures::stream::once;
|
||||
use tarpc::{
|
||||
context,
|
||||
server::{self, incoming::Incoming},
|
||||
};
|
||||
|
||||
#[tarpc::service]
|
||||
trait World {
|
||||
async fn hello(name: String) -> String;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct HelloServer;
|
||||
|
||||
#[tarpc::server]
|
||||
impl World for HelloServer {
|
||||
async fn hello(self, _: context::Context, name: String) -> String {
|
||||
format!("Hello, {name}!")
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let (_, server_transport) = tarpc::transport::channel::unbounded();
|
||||
let server = once(async move { server::BaseChannel::with_defaults(server_transport) });
|
||||
|
||||
#[deny(unused_must_use)]
|
||||
{
|
||||
server.execute(HelloServer.serve());
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
error: unused `TokioServerExecutor` that must be used
|
||||
--> tests/compile_fail/tokio/must_use_server_executor.rs:28:9
|
||||
|
|
||||
28 | server.execute(HelloServer.serve());
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
||||
note: the lint level is defined here
|
||||
--> tests/compile_fail/tokio/must_use_server_executor.rs:26:12
|
||||
|
|
||||
26 | #[deny(unused_must_use)]
|
||||
| ^^^^^^^^^^^^^^^
|
||||
@@ -21,7 +21,6 @@ pub trait ColorProtocol {
|
||||
#[derive(Clone)]
|
||||
struct ColorServer;
|
||||
|
||||
#[tarpc::server]
|
||||
impl ColorProtocol for ColorServer {
|
||||
async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData {
|
||||
match color {
|
||||
@@ -31,6 +30,11 @@ impl ColorProtocol for ColorServer {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_call() -> anyhow::Result<()> {
|
||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
|
||||
@@ -40,7 +44,9 @@ async fn test_call() -> anyhow::Result<()> {
|
||||
.take(1)
|
||||
.filter_map(|r| async { r.ok() })
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(ColorServer.serve()),
|
||||
.execute(ColorServer.serve())
|
||||
.map(|channel| channel.for_each(spawn))
|
||||
.for_each(spawn),
|
||||
);
|
||||
|
||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
|
||||
17
tarpc/tests/proc_macro_hygene.rs
Normal file
17
tarpc/tests/proc_macro_hygene.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
#![no_implicit_prelude]
|
||||
extern crate tarpc as some_random_other_name;
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
mod serde1_feature {
|
||||
#[::tarpc::derive_serde]
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum TestData {
|
||||
Black,
|
||||
White,
|
||||
}
|
||||
}
|
||||
|
||||
#[::tarpc::service]
|
||||
pub trait ColorProtocol {
|
||||
async fn get_opposite_color(color: u8) -> u8;
|
||||
}
|
||||
@@ -1,13 +1,13 @@
|
||||
use assert_matches::assert_matches;
|
||||
use futures::{
|
||||
future::{join_all, ready, Ready},
|
||||
future::{join_all, ready},
|
||||
prelude::*,
|
||||
};
|
||||
use std::time::{Duration, SystemTime};
|
||||
use tarpc::{
|
||||
client::{self},
|
||||
context,
|
||||
server::{self, incoming::Incoming, BaseChannel, Channel},
|
||||
server::{incoming::Incoming, BaseChannel, Channel},
|
||||
transport::channel,
|
||||
};
|
||||
use tokio::join;
|
||||
@@ -22,39 +22,29 @@ trait Service {
|
||||
struct Server;
|
||||
|
||||
impl Service for Server {
|
||||
type AddFut = Ready<i32>;
|
||||
|
||||
fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut {
|
||||
ready(x + y)
|
||||
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
|
||||
x + y
|
||||
}
|
||||
|
||||
type HeyFut = Ready<String>;
|
||||
|
||||
fn hey(self, _: context::Context, name: String) -> Self::HeyFut {
|
||||
ready(format!("Hey, {name}."))
|
||||
async fn hey(self, _: context::Context, name: String) -> String {
|
||||
format!("Hey, {name}.")
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sequential() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
|
||||
async fn sequential() {
|
||||
let (tx, rx) = tarpc::transport::channel::unbounded();
|
||||
let client = client::new(client::Config::default(), tx).spawn();
|
||||
let channel = BaseChannel::with_defaults(rx);
|
||||
tokio::spawn(
|
||||
BaseChannel::new(server::Config::default(), rx)
|
||||
.requests()
|
||||
.execute(Server.serve()),
|
||||
channel
|
||||
.execute(tarpc::server::serve(|_, i| async move { Ok(i + 1) }))
|
||||
.for_each(|response| response),
|
||||
);
|
||||
assert_eq!(
|
||||
client.call(context::current(), "AddOne", 1).await.unwrap(),
|
||||
2
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||
|
||||
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
|
||||
assert_matches!(
|
||||
client.hey(context::current(), "Tim".into()).await,
|
||||
Ok(ref s) if s == "Hey, Tim.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -70,7 +60,6 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
|
||||
#[derive(Debug)]
|
||||
struct AllHandlersComplete;
|
||||
|
||||
#[tarpc::server]
|
||||
impl Loop for LoopServer {
|
||||
async fn r#loop(self, _: context::Context) {
|
||||
loop {
|
||||
@@ -121,7 +110,9 @@ async fn serde_tcp() -> anyhow::Result<()> {
|
||||
.take(1)
|
||||
.filter_map(|r| async { r.ok() })
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
.execute(Server.serve())
|
||||
.map(|channel| channel.for_each(spawn))
|
||||
.for_each(spawn),
|
||||
);
|
||||
|
||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||
@@ -151,7 +142,9 @@ async fn serde_uds() -> anyhow::Result<()> {
|
||||
.take(1)
|
||||
.filter_map(|r| async { r.ok() })
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
.execute(Server.serve())
|
||||
.map(|channel| channel.for_each(spawn))
|
||||
.for_each(spawn),
|
||||
);
|
||||
|
||||
let transport = serde_transport::unix::connect(&sock, Json::default).await?;
|
||||
@@ -175,7 +168,9 @@ async fn concurrent() -> anyhow::Result<()> {
|
||||
tokio::spawn(
|
||||
stream::once(ready(rx))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
.execute(Server.serve())
|
||||
.map(|channel| channel.for_each(spawn))
|
||||
.for_each(spawn),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||
@@ -199,7 +194,9 @@ async fn concurrent_join() -> anyhow::Result<()> {
|
||||
tokio::spawn(
|
||||
stream::once(ready(rx))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
.execute(Server.serve())
|
||||
.map(|channel| channel.for_each(spawn))
|
||||
.for_each(spawn),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||
@@ -216,15 +213,20 @@ async fn concurrent_join() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent_join_all() -> anyhow::Result<()> {
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
tokio::spawn(
|
||||
stream::once(ready(rx))
|
||||
.map(BaseChannel::with_defaults)
|
||||
.execute(Server.serve()),
|
||||
BaseChannel::with_defaults(rx)
|
||||
.execute(Server.serve())
|
||||
.for_each(spawn),
|
||||
);
|
||||
|
||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||
@@ -249,11 +251,9 @@ async fn counter() -> anyhow::Result<()> {
|
||||
struct CountService(u32);
|
||||
|
||||
impl Counter for &mut CountService {
|
||||
type CountFut = futures::future::Ready<u32>;
|
||||
|
||||
fn count(self, _: context::Context) -> Self::CountFut {
|
||||
async fn count(self, _: context::Context) -> u32 {
|
||||
self.0 += 1;
|
||||
futures::future::ready(self.0)
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user