18 Commits

Author SHA1 Message Date
Tim Kuehn
a6758fd1f9 BeforeRequest hook chaining.
It's unintuitive that serve.before(hook1).before(hook2) executes in
reverse order, with hook2 going before hook1. With BeforeRequestList,
users can write `before().then(hook1).then(hook2).serving(serve)`, and
it will run hook1, then hook2, then the service fn.
2023-12-29 23:13:06 -08:00
Tim Kuehn
2c241cc809 Fix readme and release notes. 2023-12-29 21:03:09 -08:00
Tim Kuehn
263ef8a897 Prepare for v0.34.0 release 2023-12-29 20:52:37 -08:00
Kevin Ge
d50290a21c Add README for example app 2023-12-29 20:32:00 -08:00
Kevin Ge
26988cb833 Update OpenTelemetry packages in example app 2023-12-29 20:32:00 -08:00
Tim Kuehn
6cf18a1caf Rewrite traits to use async-fn-in-trait.
- Stub
- BeforeRequest
- AfterRequest

Also removed the last remaining usage of an unstable feature,
iter_intersperse.
2023-12-29 13:52:05 -08:00
Tim Kuehn
84932df9b4 Return span in InFlightRequests::complete_request.
Rather than returning a bool, return the Span associated with the
request. This gives RequestDispatch more flexibility to annotate the
request span.
2023-12-29 13:52:05 -08:00
Tim Kuehn
8dc3711a80 Use async fn in generated traits!!
The major breaking change is that Channel::execute no longer internally
spawns RPC handlers, because it is no longer possible to place a Send
bound on the return type of Serve::serve. Instead, Channel::execute
returns a stream of RPC handler futures.

Service authors can reproduce the old behavior by spawning each response
handler (the compiler knows whether or not the futures can be spawned;
it's just that the bounds can't be expressed generically):

    channel.execute(server.serve())
           .for_each(|rpc| { tokio::spawn(rpc); })
2023-12-29 13:52:05 -08:00
Tim Kuehn
7c5afa97bb Add request hooks to the Serve trait.
This allows plugging in horizontal functionality, such as authorization,
throttling, or latency recording, that should run before and/or after
execution of every request, regardless of the request type.

The tracing example is updated to show off both client stubs as well as
server hooks.

As part of this change, there were some changes to the Serve trait:

1. Serve's output type is now a Result<Response, ServerError>..
   Serve previously did not allow returning ServerErrors, which
   prevented using Serve for horizontal functionality like throttling or
   auth. Now, Serve's output type is Result<Resp, ServerError>, making
   Serve a more natural integration point for horizontal capabilities.
2. Serve's generic Request type changed to an associated type. The
   primary benefit of the generic type is that it allows one type to
   impl a trait multiple times (for example, u64 impls TryFrom<usize>,
   TryFrom<u128>, etc.). In the case of Serve impls, while it is
   theoretically possible to contrive a type that could serve multiple
   request types, in practice I don't expect that to be needed.  Most
   users will use the Serve impl generated by #[tarpc::service], which
   only ever serves one type of request.
2023-12-29 13:52:05 -08:00
Tim Kuehn
324df5cd15 Add back the Client trait, renamed Stub.
Also adds a Client stub trait alias for each generated service.

Now that generic associated types are stable, it's almost possible to
define a trait for Channel that works with async fns on stable. `impl
trait in type aliases` is still necessary (and unstable), but we're
getting closer.

As a proof of concept, three more implementations of Stub are implemented;

1. A load balancer that round-robins requests between different stubs.
2. A load balancer that selects a stub based on a request hash, so that
   the same requests go to the same stubs.
3. A stub that retries requests based on a configurable policy.

   The "serde/rc" feature is added to the "full" feature because the Retry
   stub wraps the request in an Arc, so that the request is reusable for
   multiple calls.

   Server implementors commonly need to operate generically across all
   services or request types. For example, a server throttler may want to
   return errors telling clients to back off, which is not specific to any
   one service.
2023-12-29 13:52:05 -08:00
Guillaume Charmetant
3264979993 Fix warnings in README's example 2023-11-16 09:54:54 -08:00
Guillaume Charmetant
dd63fb59bf Fix tokio dep in the README's example
Add missing tokio feature in the example's dependencies.
2023-11-16 09:54:54 -08:00
Tim Kuehn
f4db8cc5b4 Address clippy lints 2023-11-16 00:00:27 -08:00
Tim Kuehn
e9ba350496 Update must-use UI tests 2023-11-16 00:00:27 -08:00
Tim Kuehn
e6d779e70b Remove mipsel workflow.
Mipsel was downgraded to tier 3, which broke this workflow.
https://github.com/rust-lang/compiler-team/issues/648
2023-11-16 00:00:27 -08:00
Izumi Raine
ce5f8cfb0c Simplify TLS example (#404) 2023-04-14 12:33:22 -07:00
Tim Kuehn
4b69dc8db5 Prepare release of v0.33.0 2023-04-03 11:03:55 -07:00
Bruno
866db2a2cd Bump opentelemetry to 0.18 (#401) 2023-04-03 10:38:17 -07:00
44 changed files with 1753 additions and 881 deletions

View File

@@ -19,10 +19,7 @@ jobs:
access_token: ${{ github.token }} access_token: ${{ github.token }}
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable - uses: dtolnay/rust-toolchain@stable
with:
targets: mipsel-unknown-linux-gnu
- run: cargo check --all-features - run: cargo check --all-features
- run: cargo check --all-features --target mipsel-unknown-linux-gnu
test: test:
name: Test Suite name: Test Suite

View File

@@ -67,7 +67,7 @@ Some other features of tarpc:
Add to your `Cargo.toml` dependencies: Add to your `Cargo.toml` dependencies:
```toml ```toml
tarpc = "0.32" tarpc = "0.34"
``` ```
The `tarpc::service` attribute expands to a collection of items that form an rpc service. The `tarpc::service` attribute expands to a collection of items that form an rpc service.
@@ -83,7 +83,7 @@ your `Cargo.toml`:
anyhow = "1.0" anyhow = "1.0"
futures = "0.3" futures = "0.3"
tarpc = { version = "0.31", features = ["tokio1"] } tarpc = { version = "0.31", features = ["tokio1"] }
tokio = { version = "1.0", features = ["macros"] } tokio = { version = "1.0", features = ["rt-multi-thread", "macros"] }
``` ```
In the following example, we use an in-process channel for communication between In the following example, we use an in-process channel for communication between
@@ -93,14 +93,10 @@ For a more real-world example, see [example-service](example-service).
First, let's set up the dependencies and service definition. First, let's set up the dependencies and service definition.
```rust ```rust
use futures::future::{self, Ready};
use futures::{
future::{self, Ready},
prelude::*,
};
use tarpc::{ use tarpc::{
client, context, client, context,
server::{self, incoming::Incoming, Channel}, server::{self, Channel},
}; };
// This is the service definition. It looks a lot like a trait definition. // This is the service definition. It looks a lot like a trait definition.
@@ -122,13 +118,8 @@ implement it for our Server struct.
struct HelloServer; struct HelloServer;
impl World for HelloServer { impl World for HelloServer {
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and async fn hello(self, _: context::Context, name: String) -> String {
// an associated type representing the future output by the fn. format!("Hello, {name}!")
type HelloFut = Ready<String>;
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!("Hello, {name}!"))
} }
} }
``` ```
@@ -148,7 +139,7 @@ async fn main() -> anyhow::Result<()> {
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
// that takes a config and any Transport as input. // that takes a config and any Transport as input.
let mut client = WorldClient::new(client::Config::default(), client_transport).spawn(); let client = WorldClient::new(client::Config::default(), client_transport).spawn();
// The client has an RPC method for each RPC defined in the annotated trait. It takes the same // The client has an RPC method for each RPC defined in the annotated trait. It takes the same
// args as defined, with the addition of a Context, which is always the first arg. The Context // args as defined, with the addition of a Context, which is always the first arg. The Context

View File

@@ -1,3 +1,29 @@
## 0.34.0 (2023-12-29)
### Breaking Changes
- `#[tarpc::server]` is no more! Service traits now use async fns.
- `Channel::execute` no longer spawns request handlers. Async-fn-in-traits makes it impossible to
add a Send bound to the future returned by `Serve::serve`. Instead, `Channel::execute` returns a
stream of futures, where each future is a request handler. To achieve the former behavior:
```rust
channel.execute(server.serve())
.for_each(|rpc| { tokio::spawn(rpc); })
```
### New Features
- Request hooks are added to the serve trait, so that it's easy to hook in cross-cutting
functionality like throttling, authorization, etc.
- The Client trait is back! This makes it possible to hook in generic client functionality like load
balancing, retries, etc.
## 0.33.0 (2023-04-01)
### Breaking Changes
Opentelemetry dependency version increased to 0.18.
## 0.32.0 (2023-03-24) ## 0.32.0 (2023-03-24)
### Breaking Changes ### Breaking Changes

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "tarpc-example-service" name = "tarpc-example-service"
version = "0.14.0" version = "0.15.0"
rust-version = "1.56" rust-version = "1.56"
authors = ["Tim Kuehn <tikue@google.com>"] authors = ["Tim Kuehn <tikue@google.com>"]
edition = "2021" edition = "2021"
@@ -18,14 +18,15 @@ anyhow = "1.0"
clap = { version = "3.0.0-rc.9", features = ["derive"] } clap = { version = "3.0.0-rc.9", features = ["derive"] }
log = "0.4" log = "0.4"
futures = "0.3" futures = "0.3"
opentelemetry = { version = "0.17", features = ["rt-tokio"] } opentelemetry = { version = "0.21.0" }
opentelemetry-jaeger = { version = "0.16", features = ["rt-tokio"] } opentelemetry-jaeger = { version = "0.20.0", features = ["rt-tokio"] }
rand = "0.8" rand = "0.8"
tarpc = { version = "0.32", path = "../tarpc", features = ["full"] } tarpc = { version = "0.34", path = "../tarpc", features = ["full"] }
tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] } tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] }
tracing = { version = "0.1" } tracing = { version = "0.1" }
tracing-opentelemetry = "0.17" tracing-opentelemetry = "0.22.0"
tracing-subscriber = {version = "0.3", features = ["env-filter"]} tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
opentelemetry_sdk = "0.21.1"
[lib] [lib]
name = "service" name = "service"

15
example-service/README.md Normal file
View File

@@ -0,0 +1,15 @@
# Example
Example service to demonstrate how to set up `tarpc` with [Jaeger](https://www.jaegertracing.io). To see traces Jaeger, run the following with `RUST_LOG=trace`.
## Server
```bash
cargo run --bin server -- --port 50051
```
## Client
```bash
cargo run --bin client -- --server-addr "[::1]:50051" --name "Bob"
```

View File

@@ -19,10 +19,10 @@ pub trait World {
pub fn init_tracing(service_name: &str) -> anyhow::Result<()> { pub fn init_tracing(service_name: &str) -> anyhow::Result<()> {
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12"); env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
let tracer = opentelemetry_jaeger::new_pipeline() let tracer = opentelemetry_jaeger::new_agent_pipeline()
.with_service_name(service_name) .with_service_name(service_name)
.with_max_packet_size(2usize.pow(13)) .with_max_packet_size(2usize.pow(13))
.install_batch(opentelemetry::runtime::Tokio)?; .install_batch(opentelemetry_sdk::runtime::Tokio)?;
tracing_subscriber::registry() tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::from_default_env()) .with(tracing_subscriber::EnvFilter::from_default_env())

View File

@@ -34,7 +34,6 @@ struct Flags {
#[derive(Clone)] #[derive(Clone)]
struct HelloServer(SocketAddr); struct HelloServer(SocketAddr);
#[tarpc::server]
impl World for HelloServer { impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String { async fn hello(self, _: context::Context, name: String) -> String {
let sleep_time = let sleep_time =
@@ -44,6 +43,10 @@ impl World for HelloServer {
} }
} }
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let flags = Flags::parse(); let flags = Flags::parse();
@@ -66,7 +69,7 @@ async fn main() -> anyhow::Result<()> {
// the generated World trait. // the generated World trait.
.map(|channel| { .map(|channel| {
let server = HelloServer(channel.transport().peer_addr().unwrap()); let server = HelloServer(channel.transport().peer_addr().unwrap());
channel.execute(server.serve()) channel.execute(server.serve()).for_each(spawn)
}) })
// Max 10 channels. // Max 10 channels.
.buffer_unordered(10) .buffer_unordered(10)

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "tarpc-plugins" name = "tarpc-plugins"
version = "0.12.0" version = "0.13.0"
rust-version = "1.56" rust-version = "1.56"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"] authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2021" edition = "2021"

View File

@@ -12,18 +12,18 @@ extern crate quote;
extern crate syn; extern crate syn;
use proc_macro::TokenStream; use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2}; use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote, ToTokens}; use quote::{format_ident, quote, ToTokens};
use syn::{ use syn::{
braced, braced,
ext::IdentExt, ext::IdentExt,
parenthesized, parenthesized,
parse::{Parse, ParseStream}, parse::{Parse, ParseStream},
parse_macro_input, parse_quote, parse_str, parse_macro_input, parse_quote,
spanned::Spanned, spanned::Spanned,
token::Comma, token::Comma,
Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool, Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type,
MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility, Visibility,
}; };
/// Accumulates multiple errors into a result. /// Accumulates multiple errors into a result.
@@ -257,7 +257,6 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
.map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string())) .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
.collect(); .collect();
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>(); let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
let response_fut_name = &format!("{}ResponseFut", ident.unraw());
let derive_serialize = if derive_serde.0 { let derive_serialize = if derive_serde.0 {
Some( Some(
quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)] quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
@@ -274,10 +273,9 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
ServiceGenerator { ServiceGenerator {
response_fut_name,
service_ident: ident, service_ident: ident,
client_stub_ident: &format_ident!("{}Stub", ident),
server_ident: &format_ident!("Serve{}", ident), server_ident: &format_ident!("Serve{}", ident),
response_fut_ident: &Ident::new(response_fut_name, ident.span()),
client_ident: &format_ident!("{}Client", ident), client_ident: &format_ident!("{}Client", ident),
request_ident: &format_ident!("{}Request", ident), request_ident: &format_ident!("{}Request", ident),
response_ident: &format_ident!("{}Response", ident), response_ident: &format_ident!("{}Response", ident),
@@ -304,137 +302,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
.zip(camel_case_fn_names.iter()) .zip(camel_case_fn_names.iter())
.map(|(rpc, name)| Ident::new(name, rpc.ident.span())) .map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
future_types: &camel_case_fn_names
.iter()
.map(|name| parse_str(&format!("{name}Fut")).unwrap())
.collect::<Vec<_>>(),
derive_serialize: derive_serialize.as_ref(), derive_serialize: derive_serialize.as_ref(),
} }
.into_token_stream() .into_token_stream()
.into() .into()
} }
/// generate an identifier consisting of the method name to CamelCase with
/// Fut appended to it.
fn associated_type_for_rpc(method: &ImplItemMethod) -> String {
snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut"
}
/// Transforms an async function into a sync one, returning a type declaration
/// for the return type (a future).
fn transform_method(method: &mut ImplItemMethod) -> ImplItemType {
method.sig.asyncness = None;
// get either the return type or ().
let ret = match &method.sig.output {
ReturnType::Default => quote!(()),
ReturnType::Type(_, ret) => quote!(#ret),
};
let fut_name = associated_type_for_rpc(method);
let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span());
// generate the updated return signature.
method.sig.output = parse_quote! {
-> ::core::pin::Pin<Box<
dyn ::core::future::Future<Output = #ret> + ::core::marker::Send
>>
};
// transform the body of the method into Box::pin(async move { body }).
let block = method.block.clone();
method.block = parse_quote! [{
Box::pin(async move
#block
)
}];
// generate and return type declaration for return type.
let t: ImplItemType = parse_quote! {
type #fut_name_ident = ::core::pin::Pin<Box<dyn ::core::future::Future<Output = #ret> + ::core::marker::Send>>;
};
t
}
#[proc_macro_attribute]
pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream {
let mut item = syn::parse_macro_input!(input as ItemImpl);
let span = item.span();
// the generated type declarations
let mut types: Vec<ImplItemType> = Vec::new();
let mut expected_non_async_types: Vec<(&ImplItemMethod, String)> = Vec::new();
let mut found_non_async_types: Vec<&ImplItemType> = Vec::new();
for inner in &mut item.items {
match inner {
ImplItem::Method(method) => {
if method.sig.asyncness.is_some() {
// if this function is declared async, transform it into a regular function
let typedecl = transform_method(method);
types.push(typedecl);
} else {
// If it's not async, keep track of all required associated types for better
// error reporting.
expected_non_async_types.push((method, associated_type_for_rpc(method)));
}
}
ImplItem::Type(typedecl) => found_non_async_types.push(typedecl),
_ => {}
}
}
if let Err(e) =
verify_types_were_provided(span, &expected_non_async_types, &found_non_async_types)
{
return TokenStream::from(e.to_compile_error());
}
// add the type declarations into the impl block
for t in types.into_iter() {
item.items.push(syn::ImplItem::Type(t));
}
TokenStream::from(quote!(#item))
}
fn verify_types_were_provided(
span: Span,
expected: &[(&ImplItemMethod, String)],
provided: &[&ImplItemType],
) -> syn::Result<()> {
let mut result = Ok(());
for (method, expected) in expected {
if !provided.iter().any(|typedecl| typedecl.ident == expected) {
let mut e = syn::Error::new(
span,
format!("not all trait items implemented, missing: `{expected}`"),
);
let fn_span = method.sig.fn_token.span();
e.extend(syn::Error::new(
fn_span.join(method.sig.ident.span()).unwrap_or(fn_span),
format!(
"hint: `#[tarpc::server]` only rewrites async fns, and `fn {}` is not async",
method.sig.ident
),
));
match result {
Ok(_) => result = Err(e),
Err(ref mut error) => error.extend(Some(e)),
}
}
}
result
}
// Things needed to generate the service items: trait, serve impl, request/response enums, and // Things needed to generate the service items: trait, serve impl, request/response enums, and
// the client stub. // the client stub.
struct ServiceGenerator<'a> { struct ServiceGenerator<'a> {
service_ident: &'a Ident, service_ident: &'a Ident,
client_stub_ident: &'a Ident,
server_ident: &'a Ident, server_ident: &'a Ident,
response_fut_ident: &'a Ident,
response_fut_name: &'a str,
client_ident: &'a Ident, client_ident: &'a Ident,
request_ident: &'a Ident, request_ident: &'a Ident,
response_ident: &'a Ident, response_ident: &'a Ident,
@@ -442,7 +321,6 @@ struct ServiceGenerator<'a> {
attrs: &'a [Attribute], attrs: &'a [Attribute],
rpcs: &'a [RpcMethod], rpcs: &'a [RpcMethod],
camel_case_idents: &'a [Ident], camel_case_idents: &'a [Ident],
future_types: &'a [Type],
method_idents: &'a [&'a Ident], method_idents: &'a [&'a Ident],
request_names: &'a [String], request_names: &'a [String],
method_attrs: &'a [&'a [Attribute]], method_attrs: &'a [&'a [Attribute]],
@@ -458,42 +336,37 @@ impl<'a> ServiceGenerator<'a> {
attrs, attrs,
rpcs, rpcs,
vis, vis,
future_types,
return_types, return_types,
service_ident, service_ident,
client_stub_ident,
request_ident,
response_ident,
server_ident, server_ident,
.. ..
} = self; } = self;
let types_and_fns = rpcs let rpc_fns = rpcs
.iter() .iter()
.zip(future_types.iter())
.zip(return_types.iter()) .zip(return_types.iter())
.map( .map(
|( |(
( RpcMethod {
RpcMethod { attrs, ident, args, ..
attrs, ident, args, .. },
},
future_type,
),
output, output,
)| { )| {
let ty_doc = format!("The response future returned by [`{service_ident}::{ident}`].");
quote! { quote! {
#[doc = #ty_doc]
type #future_type: std::future::Future<Output = #output>;
#( #attrs )* #( #attrs )*
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type; async fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> #output;
} }
}, },
); );
let stub_doc = format!("The stub trait for service [`{service_ident}`].");
quote! { quote! {
#( #attrs )* #( #attrs )*
#vis trait #service_ident: Sized { #vis trait #service_ident: Sized {
#( #types_and_fns )* #( #rpc_fns )*
/// Returns a serving function to use with /// Returns a serving function to use with
/// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute). /// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute).
@@ -501,6 +374,15 @@ impl<'a> ServiceGenerator<'a> {
#server_ident { service: self } #server_ident { service: self }
} }
} }
#[doc = #stub_doc]
#vis trait #client_stub_ident: tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
}
impl<S> #client_stub_ident for S
where S: tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
{
}
} }
} }
@@ -524,7 +406,6 @@ impl<'a> ServiceGenerator<'a> {
server_ident, server_ident,
service_ident, service_ident,
response_ident, response_ident,
response_fut_ident,
camel_case_idents, camel_case_idents,
arg_pats, arg_pats,
method_idents, method_idents,
@@ -533,11 +414,11 @@ impl<'a> ServiceGenerator<'a> {
} = self; } = self;
quote! { quote! {
impl<S> tarpc::server::Serve<#request_ident> for #server_ident<S> impl<S> tarpc::server::Serve for #server_ident<S>
where S: #service_ident where S: #service_ident
{ {
type Req = #request_ident;
type Resp = #response_ident; type Resp = #response_ident;
type Fut = #response_fut_ident<S>;
fn method(&self, req: &#request_ident) -> Option<&'static str> { fn method(&self, req: &#request_ident) -> Option<&'static str> {
Some(match req { Some(match req {
@@ -549,15 +430,16 @@ impl<'a> ServiceGenerator<'a> {
}) })
} }
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut { async fn serve(self, ctx: tarpc::context::Context, req: #request_ident)
-> Result<#response_ident, tarpc::ServerError> {
match req { match req {
#( #(
#request_ident::#camel_case_idents{ #( #arg_pats ),* } => { #request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
#response_fut_ident::#camel_case_idents( Ok(#response_ident::#camel_case_idents(
#service_ident::#method_idents( #service_ident::#method_idents(
self.service, ctx, #( #arg_pats ),* self.service, ctx, #( #arg_pats ),*
) ).await
) ))
} }
)* )*
} }
@@ -608,73 +490,6 @@ impl<'a> ServiceGenerator<'a> {
} }
} }
fn enum_response_future(&self) -> TokenStream2 {
let &Self {
vis,
service_ident,
response_fut_ident,
camel_case_idents,
future_types,
..
} = self;
quote! {
/// A future resolving to a server response.
#[allow(missing_docs)]
#vis enum #response_fut_ident<S: #service_ident> {
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
}
}
}
fn impl_debug_for_response_future(&self) -> TokenStream2 {
let &Self {
service_ident,
response_fut_ident,
response_fut_name,
..
} = self;
quote! {
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct(#response_fut_name).finish()
}
}
}
}
fn impl_future_for_response_future(&self) -> TokenStream2 {
let &Self {
service_ident,
response_fut_ident,
response_ident,
camel_case_idents,
..
} = self;
quote! {
impl<S: #service_ident> std::future::Future for #response_fut_ident<S> {
type Output = #response_ident;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
-> std::task::Poll<#response_ident>
{
unsafe {
match std::pin::Pin::get_unchecked_mut(self) {
#(
#response_fut_ident::#camel_case_idents(resp) =>
std::pin::Pin::new_unchecked(resp)
.poll(cx)
.map(#response_ident::#camel_case_idents),
)*
}
}
}
}
}
}
fn struct_client(&self) -> TokenStream2 { fn struct_client(&self) -> TokenStream2 {
let &Self { let &Self {
vis, vis,
@@ -689,7 +504,9 @@ impl<'a> ServiceGenerator<'a> {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
/// The client stub that makes RPC calls to the server. All request methods return /// The client stub that makes RPC calls to the server. All request methods return
/// [Futures](std::future::Future). /// [Futures](std::future::Future).
#vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>); #vis struct #client_ident<
Stub = tarpc::client::Channel<#request_ident, #response_ident>
>(Stub);
} }
} }
@@ -719,6 +536,17 @@ impl<'a> ServiceGenerator<'a> {
dispatch: new_client.dispatch, dispatch: new_client.dispatch,
} }
} }
}
impl<Stub> From<Stub> for #client_ident<Stub>
where Stub: tarpc::client::stub::Stub<
Req = #request_ident,
Resp = #response_ident>
{
/// Returns a new client stub that sends requests over the given transport.
fn from(stub: Stub) -> Self {
#client_ident(stub)
}
} }
} }
@@ -741,7 +569,11 @@ impl<'a> ServiceGenerator<'a> {
} = self; } = self;
quote! { quote! {
impl #client_ident { impl<Stub> #client_ident<Stub>
where Stub: tarpc::client::stub::Stub<
Req = #request_ident,
Resp = #response_ident>
{
#( #(
#[allow(unused)] #[allow(unused)]
#( #method_attrs )* #( #method_attrs )*
@@ -770,9 +602,6 @@ impl<'a> ToTokens for ServiceGenerator<'a> {
self.impl_serve_for_server(), self.impl_serve_for_server(),
self.enum_request(), self.enum_request(),
self.enum_response(), self.enum_response(),
self.enum_response_future(),
self.impl_debug_for_response_future(),
self.impl_future_for_response_future(),
self.struct_client(), self.struct_client(),
self.impl_client_new(), self.impl_client_new(),
self.impl_client_rpc_methods(), self.impl_client_rpc_methods(),

View File

@@ -1,8 +1,3 @@
use assert_type_eq::assert_type_eq;
use futures::Future;
use std::pin::Pin;
use tarpc::context;
// these need to be out here rather than inside the function so that the // these need to be out here rather than inside the function so that the
// assert_type_eq macro can pick them up. // assert_type_eq macro can pick them up.
#[tarpc::service] #[tarpc::service]
@@ -12,42 +7,6 @@ trait Foo {
async fn baz(); async fn baz();
} }
#[test]
fn type_generation_works() {
#[tarpc::server]
impl Foo for () {
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
(s, i)
}
async fn bar(self, _: context::Context, s: String) -> String {
s
}
async fn baz(self, _: context::Context) {}
}
// the assert_type_eq macro can only be used once per block.
{
assert_type_eq!(
<() as Foo>::TwoPartFut,
Pin<Box<dyn Future<Output = (String, i32)> + Send>>
);
}
{
assert_type_eq!(
<() as Foo>::BarFut,
Pin<Box<dyn Future<Output = String> + Send>>
);
}
{
assert_type_eq!(
<() as Foo>::BazFut,
Pin<Box<dyn Future<Output = ()> + Send>>
);
}
}
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
#[test] #[test]
fn raw_idents_work() { fn raw_idents_work() {
@@ -59,24 +18,6 @@ fn raw_idents_work() {
async fn r#fn(r#impl: r#yield) -> r#yield; async fn r#fn(r#impl: r#yield) -> r#yield;
async fn r#async(); async fn r#async();
} }
#[tarpc::server]
impl r#trait for () {
async fn r#await(
self,
_: context::Context,
r#struct: r#yield,
r#enum: i32,
) -> (r#yield, i32) {
(r#struct, r#enum)
}
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
r#impl
}
async fn r#async(self, _: context::Context) {}
}
} }
#[test] #[test]
@@ -100,45 +41,4 @@ fn syntax() {
#[doc = "attr"] #[doc = "attr"]
async fn one_arg_implicit_return_error(one: String); async fn one_arg_implicit_return_error(one: String);
} }
#[tarpc::server]
impl Syntax for () {
#[deny(warnings)]
#[allow(non_snake_case)]
async fn TestCamelCaseDoesntConflict(self, _: context::Context) {}
async fn hello(self, _: context::Context) -> String {
String::new()
}
async fn attr(self, _: context::Context, _s: String) -> String {
String::new()
}
async fn no_args_no_return(self, _: context::Context) {}
async fn no_args(self, _: context::Context) -> () {}
async fn one_arg(self, _: context::Context, _one: String) -> i32 {
0
}
async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {}
async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String {
String::new()
}
async fn no_args_ret_error(self, _: context::Context) -> i32 {
0
}
async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String {
String::new()
}
async fn no_arg_implicit_return_error(self, _: context::Context) {}
async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {}
}
} }

View File

@@ -2,8 +2,6 @@ use tarpc::context;
#[test] #[test]
fn att_service_trait() { fn att_service_trait() {
use futures::future::{ready, Ready};
#[tarpc::service] #[tarpc::service]
trait Foo { trait Foo {
async fn two_part(s: String, i: i32) -> (String, i32); async fn two_part(s: String, i: i32) -> (String, i32);
@@ -12,19 +10,16 @@ fn att_service_trait() {
} }
impl Foo for () { impl Foo for () {
type TwoPartFut = Ready<(String, i32)>; async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
fn two_part(self, _: context::Context, s: String, i: i32) -> Self::TwoPartFut { (s, i)
ready((s, i))
} }
type BarFut = Ready<String>; async fn bar(self, _: context::Context, s: String) -> String {
fn bar(self, _: context::Context, s: String) -> Self::BarFut { s
ready(s)
} }
type BazFut = Ready<()>; async fn baz(self, _: context::Context) {
fn baz(self, _: context::Context) -> Self::BazFut { ()
ready(())
} }
} }
} }
@@ -32,8 +27,6 @@ fn att_service_trait() {
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
#[test] #[test]
fn raw_idents() { fn raw_idents() {
use futures::future::{ready, Ready};
type r#yield = String; type r#yield = String;
#[tarpc::service] #[tarpc::service]
@@ -44,19 +37,21 @@ fn raw_idents() {
} }
impl r#trait for () { impl r#trait for () {
type AwaitFut = Ready<(r#yield, i32)>; async fn r#await(
fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut { self,
ready((r#struct, r#enum)) _: context::Context,
r#struct: r#yield,
r#enum: i32,
) -> (r#yield, i32) {
(r#struct, r#enum)
} }
type FnFut = Ready<r#yield>; async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut { r#impl
ready(r#impl)
} }
type AsyncFut = Ready<()>; async fn r#async(self, _: context::Context) {
fn r#async(self, _: context::Context) -> Self::AsyncFut { ()
ready(())
} }
} }
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "tarpc" name = "tarpc"
version = "0.32.0" version = "0.34.0"
rust-version = "1.58.0" rust-version = "1.58.0"
authors = [ authors = [
"Adam Wright <adam.austin.wright@gmail.com>", "Adam Wright <adam.austin.wright@gmail.com>",
@@ -19,7 +19,7 @@ description = "An RPC framework for Rust with a focus on ease of use."
[features] [features]
default = [] default = []
serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive"] serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive", "serde/rc"]
tokio1 = ["tokio/rt"] tokio1 = ["tokio/rt"]
serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"] serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"]
serde-transport-json = ["tokio-serde/json"] serde-transport-json = ["tokio-serde/json"]
@@ -49,7 +49,7 @@ pin-project = "1.0"
rand = "0.8" rand = "0.8"
serde = { optional = true, version = "1.0", features = ["derive"] } serde = { optional = true, version = "1.0", features = ["derive"] }
static_assertions = "1.1.0" static_assertions = "1.1.0"
tarpc-plugins = { path = "../plugins", version = "0.12" } tarpc-plugins = { path = "../plugins", version = "0.13" }
thiserror = "1.0" thiserror = "1.0"
tokio = { version = "1", features = ["time"] } tokio = { version = "1", features = ["time"] }
tokio-util = { version = "0.7.3", features = ["time"] } tokio-util = { version = "0.7.3", features = ["time"] }
@@ -58,8 +58,8 @@ tracing = { version = "0.1", default-features = false, features = [
"attributes", "attributes",
"log", "log",
] } ] }
tracing-opentelemetry = { version = "0.17.2", default-features = false } tracing-opentelemetry = { version = "0.18.0", default-features = false }
opentelemetry = { version = "0.17.0", default-features = false } opentelemetry = { version = "0.18.0", default-features = false }
[dev-dependencies] [dev-dependencies]
@@ -68,14 +68,15 @@ bincode = "1.3"
bytes = { version = "1", features = ["serde"] } bytes = { version = "1", features = ["serde"] }
flate2 = "1.0" flate2 = "1.0"
futures-test = "0.3" futures-test = "0.3"
opentelemetry = { version = "0.17.0", default-features = false, features = [ opentelemetry = { version = "0.18.0", default-features = false, features = [
"rt-tokio", "rt-tokio",
] } ] }
opentelemetry-jaeger = { version = "0.16.0", features = ["rt-tokio"] } opentelemetry-jaeger = { version = "0.17.0", features = ["rt-tokio"] }
pin-utils = "0.1.0-alpha" pin-utils = "0.1.0-alpha"
serde_bytes = "0.11" serde_bytes = "0.11"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tokio = { version = "1", features = ["full", "test-util"] } tokio = { version = "1", features = ["full", "test-util", "tracing"] }
console-subscriber = "0.1"
tokio-serde = { version = "0.8", features = ["json", "bincode"] } tokio-serde = { version = "0.8", features = ["json", "bincode"] }
trybuild = "1.0" trybuild = "1.0"
tokio-rustls = "0.23" tokio-rustls = "0.23"

View File

@@ -1,5 +1,11 @@
// Copyright 2022 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression}; use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt}; use futures::{prelude::*, Sink, SinkExt, Stream, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_bytes::ByteBuf; use serde_bytes::ByteBuf;
use std::{io, io::Read, io::Write}; use std::{io, io::Read, io::Write};
@@ -99,13 +105,16 @@ pub trait World {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct HelloServer; struct HelloServer;
#[tarpc::server]
impl World for HelloServer { impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String { async fn hello(self, _: context::Context, name: String) -> String {
format!("Hey, {name}!") format!("Hey, {name}!")
} }
} }
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let mut incoming = tcp::listen("localhost:0", Bincode::default).await?; let mut incoming = tcp::listen("localhost:0", Bincode::default).await?;
@@ -114,6 +123,7 @@ async fn main() -> anyhow::Result<()> {
let transport = incoming.next().await.unwrap().unwrap(); let transport = incoming.next().await.unwrap().unwrap();
BaseChannel::with_defaults(add_compression(transport)) BaseChannel::with_defaults(add_compression(transport))
.execute(HelloServer.serve()) .execute(HelloServer.serve())
.for_each(spawn)
.await; .await;
}); });

View File

@@ -1,3 +1,10 @@
// Copyright 2022 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use futures::prelude::*;
use tarpc::context::Context; use tarpc::context::Context;
use tarpc::serde_transport as transport; use tarpc::serde_transport as transport;
use tarpc::server::{BaseChannel, Channel}; use tarpc::server::{BaseChannel, Channel};
@@ -13,7 +20,6 @@ pub trait PingService {
#[derive(Clone)] #[derive(Clone)]
struct Service; struct Service;
#[tarpc::server]
impl PingService for Service { impl PingService for Service {
async fn ping(self, _: Context) {} async fn ping(self, _: Context) {}
} }
@@ -26,13 +32,18 @@ async fn main() -> anyhow::Result<()> {
let listener = UnixListener::bind(bind_addr).unwrap(); let listener = UnixListener::bind(bind_addr).unwrap();
let codec_builder = LengthDelimitedCodec::builder(); let codec_builder = LengthDelimitedCodec::builder();
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
let (conn, _addr) = listener.accept().await.unwrap(); let (conn, _addr) = listener.accept().await.unwrap();
let framed = codec_builder.new_framed(conn); let framed = codec_builder.new_framed(conn);
let transport = transport::new(framed, Bincode::default()); let transport = transport::new(framed, Bincode::default());
let fut = BaseChannel::with_defaults(transport).execute(Service.serve()); let fut = BaseChannel::with_defaults(transport)
.execute(Service.serve())
.for_each(spawn);
tokio::spawn(fut); tokio::spawn(fut);
} }
}); });

View File

@@ -79,7 +79,6 @@ struct Subscriber {
topics: Vec<String>, topics: Vec<String>,
} }
#[tarpc::server]
impl subscriber::Subscriber for Subscriber { impl subscriber::Subscriber for Subscriber {
async fn topics(self, _: context::Context) -> Vec<String> { async fn topics(self, _: context::Context) -> Vec<String> {
self.topics.clone() self.topics.clone()
@@ -117,7 +116,8 @@ impl Subscriber {
)) ))
} }
}; };
let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve())); let (handler, abort_handle) =
future::abortable(handler.execute(subscriber.serve()).for_each(spawn));
tokio::spawn(async move { tokio::spawn(async move {
match handler.await { match handler.await {
Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."), Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."),
@@ -143,6 +143,10 @@ struct PublisherAddrs {
subscriptions: SocketAddr, subscriptions: SocketAddr,
} }
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
impl Publisher { impl Publisher {
async fn start(self) -> io::Result<PublisherAddrs> { async fn start(self) -> io::Result<PublisherAddrs> {
let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?;
@@ -162,6 +166,7 @@ impl Publisher {
server::BaseChannel::with_defaults(publisher) server::BaseChannel::with_defaults(publisher)
.execute(self.serve()) .execute(self.serve())
.for_each(spawn)
.await .await
}); });
@@ -257,7 +262,6 @@ impl Publisher {
} }
} }
#[tarpc::server]
impl publisher::Publisher for Publisher { impl publisher::Publisher for Publisher {
async fn publish(self, _: context::Context, topic: String, message: String) { async fn publish(self, _: context::Context, topic: String, message: String) {
info!("received message to publish."); info!("received message to publish.");
@@ -282,7 +286,7 @@ impl publisher::Publisher for Publisher {
/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend. /// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend.
fn init_tracing(service_name: &str) -> anyhow::Result<()> { fn init_tracing(service_name: &str) -> anyhow::Result<()> {
env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12"); env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12");
let tracer = opentelemetry_jaeger::new_pipeline() let tracer = opentelemetry_jaeger::new_agent_pipeline()
.with_service_name(service_name) .with_service_name(service_name)
.with_max_packet_size(2usize.pow(13)) .with_max_packet_size(2usize.pow(13))
.install_batch(opentelemetry::runtime::Tokio)?; .install_batch(opentelemetry::runtime::Tokio)?;

View File

@@ -4,7 +4,7 @@
// license that can be found in the LICENSE file or at // license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT. // https://opensource.org/licenses/MIT.
use futures::future::{self, Ready}; use futures::prelude::*;
use tarpc::{ use tarpc::{
client, context, client, context,
server::{self, Channel}, server::{self, Channel},
@@ -23,22 +23,21 @@ pub trait World {
struct HelloServer; struct HelloServer;
impl World for HelloServer { impl World for HelloServer {
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and async fn hello(self, _: context::Context, name: String) -> String {
// an associated type representing the future output by the fn. format!("Hello, {name}!")
type HelloFut = Ready<String>;
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!("Hello, {name}!"))
} }
} }
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = server::BaseChannel::with_defaults(server_transport); let server = server::BaseChannel::with_defaults(server_transport);
tokio::spawn(server.execute(HelloServer.serve())); tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn));
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
// that takes a config and any Transport as input. // that takes a config and any Transport as input.

View File

@@ -1,3 +1,10 @@
// Copyright 2023 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use futures::prelude::*;
use rustls_pemfile::certs; use rustls_pemfile::certs;
use std::io::{BufReader, Cursor}; use std::io::{BufReader, Cursor};
use std::net::{IpAddr, Ipv4Addr}; use std::net::{IpAddr, Ipv4Addr};
@@ -6,8 +13,8 @@ use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient;
use std::sync::Arc; use std::sync::Arc;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_rustls::rustls::{self, Certificate, OwnedTrustAnchor, RootCertStore}; use tokio_rustls::rustls::{self, RootCertStore};
use tokio_rustls::{webpki, TlsAcceptor, TlsConnector}; use tokio_rustls::{TlsAcceptor, TlsConnector};
use tarpc::context::Context; use tarpc::context::Context;
use tarpc::serde_transport as transport; use tarpc::serde_transport as transport;
@@ -23,7 +30,6 @@ pub trait PingService {
#[derive(Clone)] #[derive(Clone)]
struct Service; struct Service;
#[tarpc::server]
impl PingService for Service { impl PingService for Service {
async fn ping(self, _: Context) -> String { async fn ping(self, _: Context) -> String {
"🔒".to_owned() "🔒".to_owned()
@@ -32,7 +38,7 @@ impl PingService for Service {
// certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca // certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca
// used on client-side for server tls // used on client-side for server tls
const END_CHAIN: &[u8] = include_bytes!("certs/eddsa/end.chain"); const END_CHAIN: &str = include_str!("certs/eddsa/end.chain");
// used on client-side for client-auth // used on client-side for client-auth
const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key"); const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key");
const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert"); const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert");
@@ -43,6 +49,14 @@ const END_PRIVATEKEY: &str = include_str!("certs/eddsa/end.key");
// used on server-side for client-auth // used on server-side for client-auth
const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain"); const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain");
pub fn load_certs(data: &str) -> Vec<rustls::Certificate> {
certs(&mut BufReader::new(Cursor::new(data)))
.unwrap()
.into_iter()
.map(rustls::Certificate)
.collect()
}
pub fn load_private_key(key: &str) -> rustls::PrivateKey { pub fn load_private_key(key: &str) -> rustls::PrivateKey {
let mut reader = BufReader::new(Cursor::new(key)); let mut reader = BufReader::new(Cursor::new(key));
loop { loop {
@@ -57,27 +71,22 @@ pub fn load_private_key(key: &str) -> rustls::PrivateKey {
panic!("no keys found in {:?} (encrypted keys not supported)", key); panic!("no keys found in {:?} (encrypted keys not supported)", key);
} }
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
// -------------------- start here to setup tls tcp tokio stream -------------------------- // -------------------- start here to setup tls tcp tokio stream --------------------------
// ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs // ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs
// ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs // ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs
let cert = certs(&mut BufReader::new(Cursor::new(END_CERT))) let cert = load_certs(END_CERT);
.unwrap()
.into_iter()
.map(rustls::Certificate)
.collect();
let key = load_private_key(END_PRIVATEKEY); let key = load_private_key(END_PRIVATEKEY);
let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000); let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000);
// ------------- server side client_auth cert loading start // ------------- server side client_auth cert loading start
let roots: Vec<Certificate> = certs(&mut BufReader::new(Cursor::new(CLIENT_CHAIN_CLIENT_AUTH)))
.unwrap()
.into_iter()
.map(rustls::Certificate)
.collect();
let mut client_auth_roots = RootCertStore::empty(); let mut client_auth_roots = RootCertStore::empty();
for root in roots { for root in load_certs(CLIENT_CHAIN_CLIENT_AUTH) {
client_auth_roots.add(&root).unwrap(); client_auth_roots.add(&root).unwrap();
} }
let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots); let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots);
@@ -96,38 +105,27 @@ async fn main() -> anyhow::Result<()> {
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
let (stream, _peer_addr) = listener.accept().await.unwrap(); let (stream, _peer_addr) = listener.accept().await.unwrap();
let acceptor = acceptor.clone();
let tls_stream = acceptor.accept(stream).await.unwrap(); let tls_stream = acceptor.accept(stream).await.unwrap();
let framed = codec_builder.new_framed(tls_stream); let framed = codec_builder.new_framed(tls_stream);
let transport = transport::new(framed, Bincode::default()); let transport = transport::new(framed, Bincode::default());
let fut = BaseChannel::with_defaults(transport).execute(Service.serve()); let fut = BaseChannel::with_defaults(transport)
.execute(Service.serve())
.for_each(spawn);
tokio::spawn(fut); tokio::spawn(fut);
} }
}); });
// ---------------------- client connection --------------------- // ---------------------- client connection ---------------------
// cert loading from: https://github.com/tokio-rs/tls/blob/357bc562483dcf04c1f8d08bd1a831b144bf7d4c/tokio-rustls/tests/test.rs#L113
// tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs // tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs
let chain = certs(&mut std::io::Cursor::new(END_CHAIN)).unwrap();
let mut root_store = rustls::RootCertStore::empty(); let mut root_store = rustls::RootCertStore::empty();
root_store.add_server_trust_anchors(chain.iter().map(|cert| { for root in load_certs(END_CHAIN) {
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); root_store.add(&root).unwrap();
OwnedTrustAnchor::from_subject_spki_name_constraints( }
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH); let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH);
let client_auth_certs: Vec<Certificate> = let client_auth_certs = load_certs(CLIENT_CERT_CLIENT_AUTH);
certs(&mut BufReader::new(Cursor::new(CLIENT_CERT_CLIENT_AUTH)))
.unwrap()
.into_iter()
.map(rustls::Certificate)
.collect();
let config = rustls::ClientConfig::builder() let config = rustls::ClientConfig::builder()
.with_safe_defaults() .with_safe_defaults()

View File

@@ -4,13 +4,34 @@
// license that can be found in the LICENSE file or at // license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT. // https://opensource.org/licenses/MIT.
use crate::{add::Add as AddService, double::Double as DoubleService}; use crate::{
use futures::{future, prelude::*}; add::{Add as AddService, AddStub},
use tarpc::{ double::Double as DoubleService,
client, context,
server::{incoming::Incoming, BaseChannel},
tokio_serde::formats::Json,
}; };
use futures::{future, prelude::*};
use std::{
io,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use tarpc::{
client::{
self,
stub::{load_balance, retry},
RpcError,
},
context, serde_transport,
server::{
incoming::{spawn_incoming, Incoming},
request_hook::{self, BeforeRequestList},
BaseChannel,
},
tokio_serde::formats::Json,
ClientMessage, Response, ServerError, Transport,
};
use tokio::net::TcpStream;
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
pub mod add { pub mod add {
@@ -32,7 +53,6 @@ pub mod double {
#[derive(Clone)] #[derive(Clone)]
struct AddServer; struct AddServer;
#[tarpc::server]
impl AddService for AddServer { impl AddService for AddServer {
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
x + y x + y
@@ -40,12 +60,14 @@ impl AddService for AddServer {
} }
#[derive(Clone)] #[derive(Clone)]
struct DoubleServer { struct DoubleServer<Stub> {
add_client: add::AddClient, add_client: add::AddClient<Stub>,
} }
#[tarpc::server] impl<Stub> DoubleService for DoubleServer<Stub>
impl DoubleService for DoubleServer { where
Stub: AddStub + Clone + Send + Sync + 'static,
{
async fn double(self, _: context::Context, x: i32) -> Result<i32, String> { async fn double(self, _: context::Context, x: i32) -> Result<i32, String> {
self.add_client self.add_client
.add(context::current(), x, x) .add(context::current(), x, x)
@@ -55,7 +77,7 @@ impl DoubleService for DoubleServer {
} }
fn init_tracing(service_name: &str) -> anyhow::Result<()> { fn init_tracing(service_name: &str) -> anyhow::Result<()> {
let tracer = opentelemetry_jaeger::new_pipeline() let tracer = opentelemetry_jaeger::new_agent_pipeline()
.with_service_name(service_name) .with_service_name(service_name)
.with_auto_split_batch(true) .with_auto_split_batch(true)
.with_max_packet_size(2usize.pow(13)) .with_max_packet_size(2usize.pow(13))
@@ -70,32 +92,88 @@ fn init_tracing(service_name: &str) -> anyhow::Result<()> {
Ok(()) Ok(())
} }
async fn listen_on_random_port<Item, SinkItem>() -> anyhow::Result<(
impl Stream<Item = serde_transport::Transport<TcpStream, Item, SinkItem, Json<Item, SinkItem>>>,
std::net::SocketAddr,
)>
where
Item: for<'de> serde::Deserialize<'de>,
SinkItem: serde::Serialize,
{
let listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await?
.filter_map(|r| future::ready(r.ok()))
.take(1);
let addr = listener.get_ref().get_ref().local_addr();
Ok((listener, addr))
}
fn make_stub<Req, Resp, const N: usize>(
backends: [impl Transport<ClientMessage<Arc<Req>>, Response<Resp>> + Send + Sync + 'static; N],
) -> retry::Retry<
impl Fn(&Result<Resp, RpcError>, u32) -> bool + Clone,
load_balance::RoundRobin<client::Channel<Arc<Req>, Resp>>,
>
where
Req: Send + Sync + 'static,
Resp: Send + Sync + 'static,
{
let stub = load_balance::RoundRobin::new(
backends
.into_iter()
.map(|transport| tarpc::client::new(client::Config::default(), transport).spawn())
.collect(),
);
let stub = retry::Retry::new(stub, |resp, attempts| {
if let Err(e) = resp {
tracing::warn!("Got an error: {e:?}");
attempts < 3
} else {
false
}
});
stub
}
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
init_tracing("tarpc_tracing_example")?; init_tracing("tarpc_tracing_example")?;
let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) let (add_listener1, addr1) = listen_on_random_port().await?;
.await? let (add_listener2, addr2) = listen_on_random_port().await?;
.filter_map(|r| future::ready(r.ok())); let something_bad_happened = Arc::new(AtomicBool::new(false));
let addr = add_listener.get_ref().local_addr(); let server = request_hook::before()
let add_server = add_listener .then_fn(move |_: &mut _, _: &_| {
.map(BaseChannel::with_defaults) let something_bad_happened = something_bad_happened.clone();
.take(1) async move {
.execute(AddServer.serve()); if something_bad_happened.fetch_xor(true, Ordering::Relaxed) {
tokio::spawn(add_server); Err(ServerError::new(
io::ErrorKind::NotFound,
"Gamma Ray!".into(),
))
} else {
Ok(())
}
}
})
.serving(AddServer.serve());
let add_server = add_listener1
.chain(add_listener2)
.map(BaseChannel::with_defaults);
tokio::spawn(spawn_incoming(add_server.execute(server)));
let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; let add_client = add::AddClient::from(make_stub([
let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn(); tarpc::serde_transport::tcp::connect(addr1, Json::default).await?,
tarpc::serde_transport::tcp::connect(addr2, Json::default).await?,
]));
let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default)
.await? .await?
.filter_map(|r| future::ready(r.ok())); .filter_map(|r| future::ready(r.ok()));
let addr = double_listener.get_ref().local_addr(); let addr = double_listener.get_ref().local_addr();
let double_server = double_listener let double_server = double_listener.map(BaseChannel::with_defaults).take(1);
.map(BaseChannel::with_defaults) let server = DoubleServer { add_client }.serve();
.take(1) tokio::spawn(spawn_incoming(double_server.execute(server)));
.execute(DoubleServer { add_client }.serve());
tokio::spawn(double_server);
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
let double_client = let double_client =

View File

@@ -7,6 +7,7 @@
//! Provides a client that connects to a server and sends multiplexed requests. //! Provides a client that connects to a server and sends multiplexed requests.
mod in_flight_requests; mod in_flight_requests;
pub mod stub;
use crate::{ use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation}, cancellations::{cancellations, CanceledRequests, RequestCancellation},
@@ -500,7 +501,6 @@ where
// poll_next_request only returns Ready if there is room to buffer another request. // poll_next_request only returns Ready if there is room to buffer another request.
// Therefore, we can call write_request without fear of erroring due to a full // Therefore, we can call write_request without fear of erroring due to a full
// buffer. // buffer.
let request_id = request_id;
let request = ClientMessage::Request(Request { let request = ClientMessage::Request(Request {
id: request_id, id: request_id,
message: request, message: request,
@@ -543,10 +543,15 @@ where
/// Sends a server response to the client task that initiated the associated request. /// Sends a server response to the client task that initiated the associated request.
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool { fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
self.in_flight_requests().complete_request( if let Some(span) = self.in_flight_requests().complete_request(
response.request_id, response.request_id,
response.message.map_err(RpcError::Server), response.message.map_err(RpcError::Server),
) ) {
let _entered = span.enter();
tracing::info!("ReceiveResponse");
return true;
}
false
} }
} }

View File

@@ -77,20 +77,18 @@ impl<Res> InFlightRequests<Res> {
} }
/// Removes a request without aborting. Returns true iff the request was found. /// Removes a request without aborting. Returns true iff the request was found.
pub fn complete_request(&mut self, request_id: u64, result: Res) -> bool { pub fn complete_request(&mut self, request_id: u64, result: Res) -> Option<Span> {
if let Some(request_data) = self.request_data.remove(&request_id) { if let Some(request_data) = self.request_data.remove(&request_id) {
let _entered = request_data.span.enter();
tracing::info!("ReceiveResponse");
self.request_data.compact(0.1); self.request_data.compact(0.1);
self.deadlines.remove(&request_data.deadline_key); self.deadlines.remove(&request_data.deadline_key);
let _ = request_data.response_completion.send(result); let _ = request_data.response_completion.send(result);
return true; return Some(request_data.span);
} }
tracing::debug!("No in-flight request found for request_id = {request_id}."); tracing::debug!("No in-flight request found for request_id = {request_id}.");
// If the response completion was absent, then the request was already canceled. // If the response completion was absent, then the request was already canceled.
false None
} }
/// Completes all requests using the provided function. /// Completes all requests using the provided function.

45
tarpc/src/client/stub.rs Normal file
View File

@@ -0,0 +1,45 @@
//! Provides a Stub trait, implemented by types that can call remote services.
use crate::{
client::{Channel, RpcError},
context,
};
pub mod load_balance;
pub mod retry;
#[cfg(test)]
mod mock;
/// A connection to a remote service.
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
#[allow(async_fn_in_trait)]
pub trait Stub {
/// The service request type.
type Req;
/// The service response type.
type Resp;
/// Calls a remote service.
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Result<Self::Resp, RpcError>;
}
impl<Req, Resp> Stub for Channel<Req, Resp> {
type Req = Req;
type Resp = Resp;
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Req,
) -> Result<Self::Resp, RpcError> {
Self::call(self, ctx, request_name, request).await
}
}

View File

@@ -0,0 +1,279 @@
//! Provides load-balancing [Stubs](crate::client::stub::Stub).
pub use consistent_hash::ConsistentHash;
pub use round_robin::RoundRobin;
/// Provides a stub that load-balances with a simple round-robin strategy.
mod round_robin {
use crate::{
client::{stub, RpcError},
context,
};
use cycle::AtomicCycle;
impl<Stub> stub::Stub for RoundRobin<Stub>
where
Stub: stub::Stub,
{
type Req = Stub::Req;
type Resp = Stub::Resp;
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Result<Stub::Resp, RpcError> {
let next = self.stubs.next();
next.call(ctx, request_name, request).await
}
}
/// A Stub that load-balances across backing stubs by round robin.
#[derive(Clone, Debug)]
pub struct RoundRobin<Stub> {
stubs: AtomicCycle<Stub>,
}
impl<Stub> RoundRobin<Stub>
where
Stub: stub::Stub,
{
/// Returns a new RoundRobin stub.
pub fn new(stubs: Vec<Stub>) -> Self {
Self {
stubs: AtomicCycle::new(stubs),
}
}
}
mod cycle {
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
/// Cycles endlessly and atomically over a collection of elements of type T.
#[derive(Clone, Debug)]
pub struct AtomicCycle<T>(Arc<State<T>>);
#[derive(Debug)]
struct State<T> {
elements: Vec<T>,
next: AtomicUsize,
}
impl<T> AtomicCycle<T> {
pub fn new(elements: Vec<T>) -> Self {
Self(Arc::new(State {
elements,
next: Default::default(),
}))
}
pub fn next(&self) -> &T {
self.0.next()
}
}
impl<T> State<T> {
pub fn next(&self) -> &T {
let next = self.next.fetch_add(1, Ordering::Relaxed);
&self.elements[next % self.elements.len()]
}
}
#[test]
fn test_cycle() {
let cycle = AtomicCycle::new(vec![1, 2, 3]);
assert_eq!(cycle.next(), &1);
assert_eq!(cycle.next(), &2);
assert_eq!(cycle.next(), &3);
assert_eq!(cycle.next(), &1);
}
}
}
/// Provides a stub that load-balances with a consistent hashing strategy.
///
/// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use
/// the same stub.
mod consistent_hash {
use crate::{
client::{stub, RpcError},
context,
};
use std::{
collections::hash_map::RandomState,
hash::{BuildHasher, Hash, Hasher},
num::TryFromIntError,
};
impl<Stub, S> stub::Stub for ConsistentHash<Stub, S>
where
Stub: stub::Stub,
Stub::Req: Hash,
S: BuildHasher,
{
type Req = Stub::Req;
type Resp = Stub::Resp;
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Result<Stub::Resp, RpcError> {
let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect(
"invariant broken: stubs_len is not larger than a usize, \
so the hash modulo stubs_len should always fit in a usize",
);
let next = &self.stubs[index];
next.call(ctx, request_name, request).await
}
}
/// A Stub that load-balances across backing stubs by round robin.
#[derive(Clone, Debug)]
pub struct ConsistentHash<Stub, S = RandomState> {
stubs: Vec<Stub>,
stubs_len: u64,
hasher: S,
}
impl<Stub> ConsistentHash<Stub, RandomState>
where
Stub: stub::Stub,
Stub::Req: Hash,
{
/// Returns a new RoundRobin stub.
/// Returns an err if the length of `stubs` overflows a u64.
pub fn new(stubs: Vec<Stub>) -> Result<Self, TryFromIntError> {
Ok(Self {
stubs_len: stubs.len().try_into()?,
stubs,
hasher: RandomState::new(),
})
}
}
impl<Stub, S> ConsistentHash<Stub, S>
where
Stub: stub::Stub,
Stub::Req: Hash,
S: BuildHasher,
{
/// Returns a new RoundRobin stub.
/// Returns an err if the length of `stubs` overflows a u64.
pub fn with_hasher(stubs: Vec<Stub>, hasher: S) -> Result<Self, TryFromIntError> {
Ok(Self {
stubs_len: stubs.len().try_into()?,
stubs,
hasher,
})
}
fn hash_request(&self, req: &Stub::Req) -> u64 {
let mut hasher = self.hasher.build_hasher();
req.hash(&mut hasher);
hasher.finish()
}
}
#[cfg(test)]
mod tests {
use super::ConsistentHash;
use crate::{
client::stub::{mock::Mock, Stub},
context,
};
use std::{
collections::HashMap,
hash::{BuildHasher, Hash, Hasher},
rc::Rc,
};
#[tokio::test]
async fn test() -> anyhow::Result<()> {
let stub = ConsistentHash::<_, FakeHasherBuilder>::with_hasher(
vec![
// For easier reading of the assertions made in this test, each Mock's response
// value is equal to a hash value that should map to its index: 3 % 3 = 0, 1 %
// 3 = 1, etc.
Mock::new([('a', 3), ('b', 3), ('c', 3)]),
Mock::new([('a', 1), ('b', 1), ('c', 1)]),
Mock::new([('a', 2), ('b', 2), ('c', 2)]),
],
FakeHasherBuilder::new([('a', 1), ('b', 2), ('c', 3)]),
)?;
for _ in 0..2 {
let resp = stub.call(context::current(), "", 'a').await?;
assert_eq!(resp, 1);
let resp = stub.call(context::current(), "", 'b').await?;
assert_eq!(resp, 2);
let resp = stub.call(context::current(), "", 'c').await?;
assert_eq!(resp, 3);
}
Ok(())
}
struct HashRecorder(Vec<u8>);
impl Hasher for HashRecorder {
fn write(&mut self, bytes: &[u8]) {
self.0 = Vec::from(bytes);
}
fn finish(&self) -> u64 {
0
}
}
struct FakeHasherBuilder {
recorded_hashes: Rc<HashMap<Vec<u8>, u64>>,
}
struct FakeHasher {
recorded_hashes: Rc<HashMap<Vec<u8>, u64>>,
output: u64,
}
impl BuildHasher for FakeHasherBuilder {
type Hasher = FakeHasher;
fn build_hasher(&self) -> Self::Hasher {
FakeHasher {
recorded_hashes: self.recorded_hashes.clone(),
output: 0,
}
}
}
impl FakeHasherBuilder {
fn new<T: Hash, const N: usize>(fake_hashes: [(T, u64); N]) -> Self {
let mut recorded_hashes = HashMap::new();
for (to_hash, fake_hash) in fake_hashes {
let mut recorder = HashRecorder(vec![]);
to_hash.hash(&mut recorder);
recorded_hashes.insert(recorder.0, fake_hash);
}
Self {
recorded_hashes: Rc::new(recorded_hashes),
}
}
}
impl Hasher for FakeHasher {
fn write(&mut self, bytes: &[u8]) {
if let Some(hash) = self.recorded_hashes.get(bytes) {
self.output = *hash;
}
}
fn finish(&self) -> u64 {
self.output
}
}
}
}

View File

@@ -0,0 +1,49 @@
use crate::{
client::{stub::Stub, RpcError},
context, ServerError,
};
use std::{collections::HashMap, hash::Hash, io};
/// A mock stub that returns user-specified responses.
pub struct Mock<Req, Resp> {
responses: HashMap<Req, Resp>,
}
impl<Req, Resp> Mock<Req, Resp>
where
Req: Eq + Hash,
{
/// Returns a new mock, mocking the specified (request, response) pairs.
pub fn new<const N: usize>(responses: [(Req, Resp); N]) -> Self {
Self {
responses: HashMap::from(responses),
}
}
}
impl<Req, Resp> Stub for Mock<Req, Resp>
where
Req: Eq + Hash,
Resp: Clone,
{
type Req = Req;
type Resp = Resp;
async fn call(
&self,
_: context::Context,
_: &'static str,
request: Self::Req,
) -> Result<Resp, RpcError> {
self.responses
.get(&request)
.cloned()
.map(Ok)
.unwrap_or_else(|| {
Err(RpcError::Server(ServerError {
kind: io::ErrorKind::NotFound,
detail: "mock (request, response) entry not found".into(),
}))
})
}
}

View File

@@ -0,0 +1,56 @@
//! Provides a stub that retries requests based on response contents..
use crate::{
client::{stub, RpcError},
context,
};
use std::sync::Arc;
impl<Stub, Req, F> stub::Stub for Retry<F, Stub>
where
Stub: stub::Stub<Req = Arc<Req>>,
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool,
{
type Req = Req;
type Resp = Stub::Resp;
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Result<Stub::Resp, RpcError> {
let request = Arc::new(request);
for i in 1.. {
let result = self
.stub
.call(ctx, request_name, Arc::clone(&request))
.await;
if (self.should_retry)(&result, i) {
tracing::trace!("Retrying on attempt {i}");
continue;
}
return result;
}
unreachable!("Wow, that was a lot of attempts!");
}
}
/// A Stub that retries requests based on response contents.
/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled.
#[derive(Clone, Debug)]
pub struct Retry<F, Stub> {
should_retry: F,
stub: Stub,
}
impl<Stub, Req, F> Retry<F, Stub>
where
Stub: stub::Stub<Req = Arc<Req>>,
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool,
{
/// Creates a new Retry stub that delegates calls to the underlying `stub`.
pub fn new(stub: Stub, should_retry: F) -> Self {
Self { stub, should_retry }
}
}

View File

@@ -126,13 +126,9 @@
//! struct HelloServer; //! struct HelloServer;
//! //!
//! impl World for HelloServer { //! impl World for HelloServer {
//! // Each defined rpc generates two items in the trait, a fn that serves the RPC, and //! // Each defined rpc generates an async fn that serves the RPC
//! // an associated type representing the future output by the fn. //! async fn hello(self, _: context::Context, name: String) -> String {
//! //! format!("Hello, {name}!")
//! type HelloFut = Ready<String>;
//!
//! fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
//! future::ready(format!("Hello, {name}!"))
//! } //! }
//! } //! }
//! ``` //! ```
@@ -164,11 +160,9 @@
//! # #[derive(Clone)] //! # #[derive(Clone)]
//! # struct HelloServer; //! # struct HelloServer;
//! # impl World for HelloServer { //! # impl World for HelloServer {
//! # // Each defined rpc generates two items in the trait, a fn that serves the RPC, and //! // Each defined rpc generates an async fn that serves the RPC
//! # // an associated type representing the future output by the fn. //! # async fn hello(self, _: context::Context, name: String) -> String {
//! # type HelloFut = Ready<String>; //! # format!("Hello, {name}!")
//! # fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
//! # future::ready(format!("Hello, {name}!"))
//! # } //! # }
//! # } //! # }
//! # #[cfg(not(feature = "tokio1"))] //! # #[cfg(not(feature = "tokio1"))]
@@ -179,7 +173,12 @@
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); //! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
//! //!
//! let server = server::BaseChannel::with_defaults(server_transport); //! let server = server::BaseChannel::with_defaults(server_transport);
//! tokio::spawn(server.execute(HelloServer.serve())); //! tokio::spawn(
//! server.execute(HelloServer.serve())
//! // Handle all requests concurrently.
//! .for_each(|response| async move {
//! tokio::spawn(response);
//! }));
//! //!
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` //! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
//! // that takes a config and any Transport as input. //! // that takes a config and any Transport as input.
@@ -200,6 +199,7 @@
//! //!
//! Use `cargo doc` as you normally would to see the documentation created for all //! Use `cargo doc` as you normally would to see the documentation created for all
//! items expanded by a `service!` invocation. //! items expanded by a `service!` invocation.
#![deny(missing_docs)] #![deny(missing_docs)]
#![allow(clippy::type_complexity)] #![allow(clippy::type_complexity)]
#![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, feature(doc_cfg))]
@@ -244,62 +244,6 @@ pub use tarpc_plugins::derive_serde;
/// * `fn new_stub` -- creates a new Client stub. /// * `fn new_stub` -- creates a new Client stub.
pub use tarpc_plugins::service; pub use tarpc_plugins::service;
/// A utility macro that can be used for RPC server implementations.
///
/// Syntactic sugar to make using async functions in the server implementation
/// easier. It does this by rewriting code like this, which would normally not
/// compile because async functions are disallowed in trait implementations:
///
/// ```rust
/// # use tarpc::context;
/// # use std::net::SocketAddr;
/// #[tarpc::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// #[tarpc::server]
/// impl World for HelloServer {
/// async fn hello(self, _: context::Context, name: String) -> String {
/// format!("Hello, {name}! You are connected from {:?}.", self.0)
/// }
/// }
/// ```
///
/// Into code like this, which matches the service trait definition:
///
/// ```rust
/// # use tarpc::context;
/// # use std::pin::Pin;
/// # use futures::Future;
/// # use std::net::SocketAddr;
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// #[tarpc::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// impl World for HelloServer {
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
///
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
/// + Send>> {
/// Box::pin(async move {
/// format!("Hello, {name}! You are connected from {:?}.", self.0)
/// })
/// }
/// }
/// ```
///
/// Note that this won't touch functions unless they have been annotated with
/// `async`, meaning that this should not break existing code.
pub use tarpc_plugins::server;
pub(crate) mod cancellations; pub(crate) mod cancellations;
pub mod client; pub mod client;
pub mod context; pub mod context;
@@ -407,6 +351,13 @@ where
Close(#[source] E), Close(#[source] E),
} }
impl ServerError {
/// Returns a new server error with `kind` and `detail`.
pub fn new(kind: io::ErrorKind, detail: String) -> ServerError {
Self { kind, detail }
}
}
impl<T> Request<T> { impl<T> Request<T> {
/// Returns the deadline for this request. /// Returns the deadline for this request.
pub fn deadline(&self) -> &SystemTime { pub fn deadline(&self) -> &SystemTime {

View File

@@ -9,7 +9,7 @@
use crate::{ use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation}, cancellations::{cancellations, CanceledRequests, RequestCancellation},
context::{self, SpanExt}, context::{self, SpanExt},
trace, ChannelError, ClientMessage, Request, Response, Transport, trace, ChannelError, ClientMessage, Request, Response, ServerError, Transport,
}; };
use ::tokio::sync::mpsc; use ::tokio::sync::mpsc;
use futures::{ use futures::{
@@ -25,6 +25,7 @@ use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sy
use tracing::{info_span, instrument::Instrument, Span}; use tracing::{info_span, instrument::Instrument, Span};
mod in_flight_requests; mod in_flight_requests;
pub mod request_hook;
#[cfg(test)] #[cfg(test)]
mod testing; mod testing;
@@ -34,10 +35,9 @@ pub mod limits;
/// Provides helper methods for streams of Channels. /// Provides helper methods for streams of Channels.
pub mod incoming; pub mod incoming;
/// Provides convenience functionality for tokio-enabled applications. use request_hook::{
#[cfg(feature = "tokio1")] AfterRequest, BeforeRequest, HookThenServe, HookThenServeThenHook, ServeThenHook,
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] };
pub mod tokio;
/// Settings that control the behavior of [channels](Channel). /// Settings that control the behavior of [channels](Channel).
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@@ -67,32 +67,204 @@ impl Config {
} }
/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`. /// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
pub trait Serve<Req> { #[allow(async_fn_in_trait)]
pub trait Serve {
/// Type of request.
type Req;
/// Type of response. /// Type of response.
type Resp; type Resp;
/// Type of response future. /// Responds to a single request.
type Fut: Future<Output = Self::Resp>; async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError>;
/// Extracts a method name from the request. /// Extracts a method name from the request.
fn method(&self, _request: &Req) -> Option<&'static str> { fn method(&self, _request: &Self::Req) -> Option<&'static str> {
None None
} }
/// Responds to a single request. /// Runs a hook before execution of the request.
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut; ///
/// If the hook returns an error, the request will not be executed and the error will be
/// returned instead.
///
/// The hook can also modify the request context. This could be used, for example, to enforce a
/// maximum deadline on all requests.
///
/// Any type that implements [`BeforeRequest`] can be used as the hook. Types that implement
/// `FnMut(&mut Context, &RequestType) -> impl Future<Output = Result<(), ServerError>>` can
/// also be used.
///
/// # Example
///
/// ```rust
/// use futures::{executor::block_on, future};
/// use tarpc::{context, ServerError, server::{Serve, serve}};
/// use std::io;
///
/// let serve = serve(|_ctx, i| async move { Ok(i + 1) })
/// .before(|_ctx: &mut context::Context, req: &i32| {
/// future::ready(
/// if *req == 1 {
/// Err(ServerError::new(
/// io::ErrorKind::Other,
/// format!("I don't like {req}")))
/// } else {
/// Ok(())
/// })
/// });
/// let response = serve.serve(context::current(), 1);
/// assert!(block_on(response).is_err());
/// ```
fn before<Hook>(self, hook: Hook) -> HookThenServe<Self, Hook>
where
Hook: BeforeRequest<Self::Req>,
Self: Sized,
{
HookThenServe::new(self, hook)
}
/// Runs a hook after completion of a request.
///
/// The hook can modify the request context and the response.
///
/// Any type that implements [`AfterRequest`] can be used as the hook. Types that implement
/// `FnMut(&mut Context, &mut Result<ResponseType, ServerError>) -> impl Future<Output = ()>`
/// can also be used.
///
/// # Example
///
/// ```rust
/// use futures::{executor::block_on, future};
/// use tarpc::{context, ServerError, server::{Serve, serve}};
/// use std::io;
///
/// let serve = serve(
/// |_ctx, i| async move {
/// if i == 1 {
/// Err(ServerError::new(
/// io::ErrorKind::Other,
/// format!("{i} is the loneliest number")))
/// } else {
/// Ok(i + 1)
/// }
/// })
/// .after(|_ctx: &mut context::Context, resp: &mut Result<i32, ServerError>| {
/// if let Err(e) = resp {
/// eprintln!("server error: {e:?}");
/// }
/// future::ready(())
/// });
///
/// let response = serve.serve(context::current(), 1);
/// assert!(block_on(response).is_err());
/// ```
fn after<Hook>(self, hook: Hook) -> ServeThenHook<Self, Hook>
where
Hook: AfterRequest<Self::Resp>,
Self: Sized,
{
ServeThenHook::new(self, hook)
}
/// Runs a hook before and after execution of the request.
///
/// If the hook returns an error, the request will not be executed and the error will be
/// returned instead.
///
/// The hook can also modify the request context and the response. This could be used, for
/// example, to enforce a maximum deadline on all requests.
///
/// # Example
///
/// ```rust
/// use futures::{executor::block_on, future};
/// use tarpc::{
/// context, ServerError, server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest}}
/// };
/// use std::{io, time::Instant};
///
/// struct PrintLatency(Instant);
///
/// impl<Req> BeforeRequest<Req> for PrintLatency {
/// async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> {
/// self.0 = Instant::now();
/// Ok(())
/// }
/// }
///
/// impl<Resp> AfterRequest<Resp> for PrintLatency {
/// async fn after(
/// &mut self,
/// _: &mut context::Context,
/// _: &mut Result<Resp, ServerError>,
/// ) {
/// tracing::info!("Elapsed: {:?}", self.0.elapsed());
/// }
/// }
///
/// let serve = serve(|_ctx, i| async move {
/// Ok(i + 1)
/// }).before_and_after(PrintLatency(Instant::now()));
/// let response = serve.serve(context::current(), 1);
/// assert!(block_on(response).is_ok());
/// ```
fn before_and_after<Hook>(
self,
hook: Hook,
) -> HookThenServeThenHook<Self::Req, Self::Resp, Self, Hook>
where
Hook: BeforeRequest<Self::Req> + AfterRequest<Self::Resp>,
Self: Sized,
{
HookThenServeThenHook::new(self, hook)
}
} }
impl<Req, Resp, Fut, F> Serve<Req> for F /// A Serve wrapper around a Fn.
#[derive(Debug)]
pub struct ServeFn<Req, Resp, F> {
f: F,
data: PhantomData<fn(Req) -> Resp>,
}
impl<Req, Resp, F> Clone for ServeFn<Req, Resp, F>
where
F: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
data: PhantomData,
}
}
}
impl<Req, Resp, F> Copy for ServeFn<Req, Resp, F> where F: Copy {}
/// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future<Output =
/// Result<Resp, ServerError>>`.
pub fn serve<Req, Resp, Fut, F>(f: F) -> ServeFn<Req, Resp, F>
where where
F: FnOnce(context::Context, Req) -> Fut, F: FnOnce(context::Context, Req) -> Fut,
Fut: Future<Output = Resp>, Fut: Future<Output = Result<Resp, ServerError>>,
{ {
type Resp = Resp; ServeFn {
type Fut = Fut; f,
data: PhantomData,
}
}
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut { impl<Req, Resp, Fut, F> Serve for ServeFn<Req, Resp, F>
self(ctx, req) where
F: FnOnce(context::Context, Req) -> Fut,
Fut: Future<Output = Result<Resp, ServerError>>,
{
type Req = Req;
type Resp = Resp;
async fn serve(self, ctx: context::Context, req: Req) -> Result<Resp, ServerError> {
(self.f)(ctx, req).await
} }
} }
@@ -120,7 +292,7 @@ pub struct BaseChannel<Req, Resp, T> {
/// Holds data necessary to clean up in-flight requests. /// Holds data necessary to clean up in-flight requests.
in_flight_requests: InFlightRequests, in_flight_requests: InFlightRequests,
/// Types the request and response. /// Types the request and response.
ghost: PhantomData<(Req, Resp)>, ghost: PhantomData<(fn() -> Req, fn(Resp))>,
} }
impl<Req, Resp, T> BaseChannel<Req, Resp, T> impl<Req, Resp, T> BaseChannel<Req, Resp, T>
@@ -307,6 +479,34 @@ where
/// This is a terminal operation. After calling `requests`, the channel cannot be retrieved, /// This is a terminal operation. After calling `requests`, the channel cannot be retrieved,
/// and the only way to complete requests is via [`Requests::execute`] or /// and the only way to complete requests is via [`Requests::execute`] or
/// [`InFlightRequest::execute`]. /// [`InFlightRequest::execute`].
///
/// # Example
///
/// ```rust
/// use tarpc::{
/// context,
/// client::{self, NewClient},
/// server::{self, BaseChannel, Channel, serve},
/// transport,
/// };
/// use futures::prelude::*;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = transport::channel::unbounded();
/// let server = BaseChannel::new(server::Config::default(), rx);
/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx);
/// tokio::spawn(dispatch);
///
/// let mut requests = server.requests();
/// tokio::spawn(async move {
/// while let Some(Ok(request)) = requests.next().await {
/// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) })));
/// }
/// });
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
/// }
/// ```
fn requests(self) -> Requests<Self> fn requests(self) -> Requests<Self>
where where
Self: Sized, Self: Sized,
@@ -320,18 +520,42 @@ where
} }
} }
/// Runs the channel until completion by executing all requests using the given service /// Returns a stream of request execution futures. Each future represents an in-flight request
/// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's /// being responded to by the server. The futures must be awaited or spawned to complete their
/// default executor. /// requests.
#[cfg(feature = "tokio1")] ///
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] /// # Example
fn execute<S>(self, serve: S) -> self::tokio::TokioChannelExecutor<Requests<Self>, S> ///
/// ```rust
/// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport};
/// use futures::prelude::*;
/// use tracing_subscriber::prelude::*;
///
/// #[derive(PartialEq, Eq, Debug)]
/// struct MyInt(i32);
///
/// # #[cfg(not(feature = "tokio1"))]
/// # fn main() {}
/// # #[cfg(feature = "tokio1")]
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = transport::channel::unbounded();
/// let client = client::new(client::Config::default(), tx).spawn();
/// let channel = BaseChannel::with_defaults(rx);
/// tokio::spawn(
/// channel.execute(serve(|_, MyInt(i)| async move { Ok(MyInt(i + 1)) }))
/// .for_each(|response| async move {
/// tokio::spawn(response);
/// }));
/// assert_eq!(
/// client.call(context::current(), "AddOne", MyInt(1)).await.unwrap(),
/// MyInt(2));
/// }
/// ```
fn execute<S>(self, serve: S) -> impl Stream<Item = impl Future<Output = ()>>
where where
Self: Sized, Self: Sized,
S: Serve<Self::Req, Resp = Self::Resp> + Send + 'static, S: Serve<Req = Self::Req, Resp = Self::Resp> + Clone,
S::Fut: Send,
Self::Req: Send + 'static,
Self::Resp: Send + 'static,
{ {
self.requests().execute(serve) self.requests().execute(serve)
} }
@@ -425,15 +649,17 @@ where
Poll::Pending => Pending, Poll::Pending => Pending,
}; };
tracing::trace!( let status = cancellation_status
"Expired requests: {:?}, Inbound: {:?}",
expiration_status,
request_status
);
match cancellation_status
.combine(expiration_status) .combine(expiration_status)
.combine(request_status) .combine(request_status);
{
tracing::trace!(
"Cancellations: {cancellation_status:?}, \
Expired requests: {expiration_status:?}, \
Inbound: {request_status:?}, \
Overall: {status:?}",
);
match status {
Ready => continue, Ready => continue,
Closed => return Poll::Ready(None), Closed => return Poll::Ready(None),
Pending => return Poll::Pending, Pending => return Poll::Pending,
@@ -565,6 +791,10 @@ where
}| { }| {
// The response guard becomes active once in an InFlightRequest. // The response guard becomes active once in an InFlightRequest.
response_guard.cancel = true; response_guard.cancel = true;
{
let _entered = span.enter();
tracing::info!("BeginRequest");
}
InFlightRequest { InFlightRequest {
request, request,
abort_registration, abort_registration,
@@ -639,6 +869,51 @@ where
} }
Poll::Ready(Some(Ok(()))) Poll::Ready(Some(Ok(())))
} }
/// Returns a stream of request execution futures. Each future represents an in-flight request
/// being responded to by the server. The futures must be awaited or spawned to complete their
/// requests.
///
/// If the channel encounters an error, the stream is terminated and the error is logged.
///
/// # Example
///
/// ```rust
/// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport};
/// use futures::prelude::*;
///
/// # #[cfg(not(feature = "tokio1"))]
/// # fn main() {}
/// # #[cfg(feature = "tokio1")]
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = transport::channel::unbounded();
/// let requests = BaseChannel::new(server::Config::default(), rx).requests();
/// let client = client::new(client::Config::default(), tx).spawn();
/// tokio::spawn(
/// requests.execute(serve(|_, i| async move { Ok(i + 1) }))
/// .for_each(|response| async move {
/// tokio::spawn(response);
/// }));
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
/// }
/// ```
pub fn execute<S>(self, serve: S) -> impl Stream<Item = impl Future<Output = ()>>
where
S: Serve<Req = C::Req, Resp = C::Resp> + Clone,
{
self.take_while(|result| {
if let Err(e) = result {
tracing::warn!("Requests stream errored out: {}", e);
}
futures::future::ready(result.is_ok())
})
.filter_map(|result| async move { result.ok() })
.map(move |request| {
let serve = serve.clone();
request.execute(serve)
})
}
} }
impl<C> fmt::Debug for Requests<C> impl<C> fmt::Debug for Requests<C>
@@ -700,9 +975,39 @@ impl<Req, Res> InFlightRequest<Req, Res> {
/// ///
/// If the returned Future is dropped before completion, a cancellation message will be sent to /// If the returned Future is dropped before completion, a cancellation message will be sent to
/// the Channel to clean up associated request state. /// the Channel to clean up associated request state.
///
/// # Example
///
/// ```rust
/// use tarpc::{
/// context,
/// client::{self, NewClient},
/// server::{self, BaseChannel, Channel, serve},
/// transport,
/// };
/// use futures::prelude::*;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = transport::channel::unbounded();
/// let server = BaseChannel::new(server::Config::default(), rx);
/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx);
/// tokio::spawn(dispatch);
///
/// tokio::spawn(async move {
/// let mut requests = server.requests();
/// while let Some(Ok(in_flight_request)) = requests.next().await {
/// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await;
/// }
///
/// });
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
/// }
/// ```
///
pub async fn execute<S>(self, serve: S) pub async fn execute<S>(self, serve: S)
where where
S: Serve<Req, Resp = Res>, S: Serve<Req = Req, Resp = Res>,
{ {
let Self { let Self {
response_tx, response_tx,
@@ -717,18 +1022,14 @@ impl<Req, Res> InFlightRequest<Req, Res> {
}, },
} = self; } = self;
let method = serve.method(&message); let method = serve.method(&message);
// TODO(https://github.com/rust-lang/rust-clippy/issues/9111) span.record("otel.name", method.unwrap_or(""));
// remove when clippy is fixed
#[allow(clippy::needless_borrow)]
span.record("otel.name", &method.unwrap_or(""));
let _ = Abortable::new( let _ = Abortable::new(
async move { async move {
tracing::info!("BeginRequest"); let message = serve.serve(context, message).await;
let response = serve.serve(context, message).await;
tracing::info!("CompleteRequest"); tracing::info!("CompleteRequest");
let response = Response { let response = Response {
request_id, request_id,
message: Ok(response), message,
}; };
let _ = response_tx.send(response).await; let _ = response_tx.send(response).await;
tracing::info!("BufferResponse"); tracing::info!("BufferResponse");
@@ -744,6 +1045,13 @@ impl<Req, Res> InFlightRequest<Req, Res> {
} }
} }
fn print_err(e: &(dyn Error + 'static)) -> String {
anyhow::Chain::new(e)
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(": ")
}
impl<C> Stream for Requests<C> impl<C> Stream for Requests<C>
where where
C: Channel, C: Channel,
@@ -752,17 +1060,33 @@ where
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop { loop {
let read = self.as_mut().pump_read(cx)?; let read = self.as_mut().pump_read(cx).map_err(|e| {
tracing::trace!("read: {}", print_err(&e));
e
})?;
let read_closed = matches!(read, Poll::Ready(None)); let read_closed = matches!(read, Poll::Ready(None));
match (read, self.as_mut().pump_write(cx, read_closed)?) { let write = self.as_mut().pump_write(cx, read_closed).map_err(|e| {
tracing::trace!("write: {}", print_err(&e));
e
})?;
match (read, write) {
(Poll::Ready(None), Poll::Ready(None)) => { (Poll::Ready(None), Poll::Ready(None)) => {
tracing::trace!("read: Poll::Ready(None), write: Poll::Ready(None)");
return Poll::Ready(None); return Poll::Ready(None);
} }
(Poll::Ready(Some(request_handler)), _) => { (Poll::Ready(Some(request_handler)), _) => {
tracing::trace!("read: Poll::Ready(Some), write: _");
return Poll::Ready(Some(Ok(request_handler))); return Poll::Ready(Some(Ok(request_handler)));
} }
(_, Poll::Ready(Some(()))) => {} (_, Poll::Ready(Some(()))) => {
_ => { tracing::trace!("read: _, write: Poll::Ready(Some)");
}
(read @ Poll::Pending, write) | (read, write @ Poll::Pending) => {
tracing::trace!(
"read pending: {}, write pending: {}",
read.is_pending(),
write.is_pending()
);
return Poll::Pending; return Poll::Pending;
} }
} }
@@ -772,11 +1096,14 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests}; use super::{
in_flight_requests::AlreadyExistsError, serve, AfterRequest, BaseChannel, BeforeRequest,
Channel, Config, Requests, Serve,
};
use crate::{ use crate::{
context, trace, context, trace,
transport::channel::{self, UnboundedChannel}, transport::channel::{self, UnboundedChannel},
ClientMessage, Request, Response, ClientMessage, Request, Response, ServerError,
}; };
use assert_matches::assert_matches; use assert_matches::assert_matches;
use futures::{ use futures::{
@@ -785,7 +1112,12 @@ mod tests {
Future, Future,
}; };
use futures_test::task::noop_context; use futures_test::task::noop_context;
use std::{pin::Pin, task::Poll}; use std::{
io,
pin::Pin,
task::Poll,
time::{Duration, Instant, SystemTime},
};
fn test_channel<Req, Resp>() -> ( fn test_channel<Req, Resp>() -> (
Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>, Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
@@ -846,6 +1178,89 @@ mod tests {
Abortable::new(pending(), abort_registration) Abortable::new(pending(), abort_registration)
} }
#[tokio::test]
async fn test_serve() {
let serve = serve(|_, i| async move { Ok(i) });
assert_matches!(serve.serve(context::current(), 7).await, Ok(7));
}
#[tokio::test]
async fn serve_before_mutates_context() -> anyhow::Result<()> {
struct SetDeadline(SystemTime);
impl<Req> BeforeRequest<Req> for SetDeadline {
async fn before(
&mut self,
ctx: &mut context::Context,
_: &Req,
) -> Result<(), ServerError> {
ctx.deadline = self.0;
Ok(())
}
}
let some_time = SystemTime::UNIX_EPOCH + Duration::from_secs(37);
let some_other_time = SystemTime::UNIX_EPOCH + Duration::from_secs(83);
let serve = serve(move |ctx: context::Context, i| async move {
assert_eq!(ctx.deadline, some_time);
Ok(i)
});
let deadline_hook = serve.before(SetDeadline(some_time));
let mut ctx = context::current();
ctx.deadline = some_other_time;
deadline_hook.serve(ctx, 7).await?;
Ok(())
}
#[tokio::test]
async fn serve_before_and_after() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
struct PrintLatency {
start: Instant,
}
impl PrintLatency {
fn new() -> Self {
Self {
start: Instant::now(),
}
}
}
impl<Req> BeforeRequest<Req> for PrintLatency {
async fn before(
&mut self,
_: &mut context::Context,
_: &Req,
) -> Result<(), ServerError> {
self.start = Instant::now();
Ok(())
}
}
impl<Resp> AfterRequest<Resp> for PrintLatency {
async fn after(&mut self, _: &mut context::Context, _: &mut Result<Resp, ServerError>) {
tracing::info!("Elapsed: {:?}", self.start.elapsed());
}
}
let serve = serve(move |_: context::Context, i| async move { Ok(i) });
serve
.before_and_after(PrintLatency::new())
.serve(context::current(), 7)
.await?;
Ok(())
}
#[tokio::test]
async fn serve_before_error_aborts_request() -> anyhow::Result<()> {
let serve = serve(|_, _| async { panic!("Shouldn't get here") });
let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async {
Err(ServerError::new(io::ErrorKind::Other, "oops".into()))
});
let resp: Result<i32, _> = deadline_hook.serve(context::current(), 7).await;
assert_matches!(resp, Err(_));
Ok(())
}
#[tokio::test] #[tokio::test]
async fn base_channel_start_send_duplicate_request_returns_error() { async fn base_channel_start_send_duplicate_request_returns_error() {
let (mut channel, _tx) = test_channel::<(), ()>(); let (mut channel, _tx) = test_channel::<(), ()>();
@@ -1046,7 +1461,7 @@ mod tests {
Poll::Ready(Some(Ok(request))) => request, Poll::Ready(Some(Ok(request))) => request,
result => panic!("Unexpected result: {:?}", result), result => panic!("Unexpected result: {:?}", result),
}; };
request.execute(|_, _| async {}).await; request.execute(serve(|_, _| async { Ok(()) })).await;
assert!(requests assert!(requests
.as_mut() .as_mut()
.channel_pin_mut() .channel_pin_mut()

View File

@@ -1,13 +1,10 @@
use super::{ use super::{
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel}, limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
Channel, Channel, Serve,
}; };
use futures::prelude::*; use futures::prelude::*;
use std::{fmt, hash::Hash}; use std::{fmt, hash::Hash};
#[cfg(feature = "tokio1")]
use super::{tokio::TokioServerExecutor, Serve};
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel). /// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
pub trait Incoming<C> pub trait Incoming<C>
where where
@@ -28,16 +25,62 @@ where
MaxRequestsPerChannel::new(self, n) MaxRequestsPerChannel::new(self, n)
} }
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled /// Returns a stream of channels in execution. Each channel in execution is a stream of
/// concurrently by spawning on tokio's default executor, and each request will be also /// futures, where each future is an in-flight request being rsponded to.
/// be spawned on tokio's default executor. fn execute<S>(
#[cfg(feature = "tokio1")] self,
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] serve: S,
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S> ) -> impl Stream<Item = impl Stream<Item = impl Future<Output = ()>>>
where where
S: Serve<C::Req, Resp = C::Resp>, S: Serve<Req = C::Req, Resp = C::Resp> + Clone,
{ {
TokioServerExecutor::new(self, serve) self.map(move |channel| channel.execute(serve.clone()))
}
}
#[cfg(feature = "tokio1")]
/// Spawns all channels-in-execution, delegating to the tokio runtime to manage their completion.
/// Each channel is spawned, and each request from each channel is spawned.
/// Note that this function is generic over any stream-of-streams-of-futures, but it is intended
/// for spawning streams of channels.
///
/// # Example
/// ```rust
/// use tarpc::{
/// context,
/// client::{self, NewClient},
/// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve},
/// transport,
/// };
/// use futures::prelude::*;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = transport::channel::unbounded();
/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx);
/// tokio::spawn(dispatch);
///
/// let incoming = stream::once(async move {
/// BaseChannel::new(server::Config::default(), rx)
/// }).execute(serve(|_, i| async move { Ok(i + 1) }));
/// tokio::spawn(spawn_incoming(incoming));
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
/// }
/// ```
pub async fn spawn_incoming(
incoming: impl Stream<
Item = impl Stream<Item = impl Future<Output = ()> + Send + 'static> + Send + 'static,
>,
) {
use futures::pin_mut;
pin_mut!(incoming);
while let Some(channel) = incoming.next().await {
tokio::spawn(async move {
pin_mut!(channel);
while let Some(request) = channel.next().await {
tokio::spawn(request);
}
});
} }
} }

View File

@@ -0,0 +1,25 @@
// Copyright 2022 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Hooks for horizontal functionality that can run either before or after a request is executed.
/// A request hook that runs before a request is executed.
mod before;
/// A request hook that runs after a request is completed.
mod after;
/// A request hook that runs both before a request is executed and after it is completed.
mod before_and_after;
pub use {
after::{AfterRequest, ServeThenHook},
before::{
before, BeforeRequest, BeforeRequestCons, BeforeRequestList, BeforeRequestNil,
HookThenServe,
},
before_and_after::HookThenServeThenHook,
};

View File

@@ -0,0 +1,72 @@
// Copyright 2022 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a hook that runs after request execution.
use crate::{context, server::Serve, ServerError};
use futures::prelude::*;
/// A hook that runs after request execution.
#[allow(async_fn_in_trait)]
pub trait AfterRequest<Resp> {
/// The function that is called after request execution.
///
/// The hook can modify the request context and the response.
async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>);
}
impl<F, Fut, Resp> AfterRequest<Resp> for F
where
F: FnMut(&mut context::Context, &mut Result<Resp, ServerError>) -> Fut,
Fut: Future<Output = ()>,
{
async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>) {
self(ctx, resp).await
}
}
/// A Service function that runs a hook after request execution.
pub struct ServeThenHook<Serv, Hook> {
serve: Serv,
hook: Hook,
}
impl<Serv, Hook> ServeThenHook<Serv, Hook> {
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
Self { serve, hook }
}
}
impl<Serv: Clone, Hook: Clone> Clone for ServeThenHook<Serv, Hook> {
fn clone(&self) -> Self {
Self {
serve: self.serve.clone(),
hook: self.hook.clone(),
}
}
}
impl<Serv, Hook> Serve for ServeThenHook<Serv, Hook>
where
Serv: Serve,
Hook: AfterRequest<Serv::Resp>,
{
type Req = Serv::Req;
type Resp = Serv::Resp;
async fn serve(
self,
mut ctx: context::Context,
req: Serv::Req,
) -> Result<Serv::Resp, ServerError> {
let ServeThenHook {
serve, mut hook, ..
} = self;
let mut resp = serve.serve(ctx, req).await;
hook.after(&mut ctx, &mut resp).await;
resp
}
}

View File

@@ -0,0 +1,210 @@
// Copyright 2022 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a hook that runs before request execution.
use crate::{context, server::Serve, ServerError};
use futures::prelude::*;
/// A hook that runs before request execution.
#[allow(async_fn_in_trait)]
pub trait BeforeRequest<Req> {
/// The function that is called before request execution.
///
/// If this function returns an error, the request will not be executed and the error will be
/// returned instead.
///
/// This function can also modify the request context. This could be used, for example, to
/// enforce a maximum deadline on all requests.
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>;
}
/// A list of hooks that run in order before request execution.
pub trait BeforeRequestList<Req>: BeforeRequest<Req> {
/// The hook returned by `BeforeRequestList::then`.
type Then<Next>: BeforeRequest<Req>
where
Next: BeforeRequest<Req>;
/// Returns a hook that, when run, runs two hooks, first `self` and then `next`.
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next>;
/// Same as `then`, but helps the compiler with type inference when Next is a closure.
fn then_fn<
Next: FnMut(&mut context::Context, &Req) -> Fut,
Fut: Future<Output = Result<(), ServerError>>,
>(
self,
next: Next,
) -> Self::Then<Next>
where
Self: Sized,
{
self.then(next)
}
/// The service fn returned by `BeforeRequestList::serving`.
type Serve<S: Serve<Req = Req>>: Serve<Req = Req>;
/// Runs the list of request hooks before execution of the given serve fn.
/// This is equivalent to `serve.before(before_request_chain)` but may be syntactically nicer.
fn serving<S: Serve<Req = Req>>(self, serve: S) -> Self::Serve<S>;
}
impl<F, Fut, Req> BeforeRequest<Req> for F
where
F: FnMut(&mut context::Context, &Req) -> Fut,
Fut: Future<Output = Result<(), ServerError>>,
{
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> {
self(ctx, req).await
}
}
/// A Service function that runs a hook before request execution.
#[derive(Clone)]
pub struct HookThenServe<Serv, Hook> {
serve: Serv,
hook: Hook,
}
impl<Serv, Hook> HookThenServe<Serv, Hook> {
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
Self { serve, hook }
}
}
impl<Serv, Hook> Serve for HookThenServe<Serv, Hook>
where
Serv: Serve,
Hook: BeforeRequest<Serv::Req>,
{
type Req = Serv::Req;
type Resp = Serv::Resp;
async fn serve(
self,
mut ctx: context::Context,
req: Self::Req,
) -> Result<Serv::Resp, ServerError> {
let HookThenServe {
serve, mut hook, ..
} = self;
hook.before(&mut ctx, &req).await?;
serve.serve(ctx, req).await
}
}
/// Returns a request hook builder that runs a series of hooks before request execution.
///
/// Example
///
/// ```rust
/// use futures::{executor::block_on, future};
/// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self,
/// BeforeRequest, BeforeRequestList}}};
/// use std::{cell::Cell, io};
///
/// let i = Cell::new(0);
/// let serve = request_hook::before()
/// .then_fn(|_, _| async {
/// assert!(i.get() == 0);
/// i.set(1);
/// Ok(())
/// })
/// .then_fn(|_, _| async {
/// assert!(i.get() == 1);
/// i.set(2);
/// Ok(())
/// })
/// .serving(serve(|_ctx, i| async move { Ok(i + 1) }));
/// let response = serve.clone().serve(context::current(), 1);
/// assert!(block_on(response).is_ok());
/// assert!(i.get() == 2);
/// ```
pub fn before() -> BeforeRequestNil {
BeforeRequestNil
}
/// A list of hooks that run in order before a request is executed.
#[derive(Clone, Copy)]
pub struct BeforeRequestCons<First, Rest>(First, Rest);
/// A noop hook that runs before a request is executed.
#[derive(Clone, Copy)]
pub struct BeforeRequestNil;
impl<Req, First: BeforeRequest<Req>, Rest: BeforeRequest<Req>> BeforeRequest<Req>
for BeforeRequestCons<First, Rest>
{
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> {
let BeforeRequestCons(first, rest) = self;
first.before(ctx, req).await?;
rest.before(ctx, req).await?;
Ok(())
}
}
impl<Req> BeforeRequest<Req> for BeforeRequestNil {
async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> {
Ok(())
}
}
impl<Req, First: BeforeRequest<Req>, Rest: BeforeRequestList<Req>> BeforeRequestList<Req>
for BeforeRequestCons<First, Rest>
{
type Then<Next> = BeforeRequestCons<First, Rest::Then<Next>> where Next: BeforeRequest<Req>;
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next> {
let BeforeRequestCons(first, rest) = self;
BeforeRequestCons(first, rest.then(next))
}
type Serve<S: Serve<Req = Req>> = HookThenServe<S, Self>;
fn serving<S: Serve<Req = Req>>(self, serve: S) -> Self::Serve<S> {
HookThenServe::new(serve, self)
}
}
impl<Req> BeforeRequestList<Req> for BeforeRequestNil {
type Then<Next> = BeforeRequestCons<Next, BeforeRequestNil> where Next: BeforeRequest<Req>;
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next> {
BeforeRequestCons(next, BeforeRequestNil)
}
type Serve<S: Serve<Req = Req>> = S;
fn serving<S: Serve<Req = Req>>(self, serve: S) -> S {
serve
}
}
#[test]
fn before_request_list() {
use crate::server::serve;
use futures::executor::block_on;
use std::cell::Cell;
let i = Cell::new(0);
let serve = before()
.then_fn(|_, _| async {
assert!(i.get() == 0);
i.set(1);
Ok(())
})
.then_fn(|_, _| async {
assert!(i.get() == 1);
i.set(2);
Ok(())
})
.serving(serve(|_ctx, i| async move { Ok(i + 1) }));
let response = serve.clone().serve(context::current(), 1);
assert!(block_on(response).is_ok());
assert!(i.get() == 2);
}

View File

@@ -0,0 +1,57 @@
// Copyright 2022 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
//! Provides a hook that runs both before and after request execution.
use super::{after::AfterRequest, before::BeforeRequest};
use crate::{context, server::Serve, ServerError};
use std::marker::PhantomData;
/// A Service function that runs a hook both before and after request execution.
pub struct HookThenServeThenHook<Req, Resp, Serv, Hook> {
serve: Serv,
hook: Hook,
fns: PhantomData<(fn(Req), fn(Resp))>,
}
impl<Req, Resp, Serv, Hook> HookThenServeThenHook<Req, Resp, Serv, Hook> {
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
Self {
serve,
hook,
fns: PhantomData,
}
}
}
impl<Req, Resp, Serv: Clone, Hook: Clone> Clone for HookThenServeThenHook<Req, Resp, Serv, Hook> {
fn clone(&self) -> Self {
Self {
serve: self.serve.clone(),
hook: self.hook.clone(),
fns: PhantomData,
}
}
}
impl<Req, Resp, Serv, Hook> Serve for HookThenServeThenHook<Req, Resp, Serv, Hook>
where
Serv: Serve<Req = Req, Resp = Resp>,
Hook: BeforeRequest<Req> + AfterRequest<Resp>,
{
type Req = Req;
type Resp = Resp;
async fn serve(self, mut ctx: context::Context, req: Req) -> Result<Serv::Resp, ServerError> {
let HookThenServeThenHook {
serve, mut hook, ..
} = self;
hook.before(&mut ctx, &req).await?;
let mut resp = serve.serve(ctx, req).await;
hook.after(&mut ctx, &mut resp).await;
resp
}
}

View File

@@ -1,113 +0,0 @@
use super::{Channel, Requests, Serve};
use futures::{prelude::*, ready, task::*};
use pin_project::pin_project;
use std::pin::Pin;
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
/// for each new channel. Returned by
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
#[must_use]
#[pin_project]
#[derive(Debug)]
pub struct TokioServerExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
impl<T, S> TokioServerExecutor<T, S> {
pub(crate) fn new(inner: T, serve: S) -> Self {
Self { inner, serve }
}
}
/// A future that drives the server by [spawning](tokio::spawn) each [response
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
/// [`Channel::execute`](crate::server::Channel::execute).
#[must_use]
#[pin_project]
#[derive(Debug)]
pub struct TokioChannelExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
impl<T, S> TokioServerExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
impl<T, S> TokioChannelExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
// Send + 'static execution helper methods.
impl<C> Requests<C>
where
C: Channel,
C::Req: Send + 'static,
C::Resp: Send + 'static,
{
/// Executes all requests using the given service function. Requests are handled concurrently
/// by [spawning](::tokio::spawn) each handler on tokio's default executor.
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
where
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
{
TokioChannelExecutor { inner: self, serve }
}
}
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
where
St: Sized + Stream<Item = C>,
C: Channel + Send + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
Se::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
tokio::spawn(channel.execute(self.serve.clone()));
}
tracing::info!("Server shutting down.");
Poll::Ready(())
}
}
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
where
C: Channel + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
S::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
match response_handler {
Ok(resp) => {
let server = self.serve.clone();
tokio::spawn(async move {
resp.execute(server).await;
});
}
Err(e) => {
tracing::warn!("Requests stream errored out: {}", e);
break;
}
}
}
Poll::Ready(())
}
}

View File

@@ -14,9 +14,15 @@ use tokio::sync::mpsc;
/// Errors that occur in the sending or receiving of messages over a channel. /// Errors that occur in the sending or receiving of messages over a channel.
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum ChannelError { pub enum ChannelError {
/// An error occurred sending over the channel. /// An error occurred readying to send into the channel.
#[error("an error occurred sending over the channel")] #[error("an error occurred readying to send into the channel")]
Ready(#[source] Box<dyn Error + Send + Sync + 'static>),
/// An error occurred sending into the channel.
#[error("an error occurred sending into the channel")]
Send(#[source] Box<dyn Error + Send + Sync + 'static>), Send(#[source] Box<dyn Error + Send + Sync + 'static>),
/// An error occurred receiving from the channel.
#[error("an error occurred receiving from the channel")]
Receive(#[source] Box<dyn Error + Send + Sync + 'static>),
} }
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's /// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
@@ -48,7 +54,10 @@ impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, ChannelError>>> { ) -> Poll<Option<Result<Item, ChannelError>>> {
self.rx.poll_recv(cx).map(|option| option.map(Ok)) self.rx
.poll_recv(cx)
.map(|option| option.map(Ok))
.map_err(ChannelError::Receive)
} }
} }
@@ -59,7 +68,7 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(if self.tx.is_closed() { Poll::Ready(if self.tx.is_closed() {
Err(ChannelError::Send(CLOSED_MESSAGE.into())) Err(ChannelError::Ready(CLOSED_MESSAGE.into()))
} else { } else {
Ok(()) Ok(())
}) })
@@ -110,7 +119,11 @@ impl<Item, SinkItem> Stream for Channel<Item, SinkItem> {
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, ChannelError>>> { ) -> Poll<Option<Result<Item, ChannelError>>> {
self.project().rx.poll_next(cx).map(|option| option.map(Ok)) self.project()
.rx
.poll_next(cx)
.map(|option| option.map(Ok))
.map_err(ChannelError::Receive)
} }
} }
@@ -121,7 +134,7 @@ impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
self.project() self.project()
.tx .tx
.poll_ready(cx) .poll_ready(cx)
.map_err(|e| ChannelError::Send(Box::new(e))) .map_err(|e| ChannelError::Ready(Box::new(e)))
} }
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> { fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
@@ -146,16 +159,17 @@ impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
} }
} }
#[cfg(test)] #[cfg(all(test, feature = "tokio1"))]
#[cfg(feature = "tokio1")]
mod tests { mod tests {
use crate::{ use crate::{
client, context, client::{self, RpcError},
server::{incoming::Incoming, BaseChannel}, context,
server::{incoming::Incoming, serve, BaseChannel},
transport::{ transport::{
self, self,
channel::{Channel, UnboundedChannel}, channel::{Channel, UnboundedChannel},
}, },
ServerError,
}; };
use assert_matches::assert_matches; use assert_matches::assert_matches;
use futures::{prelude::*, stream}; use futures::{prelude::*, stream};
@@ -177,25 +191,28 @@ mod tests {
tokio::spawn( tokio::spawn(
stream::once(future::ready(server_channel)) stream::once(future::ready(server_channel))
.map(BaseChannel::with_defaults) .map(BaseChannel::with_defaults)
.execute(|_ctx, request: String| { .execute(serve(|_ctx, request: String| async move {
future::ready(request.parse::<u64>().map_err(|_| { request.parse::<u64>().map_err(|_| {
io::Error::new( ServerError::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
format!("{request:?} is not an int"), format!("{request:?} is not an int"),
) )
})) })
}))
.for_each(|channel| async move {
tokio::spawn(channel.for_each(|response| response));
}), }),
); );
let client = client::new(client::Config::default(), client_channel).spawn(); let client = client::new(client::Config::default(), client_channel).spawn();
let response1 = client.call(context::current(), "", "123".into()).await?; let response1 = client.call(context::current(), "", "123".into()).await;
let response2 = client.call(context::current(), "", "abc".into()).await?; let response2 = client.call(context::current(), "", "abc".into()).await;
trace!("response1: {:?}, response2: {:?}", response1, response2); trace!("response1: {:?}, response2: {:?}", response1, response2);
assert_matches!(response1, Ok(123)); assert_matches!(response1, Ok(123));
assert_matches!(response2, Err(ref e) if e.kind() == io::ErrorKind::InvalidInput); assert_matches!(response2, Err(RpcError::Server(e)) if e.kind == io::ErrorKind::InvalidInput);
Ok(()) Ok(())
} }

View File

@@ -2,8 +2,6 @@
fn ui() { fn ui() {
let t = trybuild::TestCases::new(); let t = trybuild::TestCases::new();
t.compile_fail("tests/compile_fail/*.rs"); t.compile_fail("tests/compile_fail/*.rs");
#[cfg(feature = "tokio1")]
t.compile_fail("tests/compile_fail/tokio/*.rs");
#[cfg(all(feature = "serde-transport", feature = "tcp"))] #[cfg(all(feature = "serde-transport", feature = "tcp"))]
t.compile_fail("tests/compile_fail/serde_transport/*.rs"); t.compile_fail("tests/compile_fail/serde_transport/*.rs");
} }

View File

@@ -9,3 +9,7 @@ note: the lint level is defined here
| |
11 | #[deny(unused_must_use)] 11 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^ | ^^^^^^^^^^^^^^^
help: use `let _ = ...` to ignore the resulting value
|
13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch;
| +++++++

View File

@@ -9,3 +9,7 @@ note: the lint level is defined here
| |
5 | #[deny(unused_must_use)] 5 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^ | ^^^^^^^^^^^^^^^
help: use `let _ = ...` to ignore the resulting value
|
7 | let _ = serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default);
| +++++++

View File

@@ -1,15 +0,0 @@
#[tarpc::service(derive_serde = false)]
trait World {
async fn hello(name: String) -> String;
}
struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
fn hello(name: String) -> String {
format!("Hello, {name}!", name)
}
}
fn main() {}

View File

@@ -1,11 +0,0 @@
error: not all trait items implemented, missing: `HelloFut`
--> $DIR/tarpc_server_missing_async.rs:9:1
|
9 | impl World for HelloServer {
| ^^^^
error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async
--> $DIR/tarpc_server_missing_async.rs:10:5
|
10 | fn hello(name: String) -> String {
| ^^

View File

@@ -1,29 +0,0 @@
use tarpc::{
context,
server::{self, Channel},
};
#[tarpc::service]
trait World {
async fn hello(name: String) -> String;
}
#[derive(Clone)]
struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hello, {name}!")
}
}
fn main() {
let (_, server_transport) = tarpc::transport::channel::unbounded();
let server = server::BaseChannel::with_defaults(server_transport);
#[deny(unused_must_use)]
{
server.execute(HelloServer.serve());
}
}

View File

@@ -1,11 +0,0 @@
error: unused `TokioChannelExecutor` that must be used
--> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9
|
27 | server.execute(HelloServer.serve());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12
|
25 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^

View File

@@ -1,30 +0,0 @@
use futures::stream::once;
use tarpc::{
context,
server::{self, incoming::Incoming},
};
#[tarpc::service]
trait World {
async fn hello(name: String) -> String;
}
#[derive(Clone)]
struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hello, {name}!")
}
}
fn main() {
let (_, server_transport) = tarpc::transport::channel::unbounded();
let server = once(async move { server::BaseChannel::with_defaults(server_transport) });
#[deny(unused_must_use)]
{
server.execute(HelloServer.serve());
}
}

View File

@@ -1,11 +0,0 @@
error: unused `TokioServerExecutor` that must be used
--> tests/compile_fail/tokio/must_use_server_executor.rs:28:9
|
28 | server.execute(HelloServer.serve());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/tokio/must_use_server_executor.rs:26:12
|
26 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^

View File

@@ -21,7 +21,6 @@ pub trait ColorProtocol {
#[derive(Clone)] #[derive(Clone)]
struct ColorServer; struct ColorServer;
#[tarpc::server]
impl ColorProtocol for ColorServer { impl ColorProtocol for ColorServer {
async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData { async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData {
match color { match color {
@@ -31,6 +30,11 @@ impl ColorProtocol for ColorServer {
} }
} }
#[cfg(test)]
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::test] #[tokio::test]
async fn test_call() -> anyhow::Result<()> { async fn test_call() -> anyhow::Result<()> {
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?; let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
@@ -40,7 +44,9 @@ async fn test_call() -> anyhow::Result<()> {
.take(1) .take(1)
.filter_map(|r| async { r.ok() }) .filter_map(|r| async { r.ok() })
.map(BaseChannel::with_defaults) .map(BaseChannel::with_defaults)
.execute(ColorServer.serve()), .execute(ColorServer.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
); );
let transport = serde_transport::tcp::connect(addr, Json::default).await?; let transport = serde_transport::tcp::connect(addr, Json::default).await?;

View File

@@ -1,13 +1,13 @@
use assert_matches::assert_matches; use assert_matches::assert_matches;
use futures::{ use futures::{
future::{join_all, ready, Ready}, future::{join_all, ready},
prelude::*, prelude::*,
}; };
use std::time::{Duration, SystemTime}; use std::time::{Duration, SystemTime};
use tarpc::{ use tarpc::{
client::{self}, client::{self},
context, context,
server::{self, incoming::Incoming, BaseChannel, Channel}, server::{incoming::Incoming, BaseChannel, Channel},
transport::channel, transport::channel,
}; };
use tokio::join; use tokio::join;
@@ -22,39 +22,29 @@ trait Service {
struct Server; struct Server;
impl Service for Server { impl Service for Server {
type AddFut = Ready<i32>; async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
x + y
fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut {
ready(x + y)
} }
type HeyFut = Ready<String>; async fn hey(self, _: context::Context, name: String) -> String {
format!("Hey, {name}.")
fn hey(self, _: context::Context, name: String) -> Self::HeyFut {
ready(format!("Hey, {name}."))
} }
} }
#[tokio::test] #[tokio::test]
async fn sequential() -> anyhow::Result<()> { async fn sequential() {
let _ = tracing_subscriber::fmt::try_init(); let (tx, rx) = tarpc::transport::channel::unbounded();
let client = client::new(client::Config::default(), tx).spawn();
let (tx, rx) = channel::unbounded(); let channel = BaseChannel::with_defaults(rx);
tokio::spawn( tokio::spawn(
BaseChannel::new(server::Config::default(), rx) channel
.requests() .execute(tarpc::server::serve(|_, i| async move { Ok(i + 1) }))
.execute(Server.serve()), .for_each(|response| response),
);
assert_eq!(
client.call(context::current(), "AddOne", 1).await.unwrap(),
2
); );
let client = ServiceClient::new(client::Config::default(), tx).spawn();
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
assert_matches!(
client.hey(context::current(), "Tim".into()).await,
Ok(ref s) if s == "Hey, Tim.");
Ok(())
} }
#[tokio::test] #[tokio::test]
@@ -70,7 +60,6 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
#[derive(Debug)] #[derive(Debug)]
struct AllHandlersComplete; struct AllHandlersComplete;
#[tarpc::server]
impl Loop for LoopServer { impl Loop for LoopServer {
async fn r#loop(self, _: context::Context) { async fn r#loop(self, _: context::Context) {
loop { loop {
@@ -121,7 +110,9 @@ async fn serde_tcp() -> anyhow::Result<()> {
.take(1) .take(1)
.filter_map(|r| async { r.ok() }) .filter_map(|r| async { r.ok() })
.map(BaseChannel::with_defaults) .map(BaseChannel::with_defaults)
.execute(Server.serve()), .execute(Server.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
); );
let transport = serde_transport::tcp::connect(addr, Json::default).await?; let transport = serde_transport::tcp::connect(addr, Json::default).await?;
@@ -151,7 +142,9 @@ async fn serde_uds() -> anyhow::Result<()> {
.take(1) .take(1)
.filter_map(|r| async { r.ok() }) .filter_map(|r| async { r.ok() })
.map(BaseChannel::with_defaults) .map(BaseChannel::with_defaults)
.execute(Server.serve()), .execute(Server.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
); );
let transport = serde_transport::unix::connect(&sock, Json::default).await?; let transport = serde_transport::unix::connect(&sock, Json::default).await?;
@@ -175,7 +168,9 @@ async fn concurrent() -> anyhow::Result<()> {
tokio::spawn( tokio::spawn(
stream::once(ready(rx)) stream::once(ready(rx))
.map(BaseChannel::with_defaults) .map(BaseChannel::with_defaults)
.execute(Server.serve()), .execute(Server.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
); );
let client = ServiceClient::new(client::Config::default(), tx).spawn(); let client = ServiceClient::new(client::Config::default(), tx).spawn();
@@ -199,7 +194,9 @@ async fn concurrent_join() -> anyhow::Result<()> {
tokio::spawn( tokio::spawn(
stream::once(ready(rx)) stream::once(ready(rx))
.map(BaseChannel::with_defaults) .map(BaseChannel::with_defaults)
.execute(Server.serve()), .execute(Server.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
); );
let client = ServiceClient::new(client::Config::default(), tx).spawn(); let client = ServiceClient::new(client::Config::default(), tx).spawn();
@@ -216,15 +213,20 @@ async fn concurrent_join() -> anyhow::Result<()> {
Ok(()) Ok(())
} }
#[cfg(test)]
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::test] #[tokio::test]
async fn concurrent_join_all() -> anyhow::Result<()> { async fn concurrent_join_all() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init(); let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded(); let (tx, rx) = channel::unbounded();
tokio::spawn( tokio::spawn(
stream::once(ready(rx)) BaseChannel::with_defaults(rx)
.map(BaseChannel::with_defaults) .execute(Server.serve())
.execute(Server.serve()), .for_each(spawn),
); );
let client = ServiceClient::new(client::Config::default(), tx).spawn(); let client = ServiceClient::new(client::Config::default(), tx).spawn();
@@ -249,11 +251,9 @@ async fn counter() -> anyhow::Result<()> {
struct CountService(u32); struct CountService(u32);
impl Counter for &mut CountService { impl Counter for &mut CountService {
type CountFut = futures::future::Ready<u32>; async fn count(self, _: context::Context) -> u32 {
fn count(self, _: context::Context) -> Self::CountFut {
self.0 += 1; self.0 += 1;
futures::future::ready(self.0) self.0
} }
} }