mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9c86ce157 | ||
|
|
d4f579542d | ||
|
|
5011dbe057 | ||
|
|
4aa90ee933 | ||
|
|
eb91000fed | ||
|
|
006a9f3af1 | ||
|
|
a3d00b07da | ||
|
|
d62706e62c | ||
|
|
b92dd154bc | ||
|
|
a6758fd1f9 | ||
|
|
2c241cc809 | ||
|
|
263ef8a897 | ||
|
|
d50290a21c | ||
|
|
26988cb833 | ||
|
|
6cf18a1caf | ||
|
|
84932df9b4 | ||
|
|
8dc3711a80 | ||
|
|
7c5afa97bb | ||
|
|
324df5cd15 | ||
|
|
3264979993 | ||
|
|
dd63fb59bf | ||
|
|
f4db8cc5b4 | ||
|
|
e9ba350496 | ||
|
|
e6d779e70b | ||
|
|
ce5f8cfb0c | ||
|
|
4b69dc8db5 | ||
|
|
866db2a2cd |
3
.github/workflows/main.yml
vendored
3
.github/workflows/main.yml
vendored
@@ -19,10 +19,7 @@ jobs:
|
|||||||
access_token: ${{ github.token }}
|
access_token: ${{ github.token }}
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
- uses: dtolnay/rust-toolchain@stable
|
- uses: dtolnay/rust-toolchain@stable
|
||||||
with:
|
|
||||||
targets: mipsel-unknown-linux-gnu
|
|
||||||
- run: cargo check --all-features
|
- run: cargo check --all-features
|
||||||
- run: cargo check --all-features --target mipsel-unknown-linux-gnu
|
|
||||||
|
|
||||||
test:
|
test:
|
||||||
name: Test Suite
|
name: Test Suite
|
||||||
|
|||||||
23
README.md
23
README.md
@@ -67,7 +67,7 @@ Some other features of tarpc:
|
|||||||
Add to your `Cargo.toml` dependencies:
|
Add to your `Cargo.toml` dependencies:
|
||||||
|
|
||||||
```toml
|
```toml
|
||||||
tarpc = "0.32"
|
tarpc = "0.34"
|
||||||
```
|
```
|
||||||
|
|
||||||
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||||
@@ -83,7 +83,7 @@ your `Cargo.toml`:
|
|||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
tarpc = { version = "0.31", features = ["tokio1"] }
|
tarpc = { version = "0.31", features = ["tokio1"] }
|
||||||
tokio = { version = "1.0", features = ["macros"] }
|
tokio = { version = "1.0", features = ["rt-multi-thread", "macros"] }
|
||||||
```
|
```
|
||||||
|
|
||||||
In the following example, we use an in-process channel for communication between
|
In the following example, we use an in-process channel for communication between
|
||||||
@@ -93,14 +93,10 @@ For a more real-world example, see [example-service](example-service).
|
|||||||
First, let's set up the dependencies and service definition.
|
First, let's set up the dependencies and service definition.
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
|
use futures::future::{self, Ready};
|
||||||
use futures::{
|
|
||||||
future::{self, Ready},
|
|
||||||
prelude::*,
|
|
||||||
};
|
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
client, context,
|
client, context,
|
||||||
server::{self, incoming::Incoming, Channel},
|
server::{self, Channel},
|
||||||
};
|
};
|
||||||
|
|
||||||
// This is the service definition. It looks a lot like a trait definition.
|
// This is the service definition. It looks a lot like a trait definition.
|
||||||
@@ -122,13 +118,8 @@ implement it for our Server struct.
|
|||||||
struct HelloServer;
|
struct HelloServer;
|
||||||
|
|
||||||
impl World for HelloServer {
|
impl World for HelloServer {
|
||||||
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
|
async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
// an associated type representing the future output by the fn.
|
format!("Hello, {name}!")
|
||||||
|
|
||||||
type HelloFut = Ready<String>;
|
|
||||||
|
|
||||||
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
|
||||||
future::ready(format!("Hello, {name}!"))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@@ -148,7 +139,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||||
// that takes a config and any Transport as input.
|
// that takes a config and any Transport as input.
|
||||||
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
let client = WorldClient::new(client::Config::default(), client_transport).spawn();
|
||||||
|
|
||||||
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same
|
||||||
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
// args as defined, with the addition of a Context, which is always the first arg. The Context
|
||||||
|
|||||||
26
RELEASES.md
26
RELEASES.md
@@ -1,3 +1,29 @@
|
|||||||
|
## 0.34.0 (2023-12-29)
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
- `#[tarpc::server]` is no more! Service traits now use async fns.
|
||||||
|
- `Channel::execute` no longer spawns request handlers. Async-fn-in-traits makes it impossible to
|
||||||
|
add a Send bound to the future returned by `Serve::serve`. Instead, `Channel::execute` returns a
|
||||||
|
stream of futures, where each future is a request handler. To achieve the former behavior:
|
||||||
|
```rust
|
||||||
|
channel.execute(server.serve())
|
||||||
|
.for_each(|rpc| { tokio::spawn(rpc); })
|
||||||
|
```
|
||||||
|
|
||||||
|
### New Features
|
||||||
|
|
||||||
|
- Request hooks are added to the serve trait, so that it's easy to hook in cross-cutting
|
||||||
|
functionality like throttling, authorization, etc.
|
||||||
|
- The Client trait is back! This makes it possible to hook in generic client functionality like load
|
||||||
|
balancing, retries, etc.
|
||||||
|
|
||||||
|
## 0.33.0 (2023-04-01)
|
||||||
|
|
||||||
|
### Breaking Changes
|
||||||
|
|
||||||
|
Opentelemetry dependency version increased to 0.18.
|
||||||
|
|
||||||
## 0.32.0 (2023-03-24)
|
## 0.32.0 (2023-03-24)
|
||||||
|
|
||||||
### Breaking Changes
|
### Breaking Changes
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "tarpc-example-service"
|
name = "tarpc-example-service"
|
||||||
version = "0.14.0"
|
version = "0.15.0"
|
||||||
rust-version = "1.56"
|
rust-version = "1.56"
|
||||||
authors = ["Tim Kuehn <tikue@google.com>"]
|
authors = ["Tim Kuehn <tikue@google.com>"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
@@ -18,14 +18,15 @@ anyhow = "1.0"
|
|||||||
clap = { version = "3.0.0-rc.9", features = ["derive"] }
|
clap = { version = "3.0.0-rc.9", features = ["derive"] }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
opentelemetry = { version = "0.17", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.21.0" }
|
||||||
opentelemetry-jaeger = { version = "0.16", features = ["rt-tokio"] }
|
opentelemetry-jaeger = { version = "0.20.0", features = ["rt-tokio"] }
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
tarpc = { version = "0.32", path = "../tarpc", features = ["full"] }
|
tarpc = { version = "0.34", path = "../tarpc", features = ["full"] }
|
||||||
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
|
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
|
||||||
tracing = { version = "0.1" }
|
tracing = { version = "0.1" }
|
||||||
tracing-opentelemetry = "0.17"
|
tracing-opentelemetry = "0.22.0"
|
||||||
tracing-subscriber = {version = "0.3", features = ["env-filter"]}
|
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||||
|
opentelemetry_sdk = "0.21.1"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
name = "service"
|
name = "service"
|
||||||
|
|||||||
15
example-service/README.md
Normal file
15
example-service/README.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Example
|
||||||
|
|
||||||
|
Example service to demonstrate how to set up `tarpc` with [Jaeger](https://www.jaegertracing.io). To see traces Jaeger, run the following with `RUST_LOG=trace`.
|
||||||
|
|
||||||
|
## Server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --bin server -- --port 50051
|
||||||
|
```
|
||||||
|
|
||||||
|
## Client
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --bin client -- --server-addr "[::1]:50051" --name "Bob"
|
||||||
|
```
|
||||||
@@ -19,10 +19,10 @@ pub trait World {
|
|||||||
pub fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
pub fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||||
|
|
||||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
let tracer = opentelemetry_jaeger::new_agent_pipeline()
|
||||||
.with_service_name(service_name)
|
.with_service_name(service_name)
|
||||||
.with_max_packet_size(2usize.pow(13))
|
.with_max_packet_size(2usize.pow(13))
|
||||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
.install_batch(opentelemetry_sdk::runtime::Tokio)?;
|
||||||
|
|
||||||
tracing_subscriber::registry()
|
tracing_subscriber::registry()
|
||||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ struct Flags {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct HelloServer(SocketAddr);
|
struct HelloServer(SocketAddr);
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl World for HelloServer {
|
impl World for HelloServer {
|
||||||
async fn hello(self, _: context::Context, name: String) -> String {
|
async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
let sleep_time =
|
let sleep_time =
|
||||||
@@ -44,6 +43,10 @@ impl World for HelloServer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let flags = Flags::parse();
|
let flags = Flags::parse();
|
||||||
@@ -66,7 +69,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
// the generated World trait.
|
// the generated World trait.
|
||||||
.map(|channel| {
|
.map(|channel| {
|
||||||
let server = HelloServer(channel.transport().peer_addr().unwrap());
|
let server = HelloServer(channel.transport().peer_addr().unwrap());
|
||||||
channel.execute(server.serve())
|
channel.execute(server.serve()).for_each(spawn)
|
||||||
})
|
})
|
||||||
// Max 10 channels.
|
// Max 10 channels.
|
||||||
.buffer_unordered(10)
|
.buffer_unordered(10)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "tarpc-plugins"
|
name = "tarpc-plugins"
|
||||||
version = "0.12.0"
|
version = "0.13.0"
|
||||||
rust-version = "1.56"
|
rust-version = "1.56"
|
||||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|||||||
@@ -12,18 +12,18 @@ extern crate quote;
|
|||||||
extern crate syn;
|
extern crate syn;
|
||||||
|
|
||||||
use proc_macro::TokenStream;
|
use proc_macro::TokenStream;
|
||||||
use proc_macro2::{Span, TokenStream as TokenStream2};
|
use proc_macro2::TokenStream as TokenStream2;
|
||||||
use quote::{format_ident, quote, ToTokens};
|
use quote::{format_ident, quote, ToTokens};
|
||||||
use syn::{
|
use syn::{
|
||||||
braced,
|
braced,
|
||||||
ext::IdentExt,
|
ext::IdentExt,
|
||||||
parenthesized,
|
parenthesized,
|
||||||
parse::{Parse, ParseStream},
|
parse::{Parse, ParseStream},
|
||||||
parse_macro_input, parse_quote, parse_str,
|
parse_macro_input, parse_quote,
|
||||||
spanned::Spanned,
|
spanned::Spanned,
|
||||||
token::Comma,
|
token::Comma,
|
||||||
Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool,
|
Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type,
|
||||||
MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility,
|
Visibility,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Accumulates multiple errors into a result.
|
/// Accumulates multiple errors into a result.
|
||||||
@@ -220,15 +220,15 @@ impl Parse for DeriveSerde {
|
|||||||
/// Adds the following annotations to the annotated item:
|
/// Adds the following annotations to the annotated item:
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// #[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
/// #[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)]
|
||||||
/// #[serde(crate = "tarpc::serde")]
|
/// #[serde(crate = "tarpc::serde")]
|
||||||
/// # struct Foo;
|
/// # struct Foo;
|
||||||
/// ```
|
/// ```
|
||||||
#[proc_macro_attribute]
|
#[proc_macro_attribute]
|
||||||
pub fn derive_serde(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
pub fn derive_serde(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||||
let mut gen: proc_macro2::TokenStream = quote! {
|
let mut gen: proc_macro2::TokenStream = quote! {
|
||||||
#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
#[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)]
|
||||||
#[serde(crate = "tarpc::serde")]
|
#[serde(crate = "::tarpc::serde")]
|
||||||
};
|
};
|
||||||
gen.extend(proc_macro2::TokenStream::from(item));
|
gen.extend(proc_macro2::TokenStream::from(item));
|
||||||
proc_macro::TokenStream::from(gen)
|
proc_macro::TokenStream::from(gen)
|
||||||
@@ -257,11 +257,10 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
|||||||
.map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
|
.map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
|
||||||
.collect();
|
.collect();
|
||||||
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
|
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
|
||||||
let response_fut_name = &format!("{}ResponseFut", ident.unraw());
|
|
||||||
let derive_serialize = if derive_serde.0 {
|
let derive_serialize = if derive_serde.0 {
|
||||||
Some(
|
Some(
|
||||||
quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
quote! {#[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)]
|
||||||
#[serde(crate = "tarpc::serde")]},
|
#[serde(crate = "::tarpc::serde")]},
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@@ -274,10 +273,9 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
|||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
ServiceGenerator {
|
ServiceGenerator {
|
||||||
response_fut_name,
|
|
||||||
service_ident: ident,
|
service_ident: ident,
|
||||||
|
client_stub_ident: &format_ident!("{}Stub", ident),
|
||||||
server_ident: &format_ident!("Serve{}", ident),
|
server_ident: &format_ident!("Serve{}", ident),
|
||||||
response_fut_ident: &Ident::new(response_fut_name, ident.span()),
|
|
||||||
client_ident: &format_ident!("{}Client", ident),
|
client_ident: &format_ident!("{}Client", ident),
|
||||||
request_ident: &format_ident!("{}Request", ident),
|
request_ident: &format_ident!("{}Request", ident),
|
||||||
response_ident: &format_ident!("{}Response", ident),
|
response_ident: &format_ident!("{}Response", ident),
|
||||||
@@ -304,137 +302,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
|||||||
.zip(camel_case_fn_names.iter())
|
.zip(camel_case_fn_names.iter())
|
||||||
.map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
|
.map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
future_types: &camel_case_fn_names
|
|
||||||
.iter()
|
|
||||||
.map(|name| parse_str(&format!("{name}Fut")).unwrap())
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
derive_serialize: derive_serialize.as_ref(),
|
derive_serialize: derive_serialize.as_ref(),
|
||||||
}
|
}
|
||||||
.into_token_stream()
|
.into_token_stream()
|
||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// generate an identifier consisting of the method name to CamelCase with
|
|
||||||
/// Fut appended to it.
|
|
||||||
fn associated_type_for_rpc(method: &ImplItemMethod) -> String {
|
|
||||||
snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut"
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Transforms an async function into a sync one, returning a type declaration
|
|
||||||
/// for the return type (a future).
|
|
||||||
fn transform_method(method: &mut ImplItemMethod) -> ImplItemType {
|
|
||||||
method.sig.asyncness = None;
|
|
||||||
|
|
||||||
// get either the return type or ().
|
|
||||||
let ret = match &method.sig.output {
|
|
||||||
ReturnType::Default => quote!(()),
|
|
||||||
ReturnType::Type(_, ret) => quote!(#ret),
|
|
||||||
};
|
|
||||||
|
|
||||||
let fut_name = associated_type_for_rpc(method);
|
|
||||||
let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span());
|
|
||||||
|
|
||||||
// generate the updated return signature.
|
|
||||||
method.sig.output = parse_quote! {
|
|
||||||
-> ::core::pin::Pin<Box<
|
|
||||||
dyn ::core::future::Future<Output = #ret> + ::core::marker::Send
|
|
||||||
>>
|
|
||||||
};
|
|
||||||
|
|
||||||
// transform the body of the method into Box::pin(async move { body }).
|
|
||||||
let block = method.block.clone();
|
|
||||||
method.block = parse_quote! [{
|
|
||||||
Box::pin(async move
|
|
||||||
#block
|
|
||||||
)
|
|
||||||
}];
|
|
||||||
|
|
||||||
// generate and return type declaration for return type.
|
|
||||||
let t: ImplItemType = parse_quote! {
|
|
||||||
type #fut_name_ident = ::core::pin::Pin<Box<dyn ::core::future::Future<Output = #ret> + ::core::marker::Send>>;
|
|
||||||
};
|
|
||||||
|
|
||||||
t
|
|
||||||
}
|
|
||||||
|
|
||||||
#[proc_macro_attribute]
|
|
||||||
pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream {
|
|
||||||
let mut item = syn::parse_macro_input!(input as ItemImpl);
|
|
||||||
let span = item.span();
|
|
||||||
|
|
||||||
// the generated type declarations
|
|
||||||
let mut types: Vec<ImplItemType> = Vec::new();
|
|
||||||
let mut expected_non_async_types: Vec<(&ImplItemMethod, String)> = Vec::new();
|
|
||||||
let mut found_non_async_types: Vec<&ImplItemType> = Vec::new();
|
|
||||||
|
|
||||||
for inner in &mut item.items {
|
|
||||||
match inner {
|
|
||||||
ImplItem::Method(method) => {
|
|
||||||
if method.sig.asyncness.is_some() {
|
|
||||||
// if this function is declared async, transform it into a regular function
|
|
||||||
let typedecl = transform_method(method);
|
|
||||||
types.push(typedecl);
|
|
||||||
} else {
|
|
||||||
// If it's not async, keep track of all required associated types for better
|
|
||||||
// error reporting.
|
|
||||||
expected_non_async_types.push((method, associated_type_for_rpc(method)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ImplItem::Type(typedecl) => found_non_async_types.push(typedecl),
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Err(e) =
|
|
||||||
verify_types_were_provided(span, &expected_non_async_types, &found_non_async_types)
|
|
||||||
{
|
|
||||||
return TokenStream::from(e.to_compile_error());
|
|
||||||
}
|
|
||||||
|
|
||||||
// add the type declarations into the impl block
|
|
||||||
for t in types.into_iter() {
|
|
||||||
item.items.push(syn::ImplItem::Type(t));
|
|
||||||
}
|
|
||||||
|
|
||||||
TokenStream::from(quote!(#item))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn verify_types_were_provided(
|
|
||||||
span: Span,
|
|
||||||
expected: &[(&ImplItemMethod, String)],
|
|
||||||
provided: &[&ImplItemType],
|
|
||||||
) -> syn::Result<()> {
|
|
||||||
let mut result = Ok(());
|
|
||||||
for (method, expected) in expected {
|
|
||||||
if !provided.iter().any(|typedecl| typedecl.ident == expected) {
|
|
||||||
let mut e = syn::Error::new(
|
|
||||||
span,
|
|
||||||
format!("not all trait items implemented, missing: `{expected}`"),
|
|
||||||
);
|
|
||||||
let fn_span = method.sig.fn_token.span();
|
|
||||||
e.extend(syn::Error::new(
|
|
||||||
fn_span.join(method.sig.ident.span()).unwrap_or(fn_span),
|
|
||||||
format!(
|
|
||||||
"hint: `#[tarpc::server]` only rewrites async fns, and `fn {}` is not async",
|
|
||||||
method.sig.ident
|
|
||||||
),
|
|
||||||
));
|
|
||||||
match result {
|
|
||||||
Ok(_) => result = Err(e),
|
|
||||||
Err(ref mut error) => error.extend(Some(e)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
// Things needed to generate the service items: trait, serve impl, request/response enums, and
|
// Things needed to generate the service items: trait, serve impl, request/response enums, and
|
||||||
// the client stub.
|
// the client stub.
|
||||||
struct ServiceGenerator<'a> {
|
struct ServiceGenerator<'a> {
|
||||||
service_ident: &'a Ident,
|
service_ident: &'a Ident,
|
||||||
|
client_stub_ident: &'a Ident,
|
||||||
server_ident: &'a Ident,
|
server_ident: &'a Ident,
|
||||||
response_fut_ident: &'a Ident,
|
|
||||||
response_fut_name: &'a str,
|
|
||||||
client_ident: &'a Ident,
|
client_ident: &'a Ident,
|
||||||
request_ident: &'a Ident,
|
request_ident: &'a Ident,
|
||||||
response_ident: &'a Ident,
|
response_ident: &'a Ident,
|
||||||
@@ -442,7 +321,6 @@ struct ServiceGenerator<'a> {
|
|||||||
attrs: &'a [Attribute],
|
attrs: &'a [Attribute],
|
||||||
rpcs: &'a [RpcMethod],
|
rpcs: &'a [RpcMethod],
|
||||||
camel_case_idents: &'a [Ident],
|
camel_case_idents: &'a [Ident],
|
||||||
future_types: &'a [Type],
|
|
||||||
method_idents: &'a [&'a Ident],
|
method_idents: &'a [&'a Ident],
|
||||||
request_names: &'a [String],
|
request_names: &'a [String],
|
||||||
method_attrs: &'a [&'a [Attribute]],
|
method_attrs: &'a [&'a [Attribute]],
|
||||||
@@ -458,49 +336,53 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
attrs,
|
attrs,
|
||||||
rpcs,
|
rpcs,
|
||||||
vis,
|
vis,
|
||||||
future_types,
|
|
||||||
return_types,
|
return_types,
|
||||||
service_ident,
|
service_ident,
|
||||||
|
client_stub_ident,
|
||||||
|
request_ident,
|
||||||
|
response_ident,
|
||||||
server_ident,
|
server_ident,
|
||||||
..
|
..
|
||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
let types_and_fns = rpcs
|
let rpc_fns = rpcs
|
||||||
.iter()
|
.iter()
|
||||||
.zip(future_types.iter())
|
|
||||||
.zip(return_types.iter())
|
.zip(return_types.iter())
|
||||||
.map(
|
.map(
|
||||||
|(
|
|(
|
||||||
(
|
RpcMethod {
|
||||||
RpcMethod {
|
attrs, ident, args, ..
|
||||||
attrs, ident, args, ..
|
},
|
||||||
},
|
|
||||||
future_type,
|
|
||||||
),
|
|
||||||
output,
|
output,
|
||||||
)| {
|
)| {
|
||||||
let ty_doc = format!("The response future returned by [`{service_ident}::{ident}`].");
|
|
||||||
quote! {
|
quote! {
|
||||||
#[doc = #ty_doc]
|
|
||||||
type #future_type: std::future::Future<Output = #output>;
|
|
||||||
|
|
||||||
#( #attrs )*
|
#( #attrs )*
|
||||||
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type;
|
async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let stub_doc = format!("The stub trait for service [`{service_ident}`].");
|
||||||
quote! {
|
quote! {
|
||||||
#( #attrs )*
|
#( #attrs )*
|
||||||
#vis trait #service_ident: Sized {
|
#vis trait #service_ident: ::core::marker::Sized {
|
||||||
#( #types_and_fns )*
|
#( #rpc_fns )*
|
||||||
|
|
||||||
/// Returns a serving function to use with
|
/// Returns a serving function to use with
|
||||||
/// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute).
|
/// [InFlightRequest::execute](::tarpc::server::InFlightRequest::execute).
|
||||||
fn serve(self) -> #server_ident<Self> {
|
fn serve(self) -> #server_ident<Self> {
|
||||||
#server_ident { service: self }
|
#server_ident { service: self }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[doc = #stub_doc]
|
||||||
|
#vis trait #client_stub_ident: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> #client_stub_ident for S
|
||||||
|
where S: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
|
||||||
|
{
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -510,7 +392,7 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
quote! {
|
quote! {
|
||||||
/// A serving function to use with [tarpc::server::InFlightRequest::execute].
|
/// A serving function to use with [::tarpc::server::InFlightRequest::execute].
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
#vis struct #server_ident<S> {
|
#vis struct #server_ident<S> {
|
||||||
service: S,
|
service: S,
|
||||||
@@ -524,7 +406,6 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
server_ident,
|
server_ident,
|
||||||
service_ident,
|
service_ident,
|
||||||
response_ident,
|
response_ident,
|
||||||
response_fut_ident,
|
|
||||||
camel_case_idents,
|
camel_case_idents,
|
||||||
arg_pats,
|
arg_pats,
|
||||||
method_idents,
|
method_idents,
|
||||||
@@ -533,14 +414,14 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
quote! {
|
quote! {
|
||||||
impl<S> tarpc::server::Serve<#request_ident> for #server_ident<S>
|
impl<S> ::tarpc::server::Serve for #server_ident<S>
|
||||||
where S: #service_ident
|
where S: #service_ident
|
||||||
{
|
{
|
||||||
|
type Req = #request_ident;
|
||||||
type Resp = #response_ident;
|
type Resp = #response_ident;
|
||||||
type Fut = #response_fut_ident<S>;
|
|
||||||
|
|
||||||
fn method(&self, req: &#request_ident) -> Option<&'static str> {
|
fn method(&self, req: &#request_ident) -> ::core::option::Option<&'static str> {
|
||||||
Some(match req {
|
::core::option::Option::Some(match req {
|
||||||
#(
|
#(
|
||||||
#request_ident::#camel_case_idents{..} => {
|
#request_ident::#camel_case_idents{..} => {
|
||||||
#request_names
|
#request_names
|
||||||
@@ -549,15 +430,16 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
|
async fn serve(self, ctx: ::tarpc::context::Context, req: #request_ident)
|
||||||
|
-> ::core::result::Result<#response_ident, ::tarpc::ServerError> {
|
||||||
match req {
|
match req {
|
||||||
#(
|
#(
|
||||||
#request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
|
#request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
|
||||||
#response_fut_ident::#camel_case_idents(
|
::core::result::Result::Ok(#response_ident::#camel_case_idents(
|
||||||
#service_ident::#method_idents(
|
#service_ident::#method_idents(
|
||||||
self.service, ctx, #( #arg_pats ),*
|
self.service, ctx, #( #arg_pats ),*
|
||||||
)
|
).await
|
||||||
)
|
))
|
||||||
}
|
}
|
||||||
)*
|
)*
|
||||||
}
|
}
|
||||||
@@ -608,73 +490,6 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn enum_response_future(&self) -> TokenStream2 {
|
|
||||||
let &Self {
|
|
||||||
vis,
|
|
||||||
service_ident,
|
|
||||||
response_fut_ident,
|
|
||||||
camel_case_idents,
|
|
||||||
future_types,
|
|
||||||
..
|
|
||||||
} = self;
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
/// A future resolving to a server response.
|
|
||||||
#[allow(missing_docs)]
|
|
||||||
#vis enum #response_fut_ident<S: #service_ident> {
|
|
||||||
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn impl_debug_for_response_future(&self) -> TokenStream2 {
|
|
||||||
let &Self {
|
|
||||||
service_ident,
|
|
||||||
response_fut_ident,
|
|
||||||
response_fut_name,
|
|
||||||
..
|
|
||||||
} = self;
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
|
|
||||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
||||||
fmt.debug_struct(#response_fut_name).finish()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn impl_future_for_response_future(&self) -> TokenStream2 {
|
|
||||||
let &Self {
|
|
||||||
service_ident,
|
|
||||||
response_fut_ident,
|
|
||||||
response_ident,
|
|
||||||
camel_case_idents,
|
|
||||||
..
|
|
||||||
} = self;
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
impl<S: #service_ident> std::future::Future for #response_fut_ident<S> {
|
|
||||||
type Output = #response_ident;
|
|
||||||
|
|
||||||
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
|
|
||||||
-> std::task::Poll<#response_ident>
|
|
||||||
{
|
|
||||||
unsafe {
|
|
||||||
match std::pin::Pin::get_unchecked_mut(self) {
|
|
||||||
#(
|
|
||||||
#response_fut_ident::#camel_case_idents(resp) =>
|
|
||||||
std::pin::Pin::new_unchecked(resp)
|
|
||||||
.poll(cx)
|
|
||||||
.map(#response_ident::#camel_case_idents),
|
|
||||||
)*
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn struct_client(&self) -> TokenStream2 {
|
fn struct_client(&self) -> TokenStream2 {
|
||||||
let &Self {
|
let &Self {
|
||||||
vis,
|
vis,
|
||||||
@@ -688,8 +503,10 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
/// The client stub that makes RPC calls to the server. All request methods return
|
/// The client stub that makes RPC calls to the server. All request methods return
|
||||||
/// [Futures](std::future::Future).
|
/// [Futures](::core::future::Future).
|
||||||
#vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>);
|
#vis struct #client_ident<
|
||||||
|
Stub = ::tarpc::client::Channel<#request_ident, #response_ident>
|
||||||
|
>(Stub);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -705,20 +522,31 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
quote! {
|
quote! {
|
||||||
impl #client_ident {
|
impl #client_ident {
|
||||||
/// Returns a new client stub that sends requests over the given transport.
|
/// Returns a new client stub that sends requests over the given transport.
|
||||||
#vis fn new<T>(config: tarpc::client::Config, transport: T)
|
#vis fn new<T>(config: ::tarpc::client::Config, transport: T)
|
||||||
-> tarpc::client::NewClient<
|
-> ::tarpc::client::NewClient<
|
||||||
Self,
|
Self,
|
||||||
tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
|
::tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
|
||||||
>
|
>
|
||||||
where
|
where
|
||||||
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>
|
T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>>
|
||||||
{
|
{
|
||||||
let new_client = tarpc::client::new(config, transport);
|
let new_client = ::tarpc::client::new(config, transport);
|
||||||
tarpc::client::NewClient {
|
::tarpc::client::NewClient {
|
||||||
client: #client_ident(new_client.client),
|
client: #client_ident(new_client.client),
|
||||||
dispatch: new_client.dispatch,
|
dispatch: new_client.dispatch,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Stub> ::core::convert::From<Stub> for #client_ident<Stub>
|
||||||
|
where Stub: ::tarpc::client::stub::Stub<
|
||||||
|
Req = #request_ident,
|
||||||
|
Resp = #response_ident>
|
||||||
|
{
|
||||||
|
/// Returns a new client stub that sends requests over the given transport.
|
||||||
|
fn from(stub: Stub) -> Self {
|
||||||
|
#client_ident(stub)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -741,18 +569,22 @@ impl<'a> ServiceGenerator<'a> {
|
|||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
quote! {
|
quote! {
|
||||||
impl #client_ident {
|
impl<Stub> #client_ident<Stub>
|
||||||
|
where Stub: ::tarpc::client::stub::Stub<
|
||||||
|
Req = #request_ident,
|
||||||
|
Resp = #response_ident>
|
||||||
|
{
|
||||||
#(
|
#(
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
#( #method_attrs )*
|
#( #method_attrs )*
|
||||||
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*)
|
#vis fn #method_idents(&self, ctx: ::tarpc::context::Context, #( #args ),*)
|
||||||
-> impl std::future::Future<Output = Result<#return_types, tarpc::client::RpcError>> + '_ {
|
-> impl ::core::future::Future<Output = ::core::result::Result<#return_types, ::tarpc::client::RpcError>> + '_ {
|
||||||
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
|
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
|
||||||
let resp = self.0.call(ctx, #request_names, request);
|
let resp = self.0.call(ctx, #request_names, request);
|
||||||
async move {
|
async move {
|
||||||
match resp.await? {
|
match resp.await? {
|
||||||
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
|
#response_ident::#camel_case_idents(msg) => ::core::result::Result::Ok(msg),
|
||||||
_ => unreachable!(),
|
_ => ::core::unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -770,9 +602,6 @@ impl<'a> ToTokens for ServiceGenerator<'a> {
|
|||||||
self.impl_serve_for_server(),
|
self.impl_serve_for_server(),
|
||||||
self.enum_request(),
|
self.enum_request(),
|
||||||
self.enum_response(),
|
self.enum_response(),
|
||||||
self.enum_response_future(),
|
|
||||||
self.impl_debug_for_response_future(),
|
|
||||||
self.impl_future_for_response_future(),
|
|
||||||
self.struct_client(),
|
self.struct_client(),
|
||||||
self.impl_client_new(),
|
self.impl_client_new(),
|
||||||
self.impl_client_rpc_methods(),
|
self.impl_client_rpc_methods(),
|
||||||
|
|||||||
@@ -1,8 +1,3 @@
|
|||||||
use assert_type_eq::assert_type_eq;
|
|
||||||
use futures::Future;
|
|
||||||
use std::pin::Pin;
|
|
||||||
use tarpc::context;
|
|
||||||
|
|
||||||
// these need to be out here rather than inside the function so that the
|
// these need to be out here rather than inside the function so that the
|
||||||
// assert_type_eq macro can pick them up.
|
// assert_type_eq macro can pick them up.
|
||||||
#[tarpc::service]
|
#[tarpc::service]
|
||||||
@@ -12,42 +7,6 @@ trait Foo {
|
|||||||
async fn baz();
|
async fn baz();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn type_generation_works() {
|
|
||||||
#[tarpc::server]
|
|
||||||
impl Foo for () {
|
|
||||||
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
|
|
||||||
(s, i)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn bar(self, _: context::Context, s: String) -> String {
|
|
||||||
s
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn baz(self, _: context::Context) {}
|
|
||||||
}
|
|
||||||
|
|
||||||
// the assert_type_eq macro can only be used once per block.
|
|
||||||
{
|
|
||||||
assert_type_eq!(
|
|
||||||
<() as Foo>::TwoPartFut,
|
|
||||||
Pin<Box<dyn Future<Output = (String, i32)> + Send>>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
assert_type_eq!(
|
|
||||||
<() as Foo>::BarFut,
|
|
||||||
Pin<Box<dyn Future<Output = String> + Send>>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
assert_type_eq!(
|
|
||||||
<() as Foo>::BazFut,
|
|
||||||
Pin<Box<dyn Future<Output = ()> + Send>>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(non_camel_case_types)]
|
#[allow(non_camel_case_types)]
|
||||||
#[test]
|
#[test]
|
||||||
fn raw_idents_work() {
|
fn raw_idents_work() {
|
||||||
@@ -59,24 +18,6 @@ fn raw_idents_work() {
|
|||||||
async fn r#fn(r#impl: r#yield) -> r#yield;
|
async fn r#fn(r#impl: r#yield) -> r#yield;
|
||||||
async fn r#async();
|
async fn r#async();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl r#trait for () {
|
|
||||||
async fn r#await(
|
|
||||||
self,
|
|
||||||
_: context::Context,
|
|
||||||
r#struct: r#yield,
|
|
||||||
r#enum: i32,
|
|
||||||
) -> (r#yield, i32) {
|
|
||||||
(r#struct, r#enum)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
|
|
||||||
r#impl
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn r#async(self, _: context::Context) {}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -100,45 +41,4 @@ fn syntax() {
|
|||||||
#[doc = "attr"]
|
#[doc = "attr"]
|
||||||
async fn one_arg_implicit_return_error(one: String);
|
async fn one_arg_implicit_return_error(one: String);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl Syntax for () {
|
|
||||||
#[deny(warnings)]
|
|
||||||
#[allow(non_snake_case)]
|
|
||||||
async fn TestCamelCaseDoesntConflict(self, _: context::Context) {}
|
|
||||||
|
|
||||||
async fn hello(self, _: context::Context) -> String {
|
|
||||||
String::new()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn attr(self, _: context::Context, _s: String) -> String {
|
|
||||||
String::new()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn no_args_no_return(self, _: context::Context) {}
|
|
||||||
|
|
||||||
async fn no_args(self, _: context::Context) -> () {}
|
|
||||||
|
|
||||||
async fn one_arg(self, _: context::Context, _one: String) -> i32 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {}
|
|
||||||
|
|
||||||
async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String {
|
|
||||||
String::new()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn no_args_ret_error(self, _: context::Context) -> i32 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String {
|
|
||||||
String::new()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn no_arg_implicit_return_error(self, _: context::Context) {}
|
|
||||||
|
|
||||||
async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ use tarpc::context;
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn att_service_trait() {
|
fn att_service_trait() {
|
||||||
use futures::future::{ready, Ready};
|
|
||||||
|
|
||||||
#[tarpc::service]
|
#[tarpc::service]
|
||||||
trait Foo {
|
trait Foo {
|
||||||
async fn two_part(s: String, i: i32) -> (String, i32);
|
async fn two_part(s: String, i: i32) -> (String, i32);
|
||||||
@@ -12,19 +10,16 @@ fn att_service_trait() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Foo for () {
|
impl Foo for () {
|
||||||
type TwoPartFut = Ready<(String, i32)>;
|
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
|
||||||
fn two_part(self, _: context::Context, s: String, i: i32) -> Self::TwoPartFut {
|
(s, i)
|
||||||
ready((s, i))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type BarFut = Ready<String>;
|
async fn bar(self, _: context::Context, s: String) -> String {
|
||||||
fn bar(self, _: context::Context, s: String) -> Self::BarFut {
|
s
|
||||||
ready(s)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type BazFut = Ready<()>;
|
async fn baz(self, _: context::Context) {
|
||||||
fn baz(self, _: context::Context) -> Self::BazFut {
|
()
|
||||||
ready(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -32,8 +27,6 @@ fn att_service_trait() {
|
|||||||
#[allow(non_camel_case_types)]
|
#[allow(non_camel_case_types)]
|
||||||
#[test]
|
#[test]
|
||||||
fn raw_idents() {
|
fn raw_idents() {
|
||||||
use futures::future::{ready, Ready};
|
|
||||||
|
|
||||||
type r#yield = String;
|
type r#yield = String;
|
||||||
|
|
||||||
#[tarpc::service]
|
#[tarpc::service]
|
||||||
@@ -44,19 +37,21 @@ fn raw_idents() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl r#trait for () {
|
impl r#trait for () {
|
||||||
type AwaitFut = Ready<(r#yield, i32)>;
|
async fn r#await(
|
||||||
fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut {
|
self,
|
||||||
ready((r#struct, r#enum))
|
_: context::Context,
|
||||||
|
r#struct: r#yield,
|
||||||
|
r#enum: i32,
|
||||||
|
) -> (r#yield, i32) {
|
||||||
|
(r#struct, r#enum)
|
||||||
}
|
}
|
||||||
|
|
||||||
type FnFut = Ready<r#yield>;
|
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
|
||||||
fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut {
|
r#impl
|
||||||
ready(r#impl)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AsyncFut = Ready<()>;
|
async fn r#async(self, _: context::Context) {
|
||||||
fn r#async(self, _: context::Context) -> Self::AsyncFut {
|
()
|
||||||
ready(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "tarpc"
|
name = "tarpc"
|
||||||
version = "0.32.0"
|
version = "0.34.0"
|
||||||
rust-version = "1.58.0"
|
rust-version = "1.58.0"
|
||||||
authors = [
|
authors = [
|
||||||
"Adam Wright <adam.austin.wright@gmail.com>",
|
"Adam Wright <adam.austin.wright@gmail.com>",
|
||||||
@@ -19,7 +19,7 @@ description = "An RPC framework for Rust with a focus on ease of use."
|
|||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
|
||||||
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"]
|
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive", "serde/rc"]
|
||||||
tokio1 = ["tokio/rt"]
|
tokio1 = ["tokio/rt"]
|
||||||
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
|
||||||
serde-transport-json = ["tokio-serde/json"]
|
serde-transport-json = ["tokio-serde/json"]
|
||||||
@@ -49,7 +49,7 @@ pin-project = "1.0"
|
|||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
serde = { optional = true, version = "1.0", features = ["derive"] }
|
serde = { optional = true, version = "1.0", features = ["derive"] }
|
||||||
static_assertions = "1.1.0"
|
static_assertions = "1.1.0"
|
||||||
tarpc-plugins = { path = "../plugins", version = "0.12" }
|
tarpc-plugins = { path = "../plugins", version = "0.13" }
|
||||||
thiserror = "1.0"
|
thiserror = "1.0"
|
||||||
tokio = { version = "1", features = ["time"] }
|
tokio = { version = "1", features = ["time"] }
|
||||||
tokio-util = { version = "0.7.3", features = ["time"] }
|
tokio-util = { version = "0.7.3", features = ["time"] }
|
||||||
@@ -58,8 +58,8 @@ tracing = { version = "0.1", default-features = false, features = [
|
|||||||
"attributes",
|
"attributes",
|
||||||
"log",
|
"log",
|
||||||
] }
|
] }
|
||||||
tracing-opentelemetry = { version = "0.17.2", default-features = false }
|
tracing-opentelemetry = { version = "0.18.0", default-features = false }
|
||||||
opentelemetry = { version = "0.17.0", default-features = false }
|
opentelemetry = { version = "0.18.0", default-features = false }
|
||||||
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
@@ -68,14 +68,15 @@ bincode = "1.3"
|
|||||||
bytes = { version = "1", features = ["serde"] }
|
bytes = { version = "1", features = ["serde"] }
|
||||||
flate2 = "1.0"
|
flate2 = "1.0"
|
||||||
futures-test = "0.3"
|
futures-test = "0.3"
|
||||||
opentelemetry = { version = "0.17.0", default-features = false, features = [
|
opentelemetry = { version = "0.18.0", default-features = false, features = [
|
||||||
"rt-tokio",
|
"rt-tokio",
|
||||||
] }
|
] }
|
||||||
opentelemetry-jaeger = { version = "0.16.0", features = ["rt-tokio"] }
|
opentelemetry-jaeger = { version = "0.17.0", features = ["rt-tokio"] }
|
||||||
pin-utils = "0.1.0-alpha"
|
pin-utils = "0.1.0-alpha"
|
||||||
serde_bytes = "0.11"
|
serde_bytes = "0.11"
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
tokio = { version = "1", features = ["full", "test-util"] }
|
tokio = { version = "1", features = ["full", "test-util", "tracing"] }
|
||||||
|
console-subscriber = "0.1"
|
||||||
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
|
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
|
||||||
trybuild = "1.0"
|
trybuild = "1.0"
|
||||||
tokio-rustls = "0.23"
|
tokio-rustls = "0.23"
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
|
// Copyright 2022 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
|
use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
|
||||||
use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt};
|
use futures::{prelude::*, Sink, SinkExt, Stream, StreamExt, TryStreamExt};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_bytes::ByteBuf;
|
use serde_bytes::ByteBuf;
|
||||||
use std::{io, io::Read, io::Write};
|
use std::{io, io::Read, io::Write};
|
||||||
@@ -99,13 +105,16 @@ pub trait World {
|
|||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct HelloServer;
|
struct HelloServer;
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl World for HelloServer {
|
impl World for HelloServer {
|
||||||
async fn hello(self, _: context::Context, name: String) -> String {
|
async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
format!("Hey, {name}!")
|
format!("Hey, {name}!")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let mut incoming = tcp::listen("localhost:0", Bincode::default).await?;
|
let mut incoming = tcp::listen("localhost:0", Bincode::default).await?;
|
||||||
@@ -114,6 +123,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let transport = incoming.next().await.unwrap().unwrap();
|
let transport = incoming.next().await.unwrap().unwrap();
|
||||||
BaseChannel::with_defaults(add_compression(transport))
|
BaseChannel::with_defaults(add_compression(transport))
|
||||||
.execute(HelloServer.serve())
|
.execute(HelloServer.serve())
|
||||||
|
.for_each(spawn)
|
||||||
.await;
|
.await;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,10 @@
|
|||||||
|
// Copyright 2022 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
use futures::prelude::*;
|
||||||
use tarpc::context::Context;
|
use tarpc::context::Context;
|
||||||
use tarpc::serde_transport as transport;
|
use tarpc::serde_transport as transport;
|
||||||
use tarpc::server::{BaseChannel, Channel};
|
use tarpc::server::{BaseChannel, Channel};
|
||||||
@@ -13,7 +20,6 @@ pub trait PingService {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct Service;
|
struct Service;
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl PingService for Service {
|
impl PingService for Service {
|
||||||
async fn ping(self, _: Context) {}
|
async fn ping(self, _: Context) {}
|
||||||
}
|
}
|
||||||
@@ -26,13 +32,18 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let listener = UnixListener::bind(bind_addr).unwrap();
|
let listener = UnixListener::bind(bind_addr).unwrap();
|
||||||
let codec_builder = LengthDelimitedCodec::builder();
|
let codec_builder = LengthDelimitedCodec::builder();
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
let (conn, _addr) = listener.accept().await.unwrap();
|
let (conn, _addr) = listener.accept().await.unwrap();
|
||||||
let framed = codec_builder.new_framed(conn);
|
let framed = codec_builder.new_framed(conn);
|
||||||
let transport = transport::new(framed, Bincode::default());
|
let transport = transport::new(framed, Bincode::default());
|
||||||
|
|
||||||
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
|
let fut = BaseChannel::with_defaults(transport)
|
||||||
|
.execute(Service.serve())
|
||||||
|
.for_each(spawn);
|
||||||
tokio::spawn(fut);
|
tokio::spawn(fut);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -79,7 +79,6 @@ struct Subscriber {
|
|||||||
topics: Vec<String>,
|
topics: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl subscriber::Subscriber for Subscriber {
|
impl subscriber::Subscriber for Subscriber {
|
||||||
async fn topics(self, _: context::Context) -> Vec<String> {
|
async fn topics(self, _: context::Context) -> Vec<String> {
|
||||||
self.topics.clone()
|
self.topics.clone()
|
||||||
@@ -117,7 +116,8 @@ impl Subscriber {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve()));
|
let (handler, abort_handle) =
|
||||||
|
future::abortable(handler.execute(subscriber.serve()).for_each(spawn));
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
match handler.await {
|
match handler.await {
|
||||||
Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."),
|
Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."),
|
||||||
@@ -143,6 +143,10 @@ struct PublisherAddrs {
|
|||||||
subscriptions: SocketAddr,
|
subscriptions: SocketAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
impl Publisher {
|
impl Publisher {
|
||||||
async fn start(self) -> io::Result<PublisherAddrs> {
|
async fn start(self) -> io::Result<PublisherAddrs> {
|
||||||
let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?;
|
let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?;
|
||||||
@@ -162,6 +166,7 @@ impl Publisher {
|
|||||||
|
|
||||||
server::BaseChannel::with_defaults(publisher)
|
server::BaseChannel::with_defaults(publisher)
|
||||||
.execute(self.serve())
|
.execute(self.serve())
|
||||||
|
.for_each(spawn)
|
||||||
.await
|
.await
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -257,7 +262,6 @@ impl Publisher {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl publisher::Publisher for Publisher {
|
impl publisher::Publisher for Publisher {
|
||||||
async fn publish(self, _: context::Context, topic: String, message: String) {
|
async fn publish(self, _: context::Context, topic: String, message: String) {
|
||||||
info!("received message to publish.");
|
info!("received message to publish.");
|
||||||
@@ -282,7 +286,7 @@ impl publisher::Publisher for Publisher {
|
|||||||
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
|
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
|
||||||
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||||
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
|
||||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
let tracer = opentelemetry_jaeger::new_agent_pipeline()
|
||||||
.with_service_name(service_name)
|
.with_service_name(service_name)
|
||||||
.with_max_packet_size(2usize.pow(13))
|
.with_max_packet_size(2usize.pow(13))
|
||||||
.install_batch(opentelemetry::runtime::Tokio)?;
|
.install_batch(opentelemetry::runtime::Tokio)?;
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
// license that can be found in the LICENSE file or at
|
// license that can be found in the LICENSE file or at
|
||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use futures::future::{self, Ready};
|
use futures::prelude::*;
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
client, context,
|
client, context,
|
||||||
server::{self, Channel},
|
server::{self, Channel},
|
||||||
@@ -23,22 +23,21 @@ pub trait World {
|
|||||||
struct HelloServer;
|
struct HelloServer;
|
||||||
|
|
||||||
impl World for HelloServer {
|
impl World for HelloServer {
|
||||||
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
|
async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
// an associated type representing the future output by the fn.
|
format!("Hello, {name}!")
|
||||||
|
|
||||||
type HelloFut = Ready<String>;
|
|
||||||
|
|
||||||
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
|
||||||
future::ready(format!("Hello, {name}!"))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||||
|
|
||||||
let server = server::BaseChannel::with_defaults(server_transport);
|
let server = server::BaseChannel::with_defaults(server_transport);
|
||||||
tokio::spawn(server.execute(HelloServer.serve()));
|
tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn));
|
||||||
|
|
||||||
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||||
// that takes a config and any Transport as input.
|
// that takes a config and any Transport as input.
|
||||||
|
|||||||
@@ -1,3 +1,10 @@
|
|||||||
|
// Copyright 2023 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
use futures::prelude::*;
|
||||||
use rustls_pemfile::certs;
|
use rustls_pemfile::certs;
|
||||||
use std::io::{BufReader, Cursor};
|
use std::io::{BufReader, Cursor};
|
||||||
use std::net::{IpAddr, Ipv4Addr};
|
use std::net::{IpAddr, Ipv4Addr};
|
||||||
@@ -6,8 +13,8 @@ use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio_rustls::rustls::{self, Certificate, OwnedTrustAnchor, RootCertStore};
|
use tokio_rustls::rustls::{self, RootCertStore};
|
||||||
use tokio_rustls::{webpki, TlsAcceptor, TlsConnector};
|
use tokio_rustls::{TlsAcceptor, TlsConnector};
|
||||||
|
|
||||||
use tarpc::context::Context;
|
use tarpc::context::Context;
|
||||||
use tarpc::serde_transport as transport;
|
use tarpc::serde_transport as transport;
|
||||||
@@ -23,7 +30,6 @@ pub trait PingService {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct Service;
|
struct Service;
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl PingService for Service {
|
impl PingService for Service {
|
||||||
async fn ping(self, _: Context) -> String {
|
async fn ping(self, _: Context) -> String {
|
||||||
"🔒".to_owned()
|
"🔒".to_owned()
|
||||||
@@ -32,7 +38,7 @@ impl PingService for Service {
|
|||||||
|
|
||||||
// certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca
|
// certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca
|
||||||
// used on client-side for server tls
|
// used on client-side for server tls
|
||||||
const END_CHAIN: &[u8] = include_bytes!("certs/eddsa/end.chain");
|
const END_CHAIN: &str = include_str!("certs/eddsa/end.chain");
|
||||||
// used on client-side for client-auth
|
// used on client-side for client-auth
|
||||||
const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key");
|
const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key");
|
||||||
const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert");
|
const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert");
|
||||||
@@ -43,6 +49,14 @@ const END_PRIVATEKEY: &str = include_str!("certs/eddsa/end.key");
|
|||||||
// used on server-side for client-auth
|
// used on server-side for client-auth
|
||||||
const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain");
|
const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain");
|
||||||
|
|
||||||
|
pub fn load_certs(data: &str) -> Vec<rustls::Certificate> {
|
||||||
|
certs(&mut BufReader::new(Cursor::new(data)))
|
||||||
|
.unwrap()
|
||||||
|
.into_iter()
|
||||||
|
.map(rustls::Certificate)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn load_private_key(key: &str) -> rustls::PrivateKey {
|
pub fn load_private_key(key: &str) -> rustls::PrivateKey {
|
||||||
let mut reader = BufReader::new(Cursor::new(key));
|
let mut reader = BufReader::new(Cursor::new(key));
|
||||||
loop {
|
loop {
|
||||||
@@ -57,27 +71,22 @@ pub fn load_private_key(key: &str) -> rustls::PrivateKey {
|
|||||||
panic!("no keys found in {:?} (encrypted keys not supported)", key);
|
panic!("no keys found in {:?} (encrypted keys not supported)", key);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
// -------------------- start here to setup tls tcp tokio stream --------------------------
|
// -------------------- start here to setup tls tcp tokio stream --------------------------
|
||||||
// ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs
|
// ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs
|
||||||
// ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs
|
// ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs
|
||||||
let cert = certs(&mut BufReader::new(Cursor::new(END_CERT)))
|
let cert = load_certs(END_CERT);
|
||||||
.unwrap()
|
|
||||||
.into_iter()
|
|
||||||
.map(rustls::Certificate)
|
|
||||||
.collect();
|
|
||||||
let key = load_private_key(END_PRIVATEKEY);
|
let key = load_private_key(END_PRIVATEKEY);
|
||||||
let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000);
|
let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000);
|
||||||
|
|
||||||
// ------------- server side client_auth cert loading start
|
// ------------- server side client_auth cert loading start
|
||||||
let roots: Vec<Certificate> = certs(&mut BufReader::new(Cursor::new(CLIENT_CHAIN_CLIENT_AUTH)))
|
|
||||||
.unwrap()
|
|
||||||
.into_iter()
|
|
||||||
.map(rustls::Certificate)
|
|
||||||
.collect();
|
|
||||||
let mut client_auth_roots = RootCertStore::empty();
|
let mut client_auth_roots = RootCertStore::empty();
|
||||||
for root in roots {
|
for root in load_certs(CLIENT_CHAIN_CLIENT_AUTH) {
|
||||||
client_auth_roots.add(&root).unwrap();
|
client_auth_roots.add(&root).unwrap();
|
||||||
}
|
}
|
||||||
let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots);
|
let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots);
|
||||||
@@ -96,38 +105,27 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
let (stream, _peer_addr) = listener.accept().await.unwrap();
|
let (stream, _peer_addr) = listener.accept().await.unwrap();
|
||||||
let acceptor = acceptor.clone();
|
|
||||||
let tls_stream = acceptor.accept(stream).await.unwrap();
|
let tls_stream = acceptor.accept(stream).await.unwrap();
|
||||||
let framed = codec_builder.new_framed(tls_stream);
|
let framed = codec_builder.new_framed(tls_stream);
|
||||||
|
|
||||||
let transport = transport::new(framed, Bincode::default());
|
let transport = transport::new(framed, Bincode::default());
|
||||||
|
|
||||||
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
|
let fut = BaseChannel::with_defaults(transport)
|
||||||
|
.execute(Service.serve())
|
||||||
|
.for_each(spawn);
|
||||||
tokio::spawn(fut);
|
tokio::spawn(fut);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// ---------------------- client connection ---------------------
|
// ---------------------- client connection ---------------------
|
||||||
// cert loading from: https://github.com/tokio-rs/tls/blob/357bc562483dcf04c1f8d08bd1a831b144bf7d4c/tokio-rustls/tests/test.rs#L113
|
|
||||||
// tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
|
// tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
|
||||||
let chain = certs(&mut std::io::Cursor::new(END_CHAIN)).unwrap();
|
|
||||||
let mut root_store = rustls::RootCertStore::empty();
|
let mut root_store = rustls::RootCertStore::empty();
|
||||||
root_store.add_server_trust_anchors(chain.iter().map(|cert| {
|
for root in load_certs(END_CHAIN) {
|
||||||
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
|
root_store.add(&root).unwrap();
|
||||||
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
}
|
||||||
ta.subject,
|
|
||||||
ta.spki,
|
|
||||||
ta.name_constraints,
|
|
||||||
)
|
|
||||||
}));
|
|
||||||
|
|
||||||
let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH);
|
let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH);
|
||||||
let client_auth_certs: Vec<Certificate> =
|
let client_auth_certs = load_certs(CLIENT_CERT_CLIENT_AUTH);
|
||||||
certs(&mut BufReader::new(Cursor::new(CLIENT_CERT_CLIENT_AUTH)))
|
|
||||||
.unwrap()
|
|
||||||
.into_iter()
|
|
||||||
.map(rustls::Certificate)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let config = rustls::ClientConfig::builder()
|
let config = rustls::ClientConfig::builder()
|
||||||
.with_safe_defaults()
|
.with_safe_defaults()
|
||||||
|
|||||||
@@ -4,13 +4,34 @@
|
|||||||
// license that can be found in the LICENSE file or at
|
// license that can be found in the LICENSE file or at
|
||||||
// https://opensource.org/licenses/MIT.
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
use crate::{add::Add as AddService, double::Double as DoubleService};
|
use crate::{
|
||||||
use futures::{future, prelude::*};
|
add::{Add as AddService, AddStub},
|
||||||
use tarpc::{
|
double::Double as DoubleService,
|
||||||
client, context,
|
|
||||||
server::{incoming::Incoming, BaseChannel},
|
|
||||||
tokio_serde::formats::Json,
|
|
||||||
};
|
};
|
||||||
|
use futures::{future, prelude::*};
|
||||||
|
use std::{
|
||||||
|
io,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use tarpc::{
|
||||||
|
client::{
|
||||||
|
self,
|
||||||
|
stub::{load_balance, retry},
|
||||||
|
RpcError,
|
||||||
|
},
|
||||||
|
context, serde_transport,
|
||||||
|
server::{
|
||||||
|
incoming::{spawn_incoming, Incoming},
|
||||||
|
request_hook::{self, BeforeRequestList},
|
||||||
|
BaseChannel,
|
||||||
|
},
|
||||||
|
tokio_serde::formats::Json,
|
||||||
|
ClientMessage, Response, ServerError, Transport,
|
||||||
|
};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
pub mod add {
|
pub mod add {
|
||||||
@@ -32,7 +53,6 @@ pub mod double {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct AddServer;
|
struct AddServer;
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl AddService for AddServer {
|
impl AddService for AddServer {
|
||||||
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
|
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
|
||||||
x + y
|
x + y
|
||||||
@@ -40,12 +60,14 @@ impl AddService for AddServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct DoubleServer {
|
struct DoubleServer<Stub> {
|
||||||
add_client: add::AddClient,
|
add_client: add::AddClient<Stub>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tarpc::server]
|
impl<Stub> DoubleService for DoubleServer<Stub>
|
||||||
impl DoubleService for DoubleServer {
|
where
|
||||||
|
Stub: AddStub + Clone + Send + Sync + 'static,
|
||||||
|
{
|
||||||
async fn double(self, _: context::Context, x: i32) -> Result<i32, String> {
|
async fn double(self, _: context::Context, x: i32) -> Result<i32, String> {
|
||||||
self.add_client
|
self.add_client
|
||||||
.add(context::current(), x, x)
|
.add(context::current(), x, x)
|
||||||
@@ -55,7 +77,7 @@ impl DoubleService for DoubleServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
||||||
let tracer = opentelemetry_jaeger::new_pipeline()
|
let tracer = opentelemetry_jaeger::new_agent_pipeline()
|
||||||
.with_service_name(service_name)
|
.with_service_name(service_name)
|
||||||
.with_auto_split_batch(true)
|
.with_auto_split_batch(true)
|
||||||
.with_max_packet_size(2usize.pow(13))
|
.with_max_packet_size(2usize.pow(13))
|
||||||
@@ -70,32 +92,88 @@ fn init_tracing(service_name: &str) -> anyhow::Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn listen_on_random_port<Item, SinkItem>() -> anyhow::Result<(
|
||||||
|
impl Stream<Item = serde_transport::Transport<TcpStream, Item, SinkItem, Json<Item, SinkItem>>>,
|
||||||
|
std::net::SocketAddr,
|
||||||
|
)>
|
||||||
|
where
|
||||||
|
Item: for<'de> serde::Deserialize<'de>,
|
||||||
|
SinkItem: serde::Serialize,
|
||||||
|
{
|
||||||
|
let listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||||
|
.await?
|
||||||
|
.filter_map(|r| future::ready(r.ok()))
|
||||||
|
.take(1);
|
||||||
|
let addr = listener.get_ref().get_ref().local_addr();
|
||||||
|
Ok((listener, addr))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_stub<Req, Resp, const N: usize>(
|
||||||
|
backends: [impl Transport<ClientMessage<Arc<Req>>, Response<Resp>> + Send + Sync + 'static; N],
|
||||||
|
) -> retry::Retry<
|
||||||
|
impl Fn(&Result<Resp, RpcError>, u32) -> bool + Clone,
|
||||||
|
load_balance::RoundRobin<client::Channel<Arc<Req>, Resp>>,
|
||||||
|
>
|
||||||
|
where
|
||||||
|
Req: Send + Sync + 'static,
|
||||||
|
Resp: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let stub = load_balance::RoundRobin::new(
|
||||||
|
backends
|
||||||
|
.into_iter()
|
||||||
|
.map(|transport| tarpc::client::new(client::Config::default(), transport).spawn())
|
||||||
|
.collect(),
|
||||||
|
);
|
||||||
|
let stub = retry::Retry::new(stub, |resp, attempts| {
|
||||||
|
if let Err(e) = resp {
|
||||||
|
tracing::warn!("Got an error: {e:?}");
|
||||||
|
attempts < 3
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
});
|
||||||
|
stub
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
init_tracing("tarpc_tracing_example")?;
|
init_tracing("tarpc_tracing_example")?;
|
||||||
|
|
||||||
let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
let (add_listener1, addr1) = listen_on_random_port().await?;
|
||||||
.await?
|
let (add_listener2, addr2) = listen_on_random_port().await?;
|
||||||
.filter_map(|r| future::ready(r.ok()));
|
let something_bad_happened = Arc::new(AtomicBool::new(false));
|
||||||
let addr = add_listener.get_ref().local_addr();
|
let server = request_hook::before()
|
||||||
let add_server = add_listener
|
.then_fn(move |_: &mut _, _: &_| {
|
||||||
.map(BaseChannel::with_defaults)
|
let something_bad_happened = something_bad_happened.clone();
|
||||||
.take(1)
|
async move {
|
||||||
.execute(AddServer.serve());
|
if something_bad_happened.fetch_xor(true, Ordering::Relaxed) {
|
||||||
tokio::spawn(add_server);
|
Err(ServerError::new(
|
||||||
|
io::ErrorKind::NotFound,
|
||||||
|
"Gamma Ray!".into(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.serving(AddServer.serve());
|
||||||
|
let add_server = add_listener1
|
||||||
|
.chain(add_listener2)
|
||||||
|
.map(BaseChannel::with_defaults);
|
||||||
|
tokio::spawn(spawn_incoming(add_server.execute(server)));
|
||||||
|
|
||||||
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
let add_client = add::AddClient::from(make_stub([
|
||||||
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn();
|
tarpc::serde_transport::tcp::connect(addr1, Json::default).await?,
|
||||||
|
tarpc::serde_transport::tcp::connect(addr2, Json::default).await?,
|
||||||
|
]));
|
||||||
|
|
||||||
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
|
||||||
.await?
|
.await?
|
||||||
.filter_map(|r| future::ready(r.ok()));
|
.filter_map(|r| future::ready(r.ok()));
|
||||||
let addr = double_listener.get_ref().local_addr();
|
let addr = double_listener.get_ref().local_addr();
|
||||||
let double_server = double_listener
|
let double_server = double_listener.map(BaseChannel::with_defaults).take(1);
|
||||||
.map(BaseChannel::with_defaults)
|
let server = DoubleServer { add_client }.serve();
|
||||||
.take(1)
|
tokio::spawn(spawn_incoming(double_server.execute(server)));
|
||||||
.execute(DoubleServer { add_client }.serve());
|
|
||||||
tokio::spawn(double_server);
|
|
||||||
|
|
||||||
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
|
||||||
let double_client =
|
let double_client =
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
//! Provides a client that connects to a server and sends multiplexed requests.
|
//! Provides a client that connects to a server and sends multiplexed requests.
|
||||||
|
|
||||||
mod in_flight_requests;
|
mod in_flight_requests;
|
||||||
|
pub mod stub;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||||
@@ -500,7 +501,6 @@ where
|
|||||||
// poll_next_request only returns Ready if there is room to buffer another request.
|
// poll_next_request only returns Ready if there is room to buffer another request.
|
||||||
// Therefore, we can call write_request without fear of erroring due to a full
|
// Therefore, we can call write_request without fear of erroring due to a full
|
||||||
// buffer.
|
// buffer.
|
||||||
let request_id = request_id;
|
|
||||||
let request = ClientMessage::Request(Request {
|
let request = ClientMessage::Request(Request {
|
||||||
id: request_id,
|
id: request_id,
|
||||||
message: request,
|
message: request,
|
||||||
@@ -543,10 +543,15 @@ where
|
|||||||
|
|
||||||
/// Sends a server response to the client task that initiated the associated request.
|
/// Sends a server response to the client task that initiated the associated request.
|
||||||
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
|
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
|
||||||
self.in_flight_requests().complete_request(
|
if let Some(span) = self.in_flight_requests().complete_request(
|
||||||
response.request_id,
|
response.request_id,
|
||||||
response.message.map_err(RpcError::Server),
|
response.message.map_err(RpcError::Server),
|
||||||
)
|
) {
|
||||||
|
let _entered = span.enter();
|
||||||
|
tracing::info!("ReceiveResponse");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -77,20 +77,18 @@ impl<Res> InFlightRequests<Res> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Removes a request without aborting. Returns true iff the request was found.
|
/// Removes a request without aborting. Returns true iff the request was found.
|
||||||
pub fn complete_request(&mut self, request_id: u64, result: Res) -> bool {
|
pub fn complete_request(&mut self, request_id: u64, result: Res) -> Option<Span> {
|
||||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||||
let _entered = request_data.span.enter();
|
|
||||||
tracing::info!("ReceiveResponse");
|
|
||||||
self.request_data.compact(0.1);
|
self.request_data.compact(0.1);
|
||||||
self.deadlines.remove(&request_data.deadline_key);
|
self.deadlines.remove(&request_data.deadline_key);
|
||||||
let _ = request_data.response_completion.send(result);
|
let _ = request_data.response_completion.send(result);
|
||||||
return true;
|
return Some(request_data.span);
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::debug!("No in-flight request found for request_id = {request_id}.");
|
tracing::debug!("No in-flight request found for request_id = {request_id}.");
|
||||||
|
|
||||||
// If the response completion was absent, then the request was already canceled.
|
// If the response completion was absent, then the request was already canceled.
|
||||||
false
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Completes all requests using the provided function.
|
/// Completes all requests using the provided function.
|
||||||
|
|||||||
45
tarpc/src/client/stub.rs
Normal file
45
tarpc/src/client/stub.rs
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
//! Provides a Stub trait, implemented by types that can call remote services.
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
client::{Channel, RpcError},
|
||||||
|
context,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub mod load_balance;
|
||||||
|
pub mod retry;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod mock;
|
||||||
|
|
||||||
|
/// A connection to a remote service.
|
||||||
|
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
|
||||||
|
#[allow(async_fn_in_trait)]
|
||||||
|
pub trait Stub {
|
||||||
|
/// The service request type.
|
||||||
|
type Req;
|
||||||
|
|
||||||
|
/// The service response type.
|
||||||
|
type Resp;
|
||||||
|
|
||||||
|
/// Calls a remote service.
|
||||||
|
async fn call(
|
||||||
|
&self,
|
||||||
|
ctx: context::Context,
|
||||||
|
request_name: &'static str,
|
||||||
|
request: Self::Req,
|
||||||
|
) -> Result<Self::Resp, RpcError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp> Stub for Channel<Req, Resp> {
|
||||||
|
type Req = Req;
|
||||||
|
type Resp = Resp;
|
||||||
|
|
||||||
|
async fn call(
|
||||||
|
&self,
|
||||||
|
ctx: context::Context,
|
||||||
|
request_name: &'static str,
|
||||||
|
request: Req,
|
||||||
|
) -> Result<Self::Resp, RpcError> {
|
||||||
|
Self::call(self, ctx, request_name, request).await
|
||||||
|
}
|
||||||
|
}
|
||||||
279
tarpc/src/client/stub/load_balance.rs
Normal file
279
tarpc/src/client/stub/load_balance.rs
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
//! Provides load-balancing [Stubs](crate::client::stub::Stub).
|
||||||
|
|
||||||
|
pub use consistent_hash::ConsistentHash;
|
||||||
|
pub use round_robin::RoundRobin;
|
||||||
|
|
||||||
|
/// Provides a stub that load-balances with a simple round-robin strategy.
|
||||||
|
mod round_robin {
|
||||||
|
use crate::{
|
||||||
|
client::{stub, RpcError},
|
||||||
|
context,
|
||||||
|
};
|
||||||
|
use cycle::AtomicCycle;
|
||||||
|
|
||||||
|
impl<Stub> stub::Stub for RoundRobin<Stub>
|
||||||
|
where
|
||||||
|
Stub: stub::Stub,
|
||||||
|
{
|
||||||
|
type Req = Stub::Req;
|
||||||
|
type Resp = Stub::Resp;
|
||||||
|
|
||||||
|
async fn call(
|
||||||
|
&self,
|
||||||
|
ctx: context::Context,
|
||||||
|
request_name: &'static str,
|
||||||
|
request: Self::Req,
|
||||||
|
) -> Result<Stub::Resp, RpcError> {
|
||||||
|
let next = self.stubs.next();
|
||||||
|
next.call(ctx, request_name, request).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Stub that load-balances across backing stubs by round robin.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct RoundRobin<Stub> {
|
||||||
|
stubs: AtomicCycle<Stub>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Stub> RoundRobin<Stub>
|
||||||
|
where
|
||||||
|
Stub: stub::Stub,
|
||||||
|
{
|
||||||
|
/// Returns a new RoundRobin stub.
|
||||||
|
pub fn new(stubs: Vec<Stub>) -> Self {
|
||||||
|
Self {
|
||||||
|
stubs: AtomicCycle::new(stubs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mod cycle {
|
||||||
|
use std::sync::{
|
||||||
|
atomic::{AtomicUsize, Ordering},
|
||||||
|
Arc,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Cycles endlessly and atomically over a collection of elements of type T.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct AtomicCycle<T>(Arc<State<T>>);
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct State<T> {
|
||||||
|
elements: Vec<T>,
|
||||||
|
next: AtomicUsize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> AtomicCycle<T> {
|
||||||
|
pub fn new(elements: Vec<T>) -> Self {
|
||||||
|
Self(Arc::new(State {
|
||||||
|
elements,
|
||||||
|
next: Default::default(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn next(&self) -> &T {
|
||||||
|
self.0.next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> State<T> {
|
||||||
|
pub fn next(&self) -> &T {
|
||||||
|
let next = self.next.fetch_add(1, Ordering::Relaxed);
|
||||||
|
&self.elements[next % self.elements.len()]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cycle() {
|
||||||
|
let cycle = AtomicCycle::new(vec![1, 2, 3]);
|
||||||
|
assert_eq!(cycle.next(), &1);
|
||||||
|
assert_eq!(cycle.next(), &2);
|
||||||
|
assert_eq!(cycle.next(), &3);
|
||||||
|
assert_eq!(cycle.next(), &1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provides a stub that load-balances with a consistent hashing strategy.
|
||||||
|
///
|
||||||
|
/// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use
|
||||||
|
/// the same stub.
|
||||||
|
mod consistent_hash {
|
||||||
|
use crate::{
|
||||||
|
client::{stub, RpcError},
|
||||||
|
context,
|
||||||
|
};
|
||||||
|
use std::{
|
||||||
|
collections::hash_map::RandomState,
|
||||||
|
hash::{BuildHasher, Hash, Hasher},
|
||||||
|
num::TryFromIntError,
|
||||||
|
};
|
||||||
|
|
||||||
|
impl<Stub, S> stub::Stub for ConsistentHash<Stub, S>
|
||||||
|
where
|
||||||
|
Stub: stub::Stub,
|
||||||
|
Stub::Req: Hash,
|
||||||
|
S: BuildHasher,
|
||||||
|
{
|
||||||
|
type Req = Stub::Req;
|
||||||
|
type Resp = Stub::Resp;
|
||||||
|
|
||||||
|
async fn call(
|
||||||
|
&self,
|
||||||
|
ctx: context::Context,
|
||||||
|
request_name: &'static str,
|
||||||
|
request: Self::Req,
|
||||||
|
) -> Result<Stub::Resp, RpcError> {
|
||||||
|
let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect(
|
||||||
|
"invariant broken: stubs_len is not larger than a usize, \
|
||||||
|
so the hash modulo stubs_len should always fit in a usize",
|
||||||
|
);
|
||||||
|
let next = &self.stubs[index];
|
||||||
|
next.call(ctx, request_name, request).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Stub that load-balances across backing stubs by round robin.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ConsistentHash<Stub, S = RandomState> {
|
||||||
|
stubs: Vec<Stub>,
|
||||||
|
stubs_len: u64,
|
||||||
|
hasher: S,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Stub> ConsistentHash<Stub, RandomState>
|
||||||
|
where
|
||||||
|
Stub: stub::Stub,
|
||||||
|
Stub::Req: Hash,
|
||||||
|
{
|
||||||
|
/// Returns a new RoundRobin stub.
|
||||||
|
/// Returns an err if the length of `stubs` overflows a u64.
|
||||||
|
pub fn new(stubs: Vec<Stub>) -> Result<Self, TryFromIntError> {
|
||||||
|
Ok(Self {
|
||||||
|
stubs_len: stubs.len().try_into()?,
|
||||||
|
stubs,
|
||||||
|
hasher: RandomState::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Stub, S> ConsistentHash<Stub, S>
|
||||||
|
where
|
||||||
|
Stub: stub::Stub,
|
||||||
|
Stub::Req: Hash,
|
||||||
|
S: BuildHasher,
|
||||||
|
{
|
||||||
|
/// Returns a new RoundRobin stub.
|
||||||
|
/// Returns an err if the length of `stubs` overflows a u64.
|
||||||
|
pub fn with_hasher(stubs: Vec<Stub>, hasher: S) -> Result<Self, TryFromIntError> {
|
||||||
|
Ok(Self {
|
||||||
|
stubs_len: stubs.len().try_into()?,
|
||||||
|
stubs,
|
||||||
|
hasher,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hash_request(&self, req: &Stub::Req) -> u64 {
|
||||||
|
let mut hasher = self.hasher.build_hasher();
|
||||||
|
req.hash(&mut hasher);
|
||||||
|
hasher.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::ConsistentHash;
|
||||||
|
use crate::{
|
||||||
|
client::stub::{mock::Mock, Stub},
|
||||||
|
context,
|
||||||
|
};
|
||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
hash::{BuildHasher, Hash, Hasher},
|
||||||
|
rc::Rc,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test() -> anyhow::Result<()> {
|
||||||
|
let stub = ConsistentHash::<_, FakeHasherBuilder>::with_hasher(
|
||||||
|
vec![
|
||||||
|
// For easier reading of the assertions made in this test, each Mock's response
|
||||||
|
// value is equal to a hash value that should map to its index: 3 % 3 = 0, 1 %
|
||||||
|
// 3 = 1, etc.
|
||||||
|
Mock::new([('a', 3), ('b', 3), ('c', 3)]),
|
||||||
|
Mock::new([('a', 1), ('b', 1), ('c', 1)]),
|
||||||
|
Mock::new([('a', 2), ('b', 2), ('c', 2)]),
|
||||||
|
],
|
||||||
|
FakeHasherBuilder::new([('a', 1), ('b', 2), ('c', 3)]),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
for _ in 0..2 {
|
||||||
|
let resp = stub.call(context::current(), "", 'a').await?;
|
||||||
|
assert_eq!(resp, 1);
|
||||||
|
|
||||||
|
let resp = stub.call(context::current(), "", 'b').await?;
|
||||||
|
assert_eq!(resp, 2);
|
||||||
|
|
||||||
|
let resp = stub.call(context::current(), "", 'c').await?;
|
||||||
|
assert_eq!(resp, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct HashRecorder(Vec<u8>);
|
||||||
|
impl Hasher for HashRecorder {
|
||||||
|
fn write(&mut self, bytes: &[u8]) {
|
||||||
|
self.0 = Vec::from(bytes);
|
||||||
|
}
|
||||||
|
fn finish(&self) -> u64 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FakeHasherBuilder {
|
||||||
|
recorded_hashes: Rc<HashMap<Vec<u8>, u64>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FakeHasher {
|
||||||
|
recorded_hashes: Rc<HashMap<Vec<u8>, u64>>,
|
||||||
|
output: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BuildHasher for FakeHasherBuilder {
|
||||||
|
type Hasher = FakeHasher;
|
||||||
|
|
||||||
|
fn build_hasher(&self) -> Self::Hasher {
|
||||||
|
FakeHasher {
|
||||||
|
recorded_hashes: self.recorded_hashes.clone(),
|
||||||
|
output: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeHasherBuilder {
|
||||||
|
fn new<T: Hash, const N: usize>(fake_hashes: [(T, u64); N]) -> Self {
|
||||||
|
let mut recorded_hashes = HashMap::new();
|
||||||
|
for (to_hash, fake_hash) in fake_hashes {
|
||||||
|
let mut recorder = HashRecorder(vec![]);
|
||||||
|
to_hash.hash(&mut recorder);
|
||||||
|
recorded_hashes.insert(recorder.0, fake_hash);
|
||||||
|
}
|
||||||
|
Self {
|
||||||
|
recorded_hashes: Rc::new(recorded_hashes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Hasher for FakeHasher {
|
||||||
|
fn write(&mut self, bytes: &[u8]) {
|
||||||
|
if let Some(hash) = self.recorded_hashes.get(bytes) {
|
||||||
|
self.output = *hash;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn finish(&self) -> u64 {
|
||||||
|
self.output
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
49
tarpc/src/client/stub/mock.rs
Normal file
49
tarpc/src/client/stub/mock.rs
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
use crate::{
|
||||||
|
client::{stub::Stub, RpcError},
|
||||||
|
context, ServerError,
|
||||||
|
};
|
||||||
|
use std::{collections::HashMap, hash::Hash, io};
|
||||||
|
|
||||||
|
/// A mock stub that returns user-specified responses.
|
||||||
|
pub struct Mock<Req, Resp> {
|
||||||
|
responses: HashMap<Req, Resp>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp> Mock<Req, Resp>
|
||||||
|
where
|
||||||
|
Req: Eq + Hash,
|
||||||
|
{
|
||||||
|
/// Returns a new mock, mocking the specified (request, response) pairs.
|
||||||
|
pub fn new<const N: usize>(responses: [(Req, Resp); N]) -> Self {
|
||||||
|
Self {
|
||||||
|
responses: HashMap::from(responses),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp> Stub for Mock<Req, Resp>
|
||||||
|
where
|
||||||
|
Req: Eq + Hash,
|
||||||
|
Resp: Clone,
|
||||||
|
{
|
||||||
|
type Req = Req;
|
||||||
|
type Resp = Resp;
|
||||||
|
|
||||||
|
async fn call(
|
||||||
|
&self,
|
||||||
|
_: context::Context,
|
||||||
|
_: &'static str,
|
||||||
|
request: Self::Req,
|
||||||
|
) -> Result<Resp, RpcError> {
|
||||||
|
self.responses
|
||||||
|
.get(&request)
|
||||||
|
.cloned()
|
||||||
|
.map(Ok)
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
Err(RpcError::Server(ServerError {
|
||||||
|
kind: io::ErrorKind::NotFound,
|
||||||
|
detail: "mock (request, response) entry not found".into(),
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
56
tarpc/src/client/stub/retry.rs
Normal file
56
tarpc/src/client/stub/retry.rs
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
//! Provides a stub that retries requests based on response contents..
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
client::{stub, RpcError},
|
||||||
|
context,
|
||||||
|
};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
impl<Stub, Req, F> stub::Stub for Retry<F, Stub>
|
||||||
|
where
|
||||||
|
Stub: stub::Stub<Req = Arc<Req>>,
|
||||||
|
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool,
|
||||||
|
{
|
||||||
|
type Req = Req;
|
||||||
|
type Resp = Stub::Resp;
|
||||||
|
|
||||||
|
async fn call(
|
||||||
|
&self,
|
||||||
|
ctx: context::Context,
|
||||||
|
request_name: &'static str,
|
||||||
|
request: Self::Req,
|
||||||
|
) -> Result<Stub::Resp, RpcError> {
|
||||||
|
let request = Arc::new(request);
|
||||||
|
for i in 1.. {
|
||||||
|
let result = self
|
||||||
|
.stub
|
||||||
|
.call(ctx, request_name, Arc::clone(&request))
|
||||||
|
.await;
|
||||||
|
if (self.should_retry)(&result, i) {
|
||||||
|
tracing::trace!("Retrying on attempt {i}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
unreachable!("Wow, that was a lot of attempts!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Stub that retries requests based on response contents.
|
||||||
|
/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct Retry<F, Stub> {
|
||||||
|
should_retry: F,
|
||||||
|
stub: Stub,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Stub, Req, F> Retry<F, Stub>
|
||||||
|
where
|
||||||
|
Stub: stub::Stub<Req = Arc<Req>>,
|
||||||
|
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool,
|
||||||
|
{
|
||||||
|
/// Creates a new Retry stub that delegates calls to the underlying `stub`.
|
||||||
|
pub fn new(stub: Stub, should_retry: F) -> Self {
|
||||||
|
Self { stub, should_retry }
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -126,13 +126,9 @@
|
|||||||
//! struct HelloServer;
|
//! struct HelloServer;
|
||||||
//!
|
//!
|
||||||
//! impl World for HelloServer {
|
//! impl World for HelloServer {
|
||||||
//! // Each defined rpc generates two items in the trait, a fn that serves the RPC, and
|
//! // Each defined rpc generates an async fn that serves the RPC
|
||||||
//! // an associated type representing the future output by the fn.
|
//! async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
//!
|
//! format!("Hello, {name}!")
|
||||||
//! type HelloFut = Ready<String>;
|
|
||||||
//!
|
|
||||||
//! fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
|
||||||
//! future::ready(format!("Hello, {name}!"))
|
|
||||||
//! }
|
//! }
|
||||||
//! }
|
//! }
|
||||||
//! ```
|
//! ```
|
||||||
@@ -164,11 +160,9 @@
|
|||||||
//! # #[derive(Clone)]
|
//! # #[derive(Clone)]
|
||||||
//! # struct HelloServer;
|
//! # struct HelloServer;
|
||||||
//! # impl World for HelloServer {
|
//! # impl World for HelloServer {
|
||||||
//! # // Each defined rpc generates two items in the trait, a fn that serves the RPC, and
|
//! // Each defined rpc generates an async fn that serves the RPC
|
||||||
//! # // an associated type representing the future output by the fn.
|
//! # async fn hello(self, _: context::Context, name: String) -> String {
|
||||||
//! # type HelloFut = Ready<String>;
|
//! # format!("Hello, {name}!")
|
||||||
//! # fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
|
|
||||||
//! # future::ready(format!("Hello, {name}!"))
|
|
||||||
//! # }
|
//! # }
|
||||||
//! # }
|
//! # }
|
||||||
//! # #[cfg(not(feature = "tokio1"))]
|
//! # #[cfg(not(feature = "tokio1"))]
|
||||||
@@ -179,7 +173,12 @@
|
|||||||
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
|
||||||
//!
|
//!
|
||||||
//! let server = server::BaseChannel::with_defaults(server_transport);
|
//! let server = server::BaseChannel::with_defaults(server_transport);
|
||||||
//! tokio::spawn(server.execute(HelloServer.serve()));
|
//! tokio::spawn(
|
||||||
|
//! server.execute(HelloServer.serve())
|
||||||
|
//! // Handle all requests concurrently.
|
||||||
|
//! .for_each(|response| async move {
|
||||||
|
//! tokio::spawn(response);
|
||||||
|
//! }));
|
||||||
//!
|
//!
|
||||||
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
|
||||||
//! // that takes a config and any Transport as input.
|
//! // that takes a config and any Transport as input.
|
||||||
@@ -200,6 +199,7 @@
|
|||||||
//!
|
//!
|
||||||
//! Use `cargo doc` as you normally would to see the documentation created for all
|
//! Use `cargo doc` as you normally would to see the documentation created for all
|
||||||
//! items expanded by a `service!` invocation.
|
//! items expanded by a `service!` invocation.
|
||||||
|
|
||||||
#![deny(missing_docs)]
|
#![deny(missing_docs)]
|
||||||
#![allow(clippy::type_complexity)]
|
#![allow(clippy::type_complexity)]
|
||||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||||
@@ -244,62 +244,6 @@ pub use tarpc_plugins::derive_serde;
|
|||||||
/// * `fn new_stub` -- creates a new Client stub.
|
/// * `fn new_stub` -- creates a new Client stub.
|
||||||
pub use tarpc_plugins::service;
|
pub use tarpc_plugins::service;
|
||||||
|
|
||||||
/// A utility macro that can be used for RPC server implementations.
|
|
||||||
///
|
|
||||||
/// Syntactic sugar to make using async functions in the server implementation
|
|
||||||
/// easier. It does this by rewriting code like this, which would normally not
|
|
||||||
/// compile because async functions are disallowed in trait implementations:
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # use tarpc::context;
|
|
||||||
/// # use std::net::SocketAddr;
|
|
||||||
/// #[tarpc::service]
|
|
||||||
/// trait World {
|
|
||||||
/// async fn hello(name: String) -> String;
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// #[derive(Clone)]
|
|
||||||
/// struct HelloServer(SocketAddr);
|
|
||||||
///
|
|
||||||
/// #[tarpc::server]
|
|
||||||
/// impl World for HelloServer {
|
|
||||||
/// async fn hello(self, _: context::Context, name: String) -> String {
|
|
||||||
/// format!("Hello, {name}! You are connected from {:?}.", self.0)
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// Into code like this, which matches the service trait definition:
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # use tarpc::context;
|
|
||||||
/// # use std::pin::Pin;
|
|
||||||
/// # use futures::Future;
|
|
||||||
/// # use std::net::SocketAddr;
|
|
||||||
/// #[derive(Clone)]
|
|
||||||
/// struct HelloServer(SocketAddr);
|
|
||||||
///
|
|
||||||
/// #[tarpc::service]
|
|
||||||
/// trait World {
|
|
||||||
/// async fn hello(name: String) -> String;
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// impl World for HelloServer {
|
|
||||||
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
|
|
||||||
///
|
|
||||||
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
|
|
||||||
/// + Send>> {
|
|
||||||
/// Box::pin(async move {
|
|
||||||
/// format!("Hello, {name}! You are connected from {:?}.", self.0)
|
|
||||||
/// })
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// Note that this won't touch functions unless they have been annotated with
|
|
||||||
/// `async`, meaning that this should not break existing code.
|
|
||||||
pub use tarpc_plugins::server;
|
|
||||||
|
|
||||||
pub(crate) mod cancellations;
|
pub(crate) mod cancellations;
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod context;
|
pub mod context;
|
||||||
@@ -407,6 +351,13 @@ where
|
|||||||
Close(#[source] E),
|
Close(#[source] E),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ServerError {
|
||||||
|
/// Returns a new server error with `kind` and `detail`.
|
||||||
|
pub fn new(kind: io::ErrorKind, detail: String) -> ServerError {
|
||||||
|
Self { kind, detail }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T> Request<T> {
|
impl<T> Request<T> {
|
||||||
/// Returns the deadline for this request.
|
/// Returns the deadline for this request.
|
||||||
pub fn deadline(&self) -> &SystemTime {
|
pub fn deadline(&self) -> &SystemTime {
|
||||||
|
|||||||
@@ -210,7 +210,19 @@ pub mod tcp {
|
|||||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
CodecFn: Fn() -> Codec,
|
CodecFn: Fn() -> Codec,
|
||||||
{
|
{
|
||||||
let listener = TcpListener::bind(addr).await?;
|
listen_on(TcpListener::bind(addr).await?, codec_fn).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wrap accepted connections from `listener` in TCP transports.
|
||||||
|
pub async fn listen_on<Item, SinkItem, Codec, CodecFn>(
|
||||||
|
listener: TcpListener,
|
||||||
|
codec_fn: CodecFn,
|
||||||
|
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
|
||||||
|
where
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
CodecFn: Fn() -> Codec,
|
||||||
|
{
|
||||||
let local_addr = listener.local_addr()?;
|
let local_addr = listener.local_addr()?;
|
||||||
Ok(Incoming {
|
Ok(Incoming {
|
||||||
listener,
|
listener,
|
||||||
@@ -364,7 +376,19 @@ pub mod unix {
|
|||||||
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
CodecFn: Fn() -> Codec,
|
CodecFn: Fn() -> Codec,
|
||||||
{
|
{
|
||||||
let listener = UnixListener::bind(path)?;
|
listen_on(UnixListener::bind(path)?, codec_fn).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wrap accepted connections from `listener` in Unix Domain Socket transports.
|
||||||
|
pub async fn listen_on<Item, SinkItem, Codec, CodecFn>(
|
||||||
|
listener: UnixListener,
|
||||||
|
codec_fn: CodecFn,
|
||||||
|
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
|
||||||
|
where
|
||||||
|
Item: for<'de> Deserialize<'de>,
|
||||||
|
Codec: Serializer<SinkItem> + Deserializer<Item>,
|
||||||
|
CodecFn: Fn() -> Codec,
|
||||||
|
{
|
||||||
let local_addr = listener.local_addr()?;
|
let local_addr = listener.local_addr()?;
|
||||||
Ok(Incoming {
|
Ok(Incoming {
|
||||||
listener,
|
listener,
|
||||||
@@ -537,7 +561,7 @@ pub mod unix {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::Transport;
|
use super::Transport;
|
||||||
use assert_matches::assert_matches;
|
use assert_matches::assert_matches;
|
||||||
use futures::{task::*, Sink, Stream};
|
use futures::{task::*, Sink, SinkExt, Stream, StreamExt};
|
||||||
use pin_utils::pin_mut;
|
use pin_utils::pin_mut;
|
||||||
use std::{
|
use std::{
|
||||||
io::{self, Cursor},
|
io::{self, Cursor},
|
||||||
@@ -631,7 +655,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(tcp)]
|
#[cfg(feature = "tcp")]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn tcp() -> io::Result<()> {
|
async fn tcp() -> io::Result<()> {
|
||||||
use super::tcp;
|
use super::tcp;
|
||||||
@@ -650,11 +674,30 @@ mod tests {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "tcp")]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn tcp_on_existing_transport() -> io::Result<()> {
|
||||||
|
use super::tcp;
|
||||||
|
|
||||||
|
let transport = tokio::net::TcpListener::bind("0.0.0.0:0").await?;
|
||||||
|
let mut listener = tcp::listen_on(transport, SymmetricalJson::<String>::default).await?;
|
||||||
|
let addr = listener.local_addr();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut transport = listener.next().await.unwrap().unwrap();
|
||||||
|
let message = transport.next().await.unwrap().unwrap();
|
||||||
|
transport.send(message).await.unwrap();
|
||||||
|
});
|
||||||
|
let mut transport = tcp::connect(addr, SymmetricalJson::<String>::default).await?;
|
||||||
|
transport.send(String::from("test")).await?;
|
||||||
|
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
|
||||||
|
assert_matches!(transport.next().await, None);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(all(unix, feature = "unix"))]
|
#[cfg(all(unix, feature = "unix"))]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn uds() -> io::Result<()> {
|
async fn uds() -> io::Result<()> {
|
||||||
use super::unix;
|
use super::unix;
|
||||||
use super::*;
|
|
||||||
|
|
||||||
let sock = unix::TempPathBuf::with_random("uds");
|
let sock = unix::TempPathBuf::with_random("uds");
|
||||||
let mut listener = unix::listen(&sock, SymmetricalJson::<String>::default).await?;
|
let mut listener = unix::listen(&sock, SymmetricalJson::<String>::default).await?;
|
||||||
@@ -669,4 +712,24 @@ mod tests {
|
|||||||
assert_matches!(transport.next().await, None);
|
assert_matches!(transport.next().await, None);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(all(unix, feature = "unix"))]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn uds_on_existing_transport() -> io::Result<()> {
|
||||||
|
use super::unix;
|
||||||
|
|
||||||
|
let sock = unix::TempPathBuf::with_random("uds");
|
||||||
|
let transport = tokio::net::UnixListener::bind(&sock)?;
|
||||||
|
let mut listener = unix::listen_on(transport, SymmetricalJson::<String>::default).await?;
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut transport = listener.next().await.unwrap().unwrap();
|
||||||
|
let message = transport.next().await.unwrap().unwrap();
|
||||||
|
transport.send(message).await.unwrap();
|
||||||
|
});
|
||||||
|
let mut transport = unix::connect(&sock, SymmetricalJson::<String>::default).await?;
|
||||||
|
transport.send(String::from("test")).await?;
|
||||||
|
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
|
||||||
|
assert_matches!(transport.next().await, None);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||||
context::{self, SpanExt},
|
context::{self, SpanExt},
|
||||||
trace, ChannelError, ClientMessage, Request, Response, Transport,
|
trace, ChannelError, ClientMessage, Request, Response, ServerError, Transport,
|
||||||
};
|
};
|
||||||
use ::tokio::sync::mpsc;
|
use ::tokio::sync::mpsc;
|
||||||
use futures::{
|
use futures::{
|
||||||
@@ -25,6 +25,7 @@ use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sy
|
|||||||
use tracing::{info_span, instrument::Instrument, Span};
|
use tracing::{info_span, instrument::Instrument, Span};
|
||||||
|
|
||||||
mod in_flight_requests;
|
mod in_flight_requests;
|
||||||
|
pub mod request_hook;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod testing;
|
mod testing;
|
||||||
|
|
||||||
@@ -34,10 +35,9 @@ pub mod limits;
|
|||||||
/// Provides helper methods for streams of Channels.
|
/// Provides helper methods for streams of Channels.
|
||||||
pub mod incoming;
|
pub mod incoming;
|
||||||
|
|
||||||
/// Provides convenience functionality for tokio-enabled applications.
|
use request_hook::{
|
||||||
#[cfg(feature = "tokio1")]
|
AfterRequest, BeforeRequest, HookThenServe, HookThenServeThenHook, ServeThenHook,
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
};
|
||||||
pub mod tokio;
|
|
||||||
|
|
||||||
/// Settings that control the behavior of [channels](Channel).
|
/// Settings that control the behavior of [channels](Channel).
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
@@ -67,32 +67,204 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
|
/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
|
||||||
pub trait Serve<Req> {
|
#[allow(async_fn_in_trait)]
|
||||||
|
pub trait Serve {
|
||||||
|
/// Type of request.
|
||||||
|
type Req;
|
||||||
|
|
||||||
/// Type of response.
|
/// Type of response.
|
||||||
type Resp;
|
type Resp;
|
||||||
|
|
||||||
/// Type of response future.
|
/// Responds to a single request.
|
||||||
type Fut: Future<Output = Self::Resp>;
|
async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError>;
|
||||||
|
|
||||||
/// Extracts a method name from the request.
|
/// Extracts a method name from the request.
|
||||||
fn method(&self, _request: &Req) -> Option<&'static str> {
|
fn method(&self, _request: &Self::Req) -> Option<&'static str> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Responds to a single request.
|
/// Runs a hook before execution of the request.
|
||||||
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut;
|
///
|
||||||
|
/// If the hook returns an error, the request will not be executed and the error will be
|
||||||
|
/// returned instead.
|
||||||
|
///
|
||||||
|
/// The hook can also modify the request context. This could be used, for example, to enforce a
|
||||||
|
/// maximum deadline on all requests.
|
||||||
|
///
|
||||||
|
/// Any type that implements [`BeforeRequest`] can be used as the hook. Types that implement
|
||||||
|
/// `FnMut(&mut Context, &RequestType) -> impl Future<Output = Result<(), ServerError>>` can
|
||||||
|
/// also be used.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use futures::{executor::block_on, future};
|
||||||
|
/// use tarpc::{context, ServerError, server::{Serve, serve}};
|
||||||
|
/// use std::io;
|
||||||
|
///
|
||||||
|
/// let serve = serve(|_ctx, i| async move { Ok(i + 1) })
|
||||||
|
/// .before(|_ctx: &mut context::Context, req: &i32| {
|
||||||
|
/// future::ready(
|
||||||
|
/// if *req == 1 {
|
||||||
|
/// Err(ServerError::new(
|
||||||
|
/// io::ErrorKind::Other,
|
||||||
|
/// format!("I don't like {req}")))
|
||||||
|
/// } else {
|
||||||
|
/// Ok(())
|
||||||
|
/// })
|
||||||
|
/// });
|
||||||
|
/// let response = serve.serve(context::current(), 1);
|
||||||
|
/// assert!(block_on(response).is_err());
|
||||||
|
/// ```
|
||||||
|
fn before<Hook>(self, hook: Hook) -> HookThenServe<Self, Hook>
|
||||||
|
where
|
||||||
|
Hook: BeforeRequest<Self::Req>,
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
HookThenServe::new(self, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runs a hook after completion of a request.
|
||||||
|
///
|
||||||
|
/// The hook can modify the request context and the response.
|
||||||
|
///
|
||||||
|
/// Any type that implements [`AfterRequest`] can be used as the hook. Types that implement
|
||||||
|
/// `FnMut(&mut Context, &mut Result<ResponseType, ServerError>) -> impl Future<Output = ()>`
|
||||||
|
/// can also be used.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use futures::{executor::block_on, future};
|
||||||
|
/// use tarpc::{context, ServerError, server::{Serve, serve}};
|
||||||
|
/// use std::io;
|
||||||
|
///
|
||||||
|
/// let serve = serve(
|
||||||
|
/// |_ctx, i| async move {
|
||||||
|
/// if i == 1 {
|
||||||
|
/// Err(ServerError::new(
|
||||||
|
/// io::ErrorKind::Other,
|
||||||
|
/// format!("{i} is the loneliest number")))
|
||||||
|
/// } else {
|
||||||
|
/// Ok(i + 1)
|
||||||
|
/// }
|
||||||
|
/// })
|
||||||
|
/// .after(|_ctx: &mut context::Context, resp: &mut Result<i32, ServerError>| {
|
||||||
|
/// if let Err(e) = resp {
|
||||||
|
/// eprintln!("server error: {e:?}");
|
||||||
|
/// }
|
||||||
|
/// future::ready(())
|
||||||
|
/// });
|
||||||
|
///
|
||||||
|
/// let response = serve.serve(context::current(), 1);
|
||||||
|
/// assert!(block_on(response).is_err());
|
||||||
|
/// ```
|
||||||
|
fn after<Hook>(self, hook: Hook) -> ServeThenHook<Self, Hook>
|
||||||
|
where
|
||||||
|
Hook: AfterRequest<Self::Resp>,
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
ServeThenHook::new(self, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runs a hook before and after execution of the request.
|
||||||
|
///
|
||||||
|
/// If the hook returns an error, the request will not be executed and the error will be
|
||||||
|
/// returned instead.
|
||||||
|
///
|
||||||
|
/// The hook can also modify the request context and the response. This could be used, for
|
||||||
|
/// example, to enforce a maximum deadline on all requests.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use futures::{executor::block_on, future};
|
||||||
|
/// use tarpc::{
|
||||||
|
/// context, ServerError, server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest}}
|
||||||
|
/// };
|
||||||
|
/// use std::{io, time::Instant};
|
||||||
|
///
|
||||||
|
/// struct PrintLatency(Instant);
|
||||||
|
///
|
||||||
|
/// impl<Req> BeforeRequest<Req> for PrintLatency {
|
||||||
|
/// async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> {
|
||||||
|
/// self.0 = Instant::now();
|
||||||
|
/// Ok(())
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// impl<Resp> AfterRequest<Resp> for PrintLatency {
|
||||||
|
/// async fn after(
|
||||||
|
/// &mut self,
|
||||||
|
/// _: &mut context::Context,
|
||||||
|
/// _: &mut Result<Resp, ServerError>,
|
||||||
|
/// ) {
|
||||||
|
/// tracing::info!("Elapsed: {:?}", self.0.elapsed());
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// let serve = serve(|_ctx, i| async move {
|
||||||
|
/// Ok(i + 1)
|
||||||
|
/// }).before_and_after(PrintLatency(Instant::now()));
|
||||||
|
/// let response = serve.serve(context::current(), 1);
|
||||||
|
/// assert!(block_on(response).is_ok());
|
||||||
|
/// ```
|
||||||
|
fn before_and_after<Hook>(
|
||||||
|
self,
|
||||||
|
hook: Hook,
|
||||||
|
) -> HookThenServeThenHook<Self::Req, Self::Resp, Self, Hook>
|
||||||
|
where
|
||||||
|
Hook: BeforeRequest<Self::Req> + AfterRequest<Self::Resp>,
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
HookThenServeThenHook::new(self, hook)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Req, Resp, Fut, F> Serve<Req> for F
|
/// A Serve wrapper around a Fn.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ServeFn<Req, Resp, F> {
|
||||||
|
f: F,
|
||||||
|
data: PhantomData<fn(Req) -> Resp>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp, F> Clone for ServeFn<Req, Resp, F>
|
||||||
|
where
|
||||||
|
F: Clone,
|
||||||
|
{
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
f: self.f.clone(),
|
||||||
|
data: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp, F> Copy for ServeFn<Req, Resp, F> where F: Copy {}
|
||||||
|
|
||||||
|
/// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future<Output =
|
||||||
|
/// Result<Resp, ServerError>>`.
|
||||||
|
pub fn serve<Req, Resp, Fut, F>(f: F) -> ServeFn<Req, Resp, F>
|
||||||
where
|
where
|
||||||
F: FnOnce(context::Context, Req) -> Fut,
|
F: FnOnce(context::Context, Req) -> Fut,
|
||||||
Fut: Future<Output = Resp>,
|
Fut: Future<Output = Result<Resp, ServerError>>,
|
||||||
{
|
{
|
||||||
type Resp = Resp;
|
ServeFn {
|
||||||
type Fut = Fut;
|
f,
|
||||||
|
data: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut {
|
impl<Req, Resp, Fut, F> Serve for ServeFn<Req, Resp, F>
|
||||||
self(ctx, req)
|
where
|
||||||
|
F: FnOnce(context::Context, Req) -> Fut,
|
||||||
|
Fut: Future<Output = Result<Resp, ServerError>>,
|
||||||
|
{
|
||||||
|
type Req = Req;
|
||||||
|
type Resp = Resp;
|
||||||
|
|
||||||
|
async fn serve(self, ctx: context::Context, req: Req) -> Result<Resp, ServerError> {
|
||||||
|
(self.f)(ctx, req).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,7 +292,7 @@ pub struct BaseChannel<Req, Resp, T> {
|
|||||||
/// Holds data necessary to clean up in-flight requests.
|
/// Holds data necessary to clean up in-flight requests.
|
||||||
in_flight_requests: InFlightRequests,
|
in_flight_requests: InFlightRequests,
|
||||||
/// Types the request and response.
|
/// Types the request and response.
|
||||||
ghost: PhantomData<(Req, Resp)>,
|
ghost: PhantomData<(fn() -> Req, fn(Resp))>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
|
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
|
||||||
@@ -307,6 +479,34 @@ where
|
|||||||
/// This is a terminal operation. After calling `requests`, the channel cannot be retrieved,
|
/// This is a terminal operation. After calling `requests`, the channel cannot be retrieved,
|
||||||
/// and the only way to complete requests is via [`Requests::execute`] or
|
/// and the only way to complete requests is via [`Requests::execute`] or
|
||||||
/// [`InFlightRequest::execute`].
|
/// [`InFlightRequest::execute`].
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use tarpc::{
|
||||||
|
/// context,
|
||||||
|
/// client::{self, NewClient},
|
||||||
|
/// server::{self, BaseChannel, Channel, serve},
|
||||||
|
/// transport,
|
||||||
|
/// };
|
||||||
|
/// use futures::prelude::*;
|
||||||
|
///
|
||||||
|
/// #[tokio::main]
|
||||||
|
/// async fn main() {
|
||||||
|
/// let (tx, rx) = transport::channel::unbounded();
|
||||||
|
/// let server = BaseChannel::new(server::Config::default(), rx);
|
||||||
|
/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx);
|
||||||
|
/// tokio::spawn(dispatch);
|
||||||
|
///
|
||||||
|
/// let mut requests = server.requests();
|
||||||
|
/// tokio::spawn(async move {
|
||||||
|
/// while let Some(Ok(request)) = requests.next().await {
|
||||||
|
/// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) })));
|
||||||
|
/// }
|
||||||
|
/// });
|
||||||
|
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
fn requests(self) -> Requests<Self>
|
fn requests(self) -> Requests<Self>
|
||||||
where
|
where
|
||||||
Self: Sized,
|
Self: Sized,
|
||||||
@@ -320,18 +520,42 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Runs the channel until completion by executing all requests using the given service
|
/// Returns a stream of request execution futures. Each future represents an in-flight request
|
||||||
/// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's
|
/// being responded to by the server. The futures must be awaited or spawned to complete their
|
||||||
/// default executor.
|
/// requests.
|
||||||
#[cfg(feature = "tokio1")]
|
///
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
/// # Example
|
||||||
fn execute<S>(self, serve: S) -> self::tokio::TokioChannelExecutor<Requests<Self>, S>
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport};
|
||||||
|
/// use futures::prelude::*;
|
||||||
|
/// use tracing_subscriber::prelude::*;
|
||||||
|
///
|
||||||
|
/// #[derive(PartialEq, Eq, Debug)]
|
||||||
|
/// struct MyInt(i32);
|
||||||
|
///
|
||||||
|
/// # #[cfg(not(feature = "tokio1"))]
|
||||||
|
/// # fn main() {}
|
||||||
|
/// # #[cfg(feature = "tokio1")]
|
||||||
|
/// #[tokio::main]
|
||||||
|
/// async fn main() {
|
||||||
|
/// let (tx, rx) = transport::channel::unbounded();
|
||||||
|
/// let client = client::new(client::Config::default(), tx).spawn();
|
||||||
|
/// let channel = BaseChannel::with_defaults(rx);
|
||||||
|
/// tokio::spawn(
|
||||||
|
/// channel.execute(serve(|_, MyInt(i)| async move { Ok(MyInt(i + 1)) }))
|
||||||
|
/// .for_each(|response| async move {
|
||||||
|
/// tokio::spawn(response);
|
||||||
|
/// }));
|
||||||
|
/// assert_eq!(
|
||||||
|
/// client.call(context::current(), "AddOne", MyInt(1)).await.unwrap(),
|
||||||
|
/// MyInt(2));
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
fn execute<S>(self, serve: S) -> impl Stream<Item = impl Future<Output = ()>>
|
||||||
where
|
where
|
||||||
Self: Sized,
|
Self: Sized,
|
||||||
S: Serve<Self::Req, Resp = Self::Resp> + Send + 'static,
|
S: Serve<Req = Self::Req, Resp = Self::Resp> + Clone,
|
||||||
S::Fut: Send,
|
|
||||||
Self::Req: Send + 'static,
|
|
||||||
Self::Resp: Send + 'static,
|
|
||||||
{
|
{
|
||||||
self.requests().execute(serve)
|
self.requests().execute(serve)
|
||||||
}
|
}
|
||||||
@@ -425,15 +649,17 @@ where
|
|||||||
Poll::Pending => Pending,
|
Poll::Pending => Pending,
|
||||||
};
|
};
|
||||||
|
|
||||||
tracing::trace!(
|
let status = cancellation_status
|
||||||
"Expired requests: {:?}, Inbound: {:?}",
|
|
||||||
expiration_status,
|
|
||||||
request_status
|
|
||||||
);
|
|
||||||
match cancellation_status
|
|
||||||
.combine(expiration_status)
|
.combine(expiration_status)
|
||||||
.combine(request_status)
|
.combine(request_status);
|
||||||
{
|
|
||||||
|
tracing::trace!(
|
||||||
|
"Cancellations: {cancellation_status:?}, \
|
||||||
|
Expired requests: {expiration_status:?}, \
|
||||||
|
Inbound: {request_status:?}, \
|
||||||
|
Overall: {status:?}",
|
||||||
|
);
|
||||||
|
match status {
|
||||||
Ready => continue,
|
Ready => continue,
|
||||||
Closed => return Poll::Ready(None),
|
Closed => return Poll::Ready(None),
|
||||||
Pending => return Poll::Pending,
|
Pending => return Poll::Pending,
|
||||||
@@ -565,6 +791,10 @@ where
|
|||||||
}| {
|
}| {
|
||||||
// The response guard becomes active once in an InFlightRequest.
|
// The response guard becomes active once in an InFlightRequest.
|
||||||
response_guard.cancel = true;
|
response_guard.cancel = true;
|
||||||
|
{
|
||||||
|
let _entered = span.enter();
|
||||||
|
tracing::info!("BeginRequest");
|
||||||
|
}
|
||||||
InFlightRequest {
|
InFlightRequest {
|
||||||
request,
|
request,
|
||||||
abort_registration,
|
abort_registration,
|
||||||
@@ -639,6 +869,51 @@ where
|
|||||||
}
|
}
|
||||||
Poll::Ready(Some(Ok(())))
|
Poll::Ready(Some(Ok(())))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a stream of request execution futures. Each future represents an in-flight request
|
||||||
|
/// being responded to by the server. The futures must be awaited or spawned to complete their
|
||||||
|
/// requests.
|
||||||
|
///
|
||||||
|
/// If the channel encounters an error, the stream is terminated and the error is logged.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport};
|
||||||
|
/// use futures::prelude::*;
|
||||||
|
///
|
||||||
|
/// # #[cfg(not(feature = "tokio1"))]
|
||||||
|
/// # fn main() {}
|
||||||
|
/// # #[cfg(feature = "tokio1")]
|
||||||
|
/// #[tokio::main]
|
||||||
|
/// async fn main() {
|
||||||
|
/// let (tx, rx) = transport::channel::unbounded();
|
||||||
|
/// let requests = BaseChannel::new(server::Config::default(), rx).requests();
|
||||||
|
/// let client = client::new(client::Config::default(), tx).spawn();
|
||||||
|
/// tokio::spawn(
|
||||||
|
/// requests.execute(serve(|_, i| async move { Ok(i + 1) }))
|
||||||
|
/// .for_each(|response| async move {
|
||||||
|
/// tokio::spawn(response);
|
||||||
|
/// }));
|
||||||
|
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub fn execute<S>(self, serve: S) -> impl Stream<Item = impl Future<Output = ()>>
|
||||||
|
where
|
||||||
|
S: Serve<Req = C::Req, Resp = C::Resp> + Clone,
|
||||||
|
{
|
||||||
|
self.take_while(|result| {
|
||||||
|
if let Err(e) = result {
|
||||||
|
tracing::warn!("Requests stream errored out: {}", e);
|
||||||
|
}
|
||||||
|
futures::future::ready(result.is_ok())
|
||||||
|
})
|
||||||
|
.filter_map(|result| async move { result.ok() })
|
||||||
|
.map(move |request| {
|
||||||
|
let serve = serve.clone();
|
||||||
|
request.execute(serve)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C> fmt::Debug for Requests<C>
|
impl<C> fmt::Debug for Requests<C>
|
||||||
@@ -700,9 +975,39 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
|||||||
///
|
///
|
||||||
/// If the returned Future is dropped before completion, a cancellation message will be sent to
|
/// If the returned Future is dropped before completion, a cancellation message will be sent to
|
||||||
/// the Channel to clean up associated request state.
|
/// the Channel to clean up associated request state.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use tarpc::{
|
||||||
|
/// context,
|
||||||
|
/// client::{self, NewClient},
|
||||||
|
/// server::{self, BaseChannel, Channel, serve},
|
||||||
|
/// transport,
|
||||||
|
/// };
|
||||||
|
/// use futures::prelude::*;
|
||||||
|
///
|
||||||
|
/// #[tokio::main]
|
||||||
|
/// async fn main() {
|
||||||
|
/// let (tx, rx) = transport::channel::unbounded();
|
||||||
|
/// let server = BaseChannel::new(server::Config::default(), rx);
|
||||||
|
/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx);
|
||||||
|
/// tokio::spawn(dispatch);
|
||||||
|
///
|
||||||
|
/// tokio::spawn(async move {
|
||||||
|
/// let mut requests = server.requests();
|
||||||
|
/// while let Some(Ok(in_flight_request)) = requests.next().await {
|
||||||
|
/// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await;
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// });
|
||||||
|
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
pub async fn execute<S>(self, serve: S)
|
pub async fn execute<S>(self, serve: S)
|
||||||
where
|
where
|
||||||
S: Serve<Req, Resp = Res>,
|
S: Serve<Req = Req, Resp = Res>,
|
||||||
{
|
{
|
||||||
let Self {
|
let Self {
|
||||||
response_tx,
|
response_tx,
|
||||||
@@ -717,18 +1022,14 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
|||||||
},
|
},
|
||||||
} = self;
|
} = self;
|
||||||
let method = serve.method(&message);
|
let method = serve.method(&message);
|
||||||
// TODO(https://github.com/rust-lang/rust-clippy/issues/9111)
|
span.record("otel.name", method.unwrap_or(""));
|
||||||
// remove when clippy is fixed
|
|
||||||
#[allow(clippy::needless_borrow)]
|
|
||||||
span.record("otel.name", &method.unwrap_or(""));
|
|
||||||
let _ = Abortable::new(
|
let _ = Abortable::new(
|
||||||
async move {
|
async move {
|
||||||
tracing::info!("BeginRequest");
|
let message = serve.serve(context, message).await;
|
||||||
let response = serve.serve(context, message).await;
|
|
||||||
tracing::info!("CompleteRequest");
|
tracing::info!("CompleteRequest");
|
||||||
let response = Response {
|
let response = Response {
|
||||||
request_id,
|
request_id,
|
||||||
message: Ok(response),
|
message,
|
||||||
};
|
};
|
||||||
let _ = response_tx.send(response).await;
|
let _ = response_tx.send(response).await;
|
||||||
tracing::info!("BufferResponse");
|
tracing::info!("BufferResponse");
|
||||||
@@ -744,6 +1045,13 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn print_err(e: &(dyn Error + 'static)) -> String {
|
||||||
|
anyhow::Chain::new(e)
|
||||||
|
.map(|e| e.to_string())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(": ")
|
||||||
|
}
|
||||||
|
|
||||||
impl<C> Stream for Requests<C>
|
impl<C> Stream for Requests<C>
|
||||||
where
|
where
|
||||||
C: Channel,
|
C: Channel,
|
||||||
@@ -752,17 +1060,33 @@ where
|
|||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
loop {
|
loop {
|
||||||
let read = self.as_mut().pump_read(cx)?;
|
let read = self.as_mut().pump_read(cx).map_err(|e| {
|
||||||
|
tracing::trace!("read: {}", print_err(&e));
|
||||||
|
e
|
||||||
|
})?;
|
||||||
let read_closed = matches!(read, Poll::Ready(None));
|
let read_closed = matches!(read, Poll::Ready(None));
|
||||||
match (read, self.as_mut().pump_write(cx, read_closed)?) {
|
let write = self.as_mut().pump_write(cx, read_closed).map_err(|e| {
|
||||||
|
tracing::trace!("write: {}", print_err(&e));
|
||||||
|
e
|
||||||
|
})?;
|
||||||
|
match (read, write) {
|
||||||
(Poll::Ready(None), Poll::Ready(None)) => {
|
(Poll::Ready(None), Poll::Ready(None)) => {
|
||||||
|
tracing::trace!("read: Poll::Ready(None), write: Poll::Ready(None)");
|
||||||
return Poll::Ready(None);
|
return Poll::Ready(None);
|
||||||
}
|
}
|
||||||
(Poll::Ready(Some(request_handler)), _) => {
|
(Poll::Ready(Some(request_handler)), _) => {
|
||||||
|
tracing::trace!("read: Poll::Ready(Some), write: _");
|
||||||
return Poll::Ready(Some(Ok(request_handler)));
|
return Poll::Ready(Some(Ok(request_handler)));
|
||||||
}
|
}
|
||||||
(_, Poll::Ready(Some(()))) => {}
|
(_, Poll::Ready(Some(()))) => {
|
||||||
_ => {
|
tracing::trace!("read: _, write: Poll::Ready(Some)");
|
||||||
|
}
|
||||||
|
(read @ Poll::Pending, write) | (read, write @ Poll::Pending) => {
|
||||||
|
tracing::trace!(
|
||||||
|
"read pending: {}, write pending: {}",
|
||||||
|
read.is_pending(),
|
||||||
|
write.is_pending()
|
||||||
|
);
|
||||||
return Poll::Pending;
|
return Poll::Pending;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -772,11 +1096,14 @@ where
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests};
|
use super::{
|
||||||
|
in_flight_requests::AlreadyExistsError, serve, AfterRequest, BaseChannel, BeforeRequest,
|
||||||
|
Channel, Config, Requests, Serve,
|
||||||
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
context, trace,
|
context, trace,
|
||||||
transport::channel::{self, UnboundedChannel},
|
transport::channel::{self, UnboundedChannel},
|
||||||
ClientMessage, Request, Response,
|
ClientMessage, Request, Response, ServerError,
|
||||||
};
|
};
|
||||||
use assert_matches::assert_matches;
|
use assert_matches::assert_matches;
|
||||||
use futures::{
|
use futures::{
|
||||||
@@ -785,7 +1112,12 @@ mod tests {
|
|||||||
Future,
|
Future,
|
||||||
};
|
};
|
||||||
use futures_test::task::noop_context;
|
use futures_test::task::noop_context;
|
||||||
use std::{pin::Pin, task::Poll};
|
use std::{
|
||||||
|
io,
|
||||||
|
pin::Pin,
|
||||||
|
task::Poll,
|
||||||
|
time::{Duration, Instant, SystemTime},
|
||||||
|
};
|
||||||
|
|
||||||
fn test_channel<Req, Resp>() -> (
|
fn test_channel<Req, Resp>() -> (
|
||||||
Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
|
Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
|
||||||
@@ -846,6 +1178,89 @@ mod tests {
|
|||||||
Abortable::new(pending(), abort_registration)
|
Abortable::new(pending(), abort_registration)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_serve() {
|
||||||
|
let serve = serve(|_, i| async move { Ok(i) });
|
||||||
|
assert_matches!(serve.serve(context::current(), 7).await, Ok(7));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn serve_before_mutates_context() -> anyhow::Result<()> {
|
||||||
|
struct SetDeadline(SystemTime);
|
||||||
|
impl<Req> BeforeRequest<Req> for SetDeadline {
|
||||||
|
async fn before(
|
||||||
|
&mut self,
|
||||||
|
ctx: &mut context::Context,
|
||||||
|
_: &Req,
|
||||||
|
) -> Result<(), ServerError> {
|
||||||
|
ctx.deadline = self.0;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let some_time = SystemTime::UNIX_EPOCH + Duration::from_secs(37);
|
||||||
|
let some_other_time = SystemTime::UNIX_EPOCH + Duration::from_secs(83);
|
||||||
|
|
||||||
|
let serve = serve(move |ctx: context::Context, i| async move {
|
||||||
|
assert_eq!(ctx.deadline, some_time);
|
||||||
|
Ok(i)
|
||||||
|
});
|
||||||
|
let deadline_hook = serve.before(SetDeadline(some_time));
|
||||||
|
let mut ctx = context::current();
|
||||||
|
ctx.deadline = some_other_time;
|
||||||
|
deadline_hook.serve(ctx, 7).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn serve_before_and_after() -> anyhow::Result<()> {
|
||||||
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
|
||||||
|
struct PrintLatency {
|
||||||
|
start: Instant,
|
||||||
|
}
|
||||||
|
impl PrintLatency {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
start: Instant::now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<Req> BeforeRequest<Req> for PrintLatency {
|
||||||
|
async fn before(
|
||||||
|
&mut self,
|
||||||
|
_: &mut context::Context,
|
||||||
|
_: &Req,
|
||||||
|
) -> Result<(), ServerError> {
|
||||||
|
self.start = Instant::now();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<Resp> AfterRequest<Resp> for PrintLatency {
|
||||||
|
async fn after(&mut self, _: &mut context::Context, _: &mut Result<Resp, ServerError>) {
|
||||||
|
tracing::info!("Elapsed: {:?}", self.start.elapsed());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let serve = serve(move |_: context::Context, i| async move { Ok(i) });
|
||||||
|
serve
|
||||||
|
.before_and_after(PrintLatency::new())
|
||||||
|
.serve(context::current(), 7)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn serve_before_error_aborts_request() -> anyhow::Result<()> {
|
||||||
|
let serve = serve(|_, _| async { panic!("Shouldn't get here") });
|
||||||
|
let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async {
|
||||||
|
Err(ServerError::new(io::ErrorKind::Other, "oops".into()))
|
||||||
|
});
|
||||||
|
let resp: Result<i32, _> = deadline_hook.serve(context::current(), 7).await;
|
||||||
|
assert_matches!(resp, Err(_));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn base_channel_start_send_duplicate_request_returns_error() {
|
async fn base_channel_start_send_duplicate_request_returns_error() {
|
||||||
let (mut channel, _tx) = test_channel::<(), ()>();
|
let (mut channel, _tx) = test_channel::<(), ()>();
|
||||||
@@ -1046,7 +1461,7 @@ mod tests {
|
|||||||
Poll::Ready(Some(Ok(request))) => request,
|
Poll::Ready(Some(Ok(request))) => request,
|
||||||
result => panic!("Unexpected result: {:?}", result),
|
result => panic!("Unexpected result: {:?}", result),
|
||||||
};
|
};
|
||||||
request.execute(|_, _| async {}).await;
|
request.execute(serve(|_, _| async { Ok(()) })).await;
|
||||||
assert!(requests
|
assert!(requests
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.channel_pin_mut()
|
.channel_pin_mut()
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
use super::{
|
use super::{
|
||||||
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
|
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
|
||||||
Channel,
|
Channel, Serve,
|
||||||
};
|
};
|
||||||
use futures::prelude::*;
|
use futures::prelude::*;
|
||||||
use std::{fmt, hash::Hash};
|
use std::{fmt, hash::Hash};
|
||||||
|
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
use super::{tokio::TokioServerExecutor, Serve};
|
|
||||||
|
|
||||||
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
|
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
|
||||||
pub trait Incoming<C>
|
pub trait Incoming<C>
|
||||||
where
|
where
|
||||||
@@ -28,16 +25,62 @@ where
|
|||||||
MaxRequestsPerChannel::new(self, n)
|
MaxRequestsPerChannel::new(self, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
|
/// Returns a stream of channels in execution. Each channel in execution is a stream of
|
||||||
/// concurrently by spawning on tokio's default executor, and each request will be also
|
/// futures, where each future is an in-flight request being rsponded to.
|
||||||
/// be spawned on tokio's default executor.
|
fn execute<S>(
|
||||||
#[cfg(feature = "tokio1")]
|
self,
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
serve: S,
|
||||||
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
|
) -> impl Stream<Item = impl Stream<Item = impl Future<Output = ()>>>
|
||||||
where
|
where
|
||||||
S: Serve<C::Req, Resp = C::Resp>,
|
S: Serve<Req = C::Req, Resp = C::Resp> + Clone,
|
||||||
{
|
{
|
||||||
TokioServerExecutor::new(self, serve)
|
self.map(move |channel| channel.execute(serve.clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "tokio1")]
|
||||||
|
/// Spawns all channels-in-execution, delegating to the tokio runtime to manage their completion.
|
||||||
|
/// Each channel is spawned, and each request from each channel is spawned.
|
||||||
|
/// Note that this function is generic over any stream-of-streams-of-futures, but it is intended
|
||||||
|
/// for spawning streams of channels.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```rust
|
||||||
|
/// use tarpc::{
|
||||||
|
/// context,
|
||||||
|
/// client::{self, NewClient},
|
||||||
|
/// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve},
|
||||||
|
/// transport,
|
||||||
|
/// };
|
||||||
|
/// use futures::prelude::*;
|
||||||
|
///
|
||||||
|
/// #[tokio::main]
|
||||||
|
/// async fn main() {
|
||||||
|
/// let (tx, rx) = transport::channel::unbounded();
|
||||||
|
/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx);
|
||||||
|
/// tokio::spawn(dispatch);
|
||||||
|
///
|
||||||
|
/// let incoming = stream::once(async move {
|
||||||
|
/// BaseChannel::new(server::Config::default(), rx)
|
||||||
|
/// }).execute(serve(|_, i| async move { Ok(i + 1) }));
|
||||||
|
/// tokio::spawn(spawn_incoming(incoming));
|
||||||
|
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub async fn spawn_incoming(
|
||||||
|
incoming: impl Stream<
|
||||||
|
Item = impl Stream<Item = impl Future<Output = ()> + Send + 'static> + Send + 'static,
|
||||||
|
>,
|
||||||
|
) {
|
||||||
|
use futures::pin_mut;
|
||||||
|
pin_mut!(incoming);
|
||||||
|
while let Some(channel) = incoming.next().await {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
pin_mut!(channel);
|
||||||
|
while let Some(request) = channel.next().await {
|
||||||
|
tokio::spawn(request);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
25
tarpc/src/server/request_hook.rs
Normal file
25
tarpc/src/server/request_hook.rs
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
// Copyright 2022 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
//! Hooks for horizontal functionality that can run either before or after a request is executed.
|
||||||
|
|
||||||
|
/// A request hook that runs before a request is executed.
|
||||||
|
mod before;
|
||||||
|
|
||||||
|
/// A request hook that runs after a request is completed.
|
||||||
|
mod after;
|
||||||
|
|
||||||
|
/// A request hook that runs both before a request is executed and after it is completed.
|
||||||
|
mod before_and_after;
|
||||||
|
|
||||||
|
pub use {
|
||||||
|
after::{AfterRequest, ServeThenHook},
|
||||||
|
before::{
|
||||||
|
before, BeforeRequest, BeforeRequestCons, BeforeRequestList, BeforeRequestNil,
|
||||||
|
HookThenServe,
|
||||||
|
},
|
||||||
|
before_and_after::HookThenServeThenHook,
|
||||||
|
};
|
||||||
72
tarpc/src/server/request_hook/after.rs
Normal file
72
tarpc/src/server/request_hook/after.rs
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
// Copyright 2022 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
//! Provides a hook that runs after request execution.
|
||||||
|
|
||||||
|
use crate::{context, server::Serve, ServerError};
|
||||||
|
use futures::prelude::*;
|
||||||
|
|
||||||
|
/// A hook that runs after request execution.
|
||||||
|
#[allow(async_fn_in_trait)]
|
||||||
|
pub trait AfterRequest<Resp> {
|
||||||
|
/// The function that is called after request execution.
|
||||||
|
///
|
||||||
|
/// The hook can modify the request context and the response.
|
||||||
|
async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F, Fut, Resp> AfterRequest<Resp> for F
|
||||||
|
where
|
||||||
|
F: FnMut(&mut context::Context, &mut Result<Resp, ServerError>) -> Fut,
|
||||||
|
Fut: Future<Output = ()>,
|
||||||
|
{
|
||||||
|
async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>) {
|
||||||
|
self(ctx, resp).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Service function that runs a hook after request execution.
|
||||||
|
pub struct ServeThenHook<Serv, Hook> {
|
||||||
|
serve: Serv,
|
||||||
|
hook: Hook,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Serv, Hook> ServeThenHook<Serv, Hook> {
|
||||||
|
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
|
||||||
|
Self { serve, hook }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Serv: Clone, Hook: Clone> Clone for ServeThenHook<Serv, Hook> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
serve: self.serve.clone(),
|
||||||
|
hook: self.hook.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Serv, Hook> Serve for ServeThenHook<Serv, Hook>
|
||||||
|
where
|
||||||
|
Serv: Serve,
|
||||||
|
Hook: AfterRequest<Serv::Resp>,
|
||||||
|
{
|
||||||
|
type Req = Serv::Req;
|
||||||
|
type Resp = Serv::Resp;
|
||||||
|
|
||||||
|
async fn serve(
|
||||||
|
self,
|
||||||
|
mut ctx: context::Context,
|
||||||
|
req: Serv::Req,
|
||||||
|
) -> Result<Serv::Resp, ServerError> {
|
||||||
|
let ServeThenHook {
|
||||||
|
serve, mut hook, ..
|
||||||
|
} = self;
|
||||||
|
let mut resp = serve.serve(ctx, req).await;
|
||||||
|
hook.after(&mut ctx, &mut resp).await;
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
}
|
||||||
210
tarpc/src/server/request_hook/before.rs
Normal file
210
tarpc/src/server/request_hook/before.rs
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
// Copyright 2022 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
//! Provides a hook that runs before request execution.
|
||||||
|
|
||||||
|
use crate::{context, server::Serve, ServerError};
|
||||||
|
use futures::prelude::*;
|
||||||
|
|
||||||
|
/// A hook that runs before request execution.
|
||||||
|
#[allow(async_fn_in_trait)]
|
||||||
|
pub trait BeforeRequest<Req> {
|
||||||
|
/// The function that is called before request execution.
|
||||||
|
///
|
||||||
|
/// If this function returns an error, the request will not be executed and the error will be
|
||||||
|
/// returned instead.
|
||||||
|
///
|
||||||
|
/// This function can also modify the request context. This could be used, for example, to
|
||||||
|
/// enforce a maximum deadline on all requests.
|
||||||
|
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A list of hooks that run in order before request execution.
|
||||||
|
pub trait BeforeRequestList<Req>: BeforeRequest<Req> {
|
||||||
|
/// The hook returned by `BeforeRequestList::then`.
|
||||||
|
type Then<Next>: BeforeRequest<Req>
|
||||||
|
where
|
||||||
|
Next: BeforeRequest<Req>;
|
||||||
|
|
||||||
|
/// Returns a hook that, when run, runs two hooks, first `self` and then `next`.
|
||||||
|
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next>;
|
||||||
|
|
||||||
|
/// Same as `then`, but helps the compiler with type inference when Next is a closure.
|
||||||
|
fn then_fn<
|
||||||
|
Next: FnMut(&mut context::Context, &Req) -> Fut,
|
||||||
|
Fut: Future<Output = Result<(), ServerError>>,
|
||||||
|
>(
|
||||||
|
self,
|
||||||
|
next: Next,
|
||||||
|
) -> Self::Then<Next>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
self.then(next)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The service fn returned by `BeforeRequestList::serving`.
|
||||||
|
type Serve<S: Serve<Req = Req>>: Serve<Req = Req>;
|
||||||
|
|
||||||
|
/// Runs the list of request hooks before execution of the given serve fn.
|
||||||
|
/// This is equivalent to `serve.before(before_request_chain)` but may be syntactically nicer.
|
||||||
|
fn serving<S: Serve<Req = Req>>(self, serve: S) -> Self::Serve<S>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F, Fut, Req> BeforeRequest<Req> for F
|
||||||
|
where
|
||||||
|
F: FnMut(&mut context::Context, &Req) -> Fut,
|
||||||
|
Fut: Future<Output = Result<(), ServerError>>,
|
||||||
|
{
|
||||||
|
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> {
|
||||||
|
self(ctx, req).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Service function that runs a hook before request execution.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct HookThenServe<Serv, Hook> {
|
||||||
|
serve: Serv,
|
||||||
|
hook: Hook,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Serv, Hook> HookThenServe<Serv, Hook> {
|
||||||
|
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
|
||||||
|
Self { serve, hook }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Serv, Hook> Serve for HookThenServe<Serv, Hook>
|
||||||
|
where
|
||||||
|
Serv: Serve,
|
||||||
|
Hook: BeforeRequest<Serv::Req>,
|
||||||
|
{
|
||||||
|
type Req = Serv::Req;
|
||||||
|
type Resp = Serv::Resp;
|
||||||
|
|
||||||
|
async fn serve(
|
||||||
|
self,
|
||||||
|
mut ctx: context::Context,
|
||||||
|
req: Self::Req,
|
||||||
|
) -> Result<Serv::Resp, ServerError> {
|
||||||
|
let HookThenServe {
|
||||||
|
serve, mut hook, ..
|
||||||
|
} = self;
|
||||||
|
hook.before(&mut ctx, &req).await?;
|
||||||
|
serve.serve(ctx, req).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a request hook builder that runs a series of hooks before request execution.
|
||||||
|
///
|
||||||
|
/// Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use futures::{executor::block_on, future};
|
||||||
|
/// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self,
|
||||||
|
/// BeforeRequest, BeforeRequestList}}};
|
||||||
|
/// use std::{cell::Cell, io};
|
||||||
|
///
|
||||||
|
/// let i = Cell::new(0);
|
||||||
|
/// let serve = request_hook::before()
|
||||||
|
/// .then_fn(|_, _| async {
|
||||||
|
/// assert!(i.get() == 0);
|
||||||
|
/// i.set(1);
|
||||||
|
/// Ok(())
|
||||||
|
/// })
|
||||||
|
/// .then_fn(|_, _| async {
|
||||||
|
/// assert!(i.get() == 1);
|
||||||
|
/// i.set(2);
|
||||||
|
/// Ok(())
|
||||||
|
/// })
|
||||||
|
/// .serving(serve(|_ctx, i| async move { Ok(i + 1) }));
|
||||||
|
/// let response = serve.clone().serve(context::current(), 1);
|
||||||
|
/// assert!(block_on(response).is_ok());
|
||||||
|
/// assert!(i.get() == 2);
|
||||||
|
/// ```
|
||||||
|
pub fn before() -> BeforeRequestNil {
|
||||||
|
BeforeRequestNil
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A list of hooks that run in order before a request is executed.
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct BeforeRequestCons<First, Rest>(First, Rest);
|
||||||
|
|
||||||
|
/// A noop hook that runs before a request is executed.
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct BeforeRequestNil;
|
||||||
|
|
||||||
|
impl<Req, First: BeforeRequest<Req>, Rest: BeforeRequest<Req>> BeforeRequest<Req>
|
||||||
|
for BeforeRequestCons<First, Rest>
|
||||||
|
{
|
||||||
|
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> {
|
||||||
|
let BeforeRequestCons(first, rest) = self;
|
||||||
|
first.before(ctx, req).await?;
|
||||||
|
rest.before(ctx, req).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req> BeforeRequest<Req> for BeforeRequestNil {
|
||||||
|
async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, First: BeforeRequest<Req>, Rest: BeforeRequestList<Req>> BeforeRequestList<Req>
|
||||||
|
for BeforeRequestCons<First, Rest>
|
||||||
|
{
|
||||||
|
type Then<Next> = BeforeRequestCons<First, Rest::Then<Next>> where Next: BeforeRequest<Req>;
|
||||||
|
|
||||||
|
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next> {
|
||||||
|
let BeforeRequestCons(first, rest) = self;
|
||||||
|
BeforeRequestCons(first, rest.then(next))
|
||||||
|
}
|
||||||
|
|
||||||
|
type Serve<S: Serve<Req = Req>> = HookThenServe<S, Self>;
|
||||||
|
|
||||||
|
fn serving<S: Serve<Req = Req>>(self, serve: S) -> Self::Serve<S> {
|
||||||
|
HookThenServe::new(serve, self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req> BeforeRequestList<Req> for BeforeRequestNil {
|
||||||
|
type Then<Next> = BeforeRequestCons<Next, BeforeRequestNil> where Next: BeforeRequest<Req>;
|
||||||
|
|
||||||
|
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next> {
|
||||||
|
BeforeRequestCons(next, BeforeRequestNil)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Serve<S: Serve<Req = Req>> = S;
|
||||||
|
|
||||||
|
fn serving<S: Serve<Req = Req>>(self, serve: S) -> S {
|
||||||
|
serve
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn before_request_list() {
|
||||||
|
use crate::server::serve;
|
||||||
|
use futures::executor::block_on;
|
||||||
|
use std::cell::Cell;
|
||||||
|
|
||||||
|
let i = Cell::new(0);
|
||||||
|
let serve = before()
|
||||||
|
.then_fn(|_, _| async {
|
||||||
|
assert!(i.get() == 0);
|
||||||
|
i.set(1);
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.then_fn(|_, _| async {
|
||||||
|
assert!(i.get() == 1);
|
||||||
|
i.set(2);
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.serving(serve(|_ctx, i| async move { Ok(i + 1) }));
|
||||||
|
let response = serve.clone().serve(context::current(), 1);
|
||||||
|
assert!(block_on(response).is_ok());
|
||||||
|
assert!(i.get() == 2);
|
||||||
|
}
|
||||||
57
tarpc/src/server/request_hook/before_and_after.rs
Normal file
57
tarpc/src/server/request_hook/before_and_after.rs
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
// Copyright 2022 Google LLC
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file or at
|
||||||
|
// https://opensource.org/licenses/MIT.
|
||||||
|
|
||||||
|
//! Provides a hook that runs both before and after request execution.
|
||||||
|
|
||||||
|
use super::{after::AfterRequest, before::BeforeRequest};
|
||||||
|
use crate::{context, server::Serve, ServerError};
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
/// A Service function that runs a hook both before and after request execution.
|
||||||
|
pub struct HookThenServeThenHook<Req, Resp, Serv, Hook> {
|
||||||
|
serve: Serv,
|
||||||
|
hook: Hook,
|
||||||
|
fns: PhantomData<(fn(Req), fn(Resp))>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp, Serv, Hook> HookThenServeThenHook<Req, Resp, Serv, Hook> {
|
||||||
|
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
|
||||||
|
Self {
|
||||||
|
serve,
|
||||||
|
hook,
|
||||||
|
fns: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp, Serv: Clone, Hook: Clone> Clone for HookThenServeThenHook<Req, Resp, Serv, Hook> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
serve: self.serve.clone(),
|
||||||
|
hook: self.hook.clone(),
|
||||||
|
fns: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Req, Resp, Serv, Hook> Serve for HookThenServeThenHook<Req, Resp, Serv, Hook>
|
||||||
|
where
|
||||||
|
Serv: Serve<Req = Req, Resp = Resp>,
|
||||||
|
Hook: BeforeRequest<Req> + AfterRequest<Resp>,
|
||||||
|
{
|
||||||
|
type Req = Req;
|
||||||
|
type Resp = Resp;
|
||||||
|
|
||||||
|
async fn serve(self, mut ctx: context::Context, req: Req) -> Result<Serv::Resp, ServerError> {
|
||||||
|
let HookThenServeThenHook {
|
||||||
|
serve, mut hook, ..
|
||||||
|
} = self;
|
||||||
|
hook.before(&mut ctx, &req).await?;
|
||||||
|
let mut resp = serve.serve(ctx, req).await;
|
||||||
|
hook.after(&mut ctx, &mut resp).await;
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
use super::{Channel, Requests, Serve};
|
|
||||||
use futures::{prelude::*, ready, task::*};
|
|
||||||
use pin_project::pin_project;
|
|
||||||
use std::pin::Pin;
|
|
||||||
|
|
||||||
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
|
|
||||||
/// for each new channel. Returned by
|
|
||||||
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
|
|
||||||
#[must_use]
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct TokioServerExecutor<T, S> {
|
|
||||||
#[pin]
|
|
||||||
inner: T,
|
|
||||||
serve: S,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, S> TokioServerExecutor<T, S> {
|
|
||||||
pub(crate) fn new(inner: T, serve: S) -> Self {
|
|
||||||
Self { inner, serve }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A future that drives the server by [spawning](tokio::spawn) each [response
|
|
||||||
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
|
|
||||||
/// [`Channel::execute`](crate::server::Channel::execute).
|
|
||||||
#[must_use]
|
|
||||||
#[pin_project]
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct TokioChannelExecutor<T, S> {
|
|
||||||
#[pin]
|
|
||||||
inner: T,
|
|
||||||
serve: S,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, S> TokioServerExecutor<T, S> {
|
|
||||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
|
||||||
self.as_mut().project().inner
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, S> TokioChannelExecutor<T, S> {
|
|
||||||
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
|
|
||||||
self.as_mut().project().inner
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send + 'static execution helper methods.
|
|
||||||
|
|
||||||
impl<C> Requests<C>
|
|
||||||
where
|
|
||||||
C: Channel,
|
|
||||||
C::Req: Send + 'static,
|
|
||||||
C::Resp: Send + 'static,
|
|
||||||
{
|
|
||||||
/// Executes all requests using the given service function. Requests are handled concurrently
|
|
||||||
/// by [spawning](::tokio::spawn) each handler on tokio's default executor.
|
|
||||||
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
|
|
||||||
where
|
|
||||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
|
|
||||||
{
|
|
||||||
TokioChannelExecutor { inner: self, serve }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
|
|
||||||
where
|
|
||||||
St: Sized + Stream<Item = C>,
|
|
||||||
C: Channel + Send + 'static,
|
|
||||||
C::Req: Send + 'static,
|
|
||||||
C::Resp: Send + 'static,
|
|
||||||
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
|
||||||
Se::Fut: Send,
|
|
||||||
{
|
|
||||||
type Output = ();
|
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
|
||||||
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
|
||||||
tokio::spawn(channel.execute(self.serve.clone()));
|
|
||||||
}
|
|
||||||
tracing::info!("Server shutting down.");
|
|
||||||
Poll::Ready(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
|
|
||||||
where
|
|
||||||
C: Channel + 'static,
|
|
||||||
C::Req: Send + 'static,
|
|
||||||
C::Resp: Send + 'static,
|
|
||||||
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
|
|
||||||
S::Fut: Send,
|
|
||||||
{
|
|
||||||
type Output = ();
|
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
||||||
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
|
||||||
match response_handler {
|
|
||||||
Ok(resp) => {
|
|
||||||
let server = self.serve.clone();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
resp.execute(server).await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::warn!("Requests stream errored out: {}", e);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Poll::Ready(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -14,9 +14,15 @@ use tokio::sync::mpsc;
|
|||||||
/// Errors that occur in the sending or receiving of messages over a channel.
|
/// Errors that occur in the sending or receiving of messages over a channel.
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum ChannelError {
|
pub enum ChannelError {
|
||||||
/// An error occurred sending over the channel.
|
/// An error occurred readying to send into the channel.
|
||||||
#[error("an error occurred sending over the channel")]
|
#[error("an error occurred readying to send into the channel")]
|
||||||
|
Ready(#[source] Box<dyn Error + Send + Sync + 'static>),
|
||||||
|
/// An error occurred sending into the channel.
|
||||||
|
#[error("an error occurred sending into the channel")]
|
||||||
Send(#[source] Box<dyn Error + Send + Sync + 'static>),
|
Send(#[source] Box<dyn Error + Send + Sync + 'static>),
|
||||||
|
/// An error occurred receiving from the channel.
|
||||||
|
#[error("an error occurred receiving from the channel")]
|
||||||
|
Receive(#[source] Box<dyn Error + Send + Sync + 'static>),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
|
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
|
||||||
@@ -48,7 +54,10 @@ impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
|
|||||||
mut self: Pin<&mut Self>,
|
mut self: Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Option<Result<Item, ChannelError>>> {
|
) -> Poll<Option<Result<Item, ChannelError>>> {
|
||||||
self.rx.poll_recv(cx).map(|option| option.map(Ok))
|
self.rx
|
||||||
|
.poll_recv(cx)
|
||||||
|
.map(|option| option.map(Ok))
|
||||||
|
.map_err(ChannelError::Receive)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,7 +68,7 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
|
|||||||
|
|
||||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
Poll::Ready(if self.tx.is_closed() {
|
Poll::Ready(if self.tx.is_closed() {
|
||||||
Err(ChannelError::Send(CLOSED_MESSAGE.into()))
|
Err(ChannelError::Ready(CLOSED_MESSAGE.into()))
|
||||||
} else {
|
} else {
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
@@ -110,7 +119,11 @@ impl<Item, SinkItem> Stream for Channel<Item, SinkItem> {
|
|||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Option<Result<Item, ChannelError>>> {
|
) -> Poll<Option<Result<Item, ChannelError>>> {
|
||||||
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
|
self.project()
|
||||||
|
.rx
|
||||||
|
.poll_next(cx)
|
||||||
|
.map(|option| option.map(Ok))
|
||||||
|
.map_err(ChannelError::Receive)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,7 +134,7 @@ impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
|
|||||||
self.project()
|
self.project()
|
||||||
.tx
|
.tx
|
||||||
.poll_ready(cx)
|
.poll_ready(cx)
|
||||||
.map_err(|e| ChannelError::Send(Box::new(e)))
|
.map_err(|e| ChannelError::Ready(Box::new(e)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
|
||||||
@@ -146,16 +159,17 @@ impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(all(test, feature = "tokio1"))]
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::{
|
use crate::{
|
||||||
client, context,
|
client::{self, RpcError},
|
||||||
server::{incoming::Incoming, BaseChannel},
|
context,
|
||||||
|
server::{incoming::Incoming, serve, BaseChannel},
|
||||||
transport::{
|
transport::{
|
||||||
self,
|
self,
|
||||||
channel::{Channel, UnboundedChannel},
|
channel::{Channel, UnboundedChannel},
|
||||||
},
|
},
|
||||||
|
ServerError,
|
||||||
};
|
};
|
||||||
use assert_matches::assert_matches;
|
use assert_matches::assert_matches;
|
||||||
use futures::{prelude::*, stream};
|
use futures::{prelude::*, stream};
|
||||||
@@ -177,25 +191,28 @@ mod tests {
|
|||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
stream::once(future::ready(server_channel))
|
stream::once(future::ready(server_channel))
|
||||||
.map(BaseChannel::with_defaults)
|
.map(BaseChannel::with_defaults)
|
||||||
.execute(|_ctx, request: String| {
|
.execute(serve(|_ctx, request: String| async move {
|
||||||
future::ready(request.parse::<u64>().map_err(|_| {
|
request.parse::<u64>().map_err(|_| {
|
||||||
io::Error::new(
|
ServerError::new(
|
||||||
io::ErrorKind::InvalidInput,
|
io::ErrorKind::InvalidInput,
|
||||||
format!("{request:?} is not an int"),
|
format!("{request:?} is not an int"),
|
||||||
)
|
)
|
||||||
}))
|
})
|
||||||
|
}))
|
||||||
|
.for_each(|channel| async move {
|
||||||
|
tokio::spawn(channel.for_each(|response| response));
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
let client = client::new(client::Config::default(), client_channel).spawn();
|
let client = client::new(client::Config::default(), client_channel).spawn();
|
||||||
|
|
||||||
let response1 = client.call(context::current(), "", "123".into()).await?;
|
let response1 = client.call(context::current(), "", "123".into()).await;
|
||||||
let response2 = client.call(context::current(), "", "abc".into()).await?;
|
let response2 = client.call(context::current(), "", "abc".into()).await;
|
||||||
|
|
||||||
trace!("response1: {:?}, response2: {:?}", response1, response2);
|
trace!("response1: {:?}, response2: {:?}", response1, response2);
|
||||||
|
|
||||||
assert_matches!(response1, Ok(123));
|
assert_matches!(response1, Ok(123));
|
||||||
assert_matches!(response2, Err(ref e) if e.kind() == io::ErrorKind::InvalidInput);
|
assert_matches!(response2, Err(RpcError::Server(e)) if e.kind == io::ErrorKind::InvalidInput);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,8 +2,6 @@
|
|||||||
fn ui() {
|
fn ui() {
|
||||||
let t = trybuild::TestCases::new();
|
let t = trybuild::TestCases::new();
|
||||||
t.compile_fail("tests/compile_fail/*.rs");
|
t.compile_fail("tests/compile_fail/*.rs");
|
||||||
#[cfg(feature = "tokio1")]
|
|
||||||
t.compile_fail("tests/compile_fail/tokio/*.rs");
|
|
||||||
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
||||||
t.compile_fail("tests/compile_fail/serde_transport/*.rs");
|
t.compile_fail("tests/compile_fail/serde_transport/*.rs");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,3 +9,7 @@ note: the lint level is defined here
|
|||||||
|
|
|
|
||||||
11 | #[deny(unused_must_use)]
|
11 | #[deny(unused_must_use)]
|
||||||
| ^^^^^^^^^^^^^^^
|
| ^^^^^^^^^^^^^^^
|
||||||
|
help: use `let _ = ...` to ignore the resulting value
|
||||||
|
|
|
||||||
|
13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch;
|
||||||
|
| +++++++
|
||||||
|
|||||||
@@ -9,3 +9,7 @@ note: the lint level is defined here
|
|||||||
|
|
|
|
||||||
5 | #[deny(unused_must_use)]
|
5 | #[deny(unused_must_use)]
|
||||||
| ^^^^^^^^^^^^^^^
|
| ^^^^^^^^^^^^^^^
|
||||||
|
help: use `let _ = ...` to ignore the resulting value
|
||||||
|
|
|
||||||
|
7 | let _ = serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
|
||||||
|
| +++++++
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
#[tarpc::service(derive_serde = false)]
|
|
||||||
trait World {
|
|
||||||
async fn hello(name: String) -> String;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct HelloServer;
|
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl World for HelloServer {
|
|
||||||
fn hello(name: String) -> String {
|
|
||||||
format!("Hello, {name}!", name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
error: not all trait items implemented, missing: `HelloFut`
|
|
||||||
--> $DIR/tarpc_server_missing_async.rs:9:1
|
|
||||||
|
|
|
||||||
9 | impl World for HelloServer {
|
|
||||||
| ^^^^
|
|
||||||
|
|
||||||
error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async
|
|
||||||
--> $DIR/tarpc_server_missing_async.rs:10:5
|
|
||||||
|
|
|
||||||
10 | fn hello(name: String) -> String {
|
|
||||||
| ^^
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
use tarpc::{
|
|
||||||
context,
|
|
||||||
server::{self, Channel},
|
|
||||||
};
|
|
||||||
|
|
||||||
#[tarpc::service]
|
|
||||||
trait World {
|
|
||||||
async fn hello(name: String) -> String;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct HelloServer;
|
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl World for HelloServer {
|
|
||||||
async fn hello(self, _: context::Context, name: String) -> String {
|
|
||||||
format!("Hello, {name}!")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
let (_, server_transport) = tarpc::transport::channel::unbounded();
|
|
||||||
let server = server::BaseChannel::with_defaults(server_transport);
|
|
||||||
|
|
||||||
#[deny(unused_must_use)]
|
|
||||||
{
|
|
||||||
server.execute(HelloServer.serve());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
error: unused `TokioChannelExecutor` that must be used
|
|
||||||
--> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9
|
|
||||||
|
|
|
||||||
27 | server.execute(HelloServer.serve());
|
|
||||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
|
||||||
note: the lint level is defined here
|
|
||||||
--> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12
|
|
||||||
|
|
|
||||||
25 | #[deny(unused_must_use)]
|
|
||||||
| ^^^^^^^^^^^^^^^
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
use futures::stream::once;
|
|
||||||
use tarpc::{
|
|
||||||
context,
|
|
||||||
server::{self, incoming::Incoming},
|
|
||||||
};
|
|
||||||
|
|
||||||
#[tarpc::service]
|
|
||||||
trait World {
|
|
||||||
async fn hello(name: String) -> String;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct HelloServer;
|
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl World for HelloServer {
|
|
||||||
async fn hello(self, _: context::Context, name: String) -> String {
|
|
||||||
format!("Hello, {name}!")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
let (_, server_transport) = tarpc::transport::channel::unbounded();
|
|
||||||
let server = once(async move { server::BaseChannel::with_defaults(server_transport) });
|
|
||||||
|
|
||||||
#[deny(unused_must_use)]
|
|
||||||
{
|
|
||||||
server.execute(HelloServer.serve());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
error: unused `TokioServerExecutor` that must be used
|
|
||||||
--> tests/compile_fail/tokio/must_use_server_executor.rs:28:9
|
|
||||||
|
|
|
||||||
28 | server.execute(HelloServer.serve());
|
|
||||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
|
||||||
note: the lint level is defined here
|
|
||||||
--> tests/compile_fail/tokio/must_use_server_executor.rs:26:12
|
|
||||||
|
|
|
||||||
26 | #[deny(unused_must_use)]
|
|
||||||
| ^^^^^^^^^^^^^^^
|
|
||||||
@@ -21,7 +21,6 @@ pub trait ColorProtocol {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct ColorServer;
|
struct ColorServer;
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl ColorProtocol for ColorServer {
|
impl ColorProtocol for ColorServer {
|
||||||
async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData {
|
async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData {
|
||||||
match color {
|
match color {
|
||||||
@@ -31,6 +30,11 @@ impl ColorProtocol for ColorServer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_call() -> anyhow::Result<()> {
|
async fn test_call() -> anyhow::Result<()> {
|
||||||
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
|
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
|
||||||
@@ -40,7 +44,9 @@ async fn test_call() -> anyhow::Result<()> {
|
|||||||
.take(1)
|
.take(1)
|
||||||
.filter_map(|r| async { r.ok() })
|
.filter_map(|r| async { r.ok() })
|
||||||
.map(BaseChannel::with_defaults)
|
.map(BaseChannel::with_defaults)
|
||||||
.execute(ColorServer.serve()),
|
.execute(ColorServer.serve())
|
||||||
|
.map(|channel| channel.for_each(spawn))
|
||||||
|
.for_each(spawn),
|
||||||
);
|
);
|
||||||
|
|
||||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||||
|
|||||||
17
tarpc/tests/proc_macro_hygene.rs
Normal file
17
tarpc/tests/proc_macro_hygene.rs
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
#![no_implicit_prelude]
|
||||||
|
extern crate tarpc as some_random_other_name;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde1")]
|
||||||
|
mod serde1_feature {
|
||||||
|
#[::tarpc::derive_serde]
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub enum TestData {
|
||||||
|
Black,
|
||||||
|
White,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[::tarpc::service]
|
||||||
|
pub trait ColorProtocol {
|
||||||
|
async fn get_opposite_color(color: u8) -> u8;
|
||||||
|
}
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
use assert_matches::assert_matches;
|
use assert_matches::assert_matches;
|
||||||
use futures::{
|
use futures::{
|
||||||
future::{join_all, ready, Ready},
|
future::{join_all, ready},
|
||||||
prelude::*,
|
prelude::*,
|
||||||
};
|
};
|
||||||
use std::time::{Duration, SystemTime};
|
use std::time::{Duration, SystemTime};
|
||||||
use tarpc::{
|
use tarpc::{
|
||||||
client::{self},
|
client::{self},
|
||||||
context,
|
context,
|
||||||
server::{self, incoming::Incoming, BaseChannel, Channel},
|
server::{incoming::Incoming, BaseChannel, Channel},
|
||||||
transport::channel,
|
transport::channel,
|
||||||
};
|
};
|
||||||
use tokio::join;
|
use tokio::join;
|
||||||
@@ -22,39 +22,29 @@ trait Service {
|
|||||||
struct Server;
|
struct Server;
|
||||||
|
|
||||||
impl Service for Server {
|
impl Service for Server {
|
||||||
type AddFut = Ready<i32>;
|
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
|
||||||
|
x + y
|
||||||
fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut {
|
|
||||||
ready(x + y)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type HeyFut = Ready<String>;
|
async fn hey(self, _: context::Context, name: String) -> String {
|
||||||
|
format!("Hey, {name}.")
|
||||||
fn hey(self, _: context::Context, name: String) -> Self::HeyFut {
|
|
||||||
ready(format!("Hey, {name}."))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sequential() -> anyhow::Result<()> {
|
async fn sequential() {
|
||||||
let _ = tracing_subscriber::fmt::try_init();
|
let (tx, rx) = tarpc::transport::channel::unbounded();
|
||||||
|
let client = client::new(client::Config::default(), tx).spawn();
|
||||||
let (tx, rx) = channel::unbounded();
|
let channel = BaseChannel::with_defaults(rx);
|
||||||
|
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
BaseChannel::new(server::Config::default(), rx)
|
channel
|
||||||
.requests()
|
.execute(tarpc::server::serve(|_, i| async move { Ok(i + 1) }))
|
||||||
.execute(Server.serve()),
|
.for_each(|response| response),
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
client.call(context::current(), "AddOne", 1).await.unwrap(),
|
||||||
|
2
|
||||||
);
|
);
|
||||||
|
|
||||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
|
||||||
|
|
||||||
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
|
|
||||||
assert_matches!(
|
|
||||||
client.hey(context::current(), "Tim".into()).await,
|
|
||||||
Ok(ref s) if s == "Hey, Tim.");
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -70,7 +60,6 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct AllHandlersComplete;
|
struct AllHandlersComplete;
|
||||||
|
|
||||||
#[tarpc::server]
|
|
||||||
impl Loop for LoopServer {
|
impl Loop for LoopServer {
|
||||||
async fn r#loop(self, _: context::Context) {
|
async fn r#loop(self, _: context::Context) {
|
||||||
loop {
|
loop {
|
||||||
@@ -121,7 +110,9 @@ async fn serde_tcp() -> anyhow::Result<()> {
|
|||||||
.take(1)
|
.take(1)
|
||||||
.filter_map(|r| async { r.ok() })
|
.filter_map(|r| async { r.ok() })
|
||||||
.map(BaseChannel::with_defaults)
|
.map(BaseChannel::with_defaults)
|
||||||
.execute(Server.serve()),
|
.execute(Server.serve())
|
||||||
|
.map(|channel| channel.for_each(spawn))
|
||||||
|
.for_each(spawn),
|
||||||
);
|
);
|
||||||
|
|
||||||
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
|
||||||
@@ -151,7 +142,9 @@ async fn serde_uds() -> anyhow::Result<()> {
|
|||||||
.take(1)
|
.take(1)
|
||||||
.filter_map(|r| async { r.ok() })
|
.filter_map(|r| async { r.ok() })
|
||||||
.map(BaseChannel::with_defaults)
|
.map(BaseChannel::with_defaults)
|
||||||
.execute(Server.serve()),
|
.execute(Server.serve())
|
||||||
|
.map(|channel| channel.for_each(spawn))
|
||||||
|
.for_each(spawn),
|
||||||
);
|
);
|
||||||
|
|
||||||
let transport = serde_transport::unix::connect(&sock, Json::default).await?;
|
let transport = serde_transport::unix::connect(&sock, Json::default).await?;
|
||||||
@@ -175,7 +168,9 @@ async fn concurrent() -> anyhow::Result<()> {
|
|||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
stream::once(ready(rx))
|
stream::once(ready(rx))
|
||||||
.map(BaseChannel::with_defaults)
|
.map(BaseChannel::with_defaults)
|
||||||
.execute(Server.serve()),
|
.execute(Server.serve())
|
||||||
|
.map(|channel| channel.for_each(spawn))
|
||||||
|
.for_each(spawn),
|
||||||
);
|
);
|
||||||
|
|
||||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||||
@@ -199,7 +194,9 @@ async fn concurrent_join() -> anyhow::Result<()> {
|
|||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
stream::once(ready(rx))
|
stream::once(ready(rx))
|
||||||
.map(BaseChannel::with_defaults)
|
.map(BaseChannel::with_defaults)
|
||||||
.execute(Server.serve()),
|
.execute(Server.serve())
|
||||||
|
.map(|channel| channel.for_each(spawn))
|
||||||
|
.for_each(spawn),
|
||||||
);
|
);
|
||||||
|
|
||||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||||
@@ -216,15 +213,20 @@ async fn concurrent_join() -> anyhow::Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn concurrent_join_all() -> anyhow::Result<()> {
|
async fn concurrent_join_all() -> anyhow::Result<()> {
|
||||||
let _ = tracing_subscriber::fmt::try_init();
|
let _ = tracing_subscriber::fmt::try_init();
|
||||||
|
|
||||||
let (tx, rx) = channel::unbounded();
|
let (tx, rx) = channel::unbounded();
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
stream::once(ready(rx))
|
BaseChannel::with_defaults(rx)
|
||||||
.map(BaseChannel::with_defaults)
|
.execute(Server.serve())
|
||||||
.execute(Server.serve()),
|
.for_each(spawn),
|
||||||
);
|
);
|
||||||
|
|
||||||
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
let client = ServiceClient::new(client::Config::default(), tx).spawn();
|
||||||
@@ -249,11 +251,9 @@ async fn counter() -> anyhow::Result<()> {
|
|||||||
struct CountService(u32);
|
struct CountService(u32);
|
||||||
|
|
||||||
impl Counter for &mut CountService {
|
impl Counter for &mut CountService {
|
||||||
type CountFut = futures::future::Ready<u32>;
|
async fn count(self, _: context::Context) -> u32 {
|
||||||
|
|
||||||
fn count(self, _: context::Context) -> Self::CountFut {
|
|
||||||
self.0 += 1;
|
self.0 += 1;
|
||||||
futures::future::ready(self.0)
|
self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user