12 Commits

Author SHA1 Message Date
Tim Kuehn
7e521768ab Prepare for v0.21.0 release. 2020-06-26 20:05:02 -07:00
Tim Kuehn
e9b1e7d101 Use #[non_exhaustive] in lieu of _NonExhaustive enum variant. 2020-06-26 19:47:20 -07:00
Taiki Endo
f0322fb892 Remove uses of pin_project::project attribute
pin-project will deprecate the project attribute due to some unfixable
limitations.

Refs: https://github.com/taiki-e/pin-project/issues/225
2020-06-05 20:34:44 -07:00
Patrick Elsen
617daebb88 Add tarpc::server proc-macro as syntactic sugar for async methods. (#302)
The tarpc::server proc-macro can be used to annotate implementations of
services to turn async functions into the proper declarations needed
for tarpc to be able to call them.

This uses the assert_type_eq crate to check that the transformations
applied by the tarpc::server proc macro are correct and lead to code
that compiles.
2020-05-16 10:25:25 -07:00
Tim Kuehn
a11d4fff58 Remove raii_counter 2020-04-22 02:13:02 -07:00
Tim
bf42a04d83 Move the request timeout so that it surrounds the entire call, not just the response future. (#295)
* Move the request timeout so that it surrounds the entire call, not just the response future.

This will enable the timeout earlier, so that a backlog in the outbound request buffer can not cause requests to stall indefinitely.

* Run cargo fmt
2020-02-25 14:42:40 -08:00
Tim Kuehn
06528d6953 Fix clippy lint. 2019-12-19 12:28:26 -08:00
Tim Kuehn
9f00395746 Replace _non_exhaustive fields with #[non_exhaustive] attribute.
The attribute landed on stable rust (1.40.0) today.

Fixes https://github.com/google/tarpc/issues/275
2019-12-19 12:14:34 -08:00
Tim Kuehn
e0674cd57f Make pre-push run on rust stable. 2019-12-19 12:06:06 -08:00
Tim Kuehn
7e49bd9ee7 Clean up badges a bit. 2019-12-16 13:21:00 -08:00
Tim Kuehn
8a1baa9c4e Remove usage of unsafe in rpc::client::channel.
pin_project is actually able to handle the complexities of enum Futures.
2019-12-16 11:10:57 -08:00
Oleg Nosov
31c713d188 Allow raw identifiers + fixed naming + place all code generation methods in impl (#291)
Allows defining services using raw identifiers like:

```rust
pub mod service {
    #[tarpc::service]
    pub trait r#trait {
        async fn r#fn(x: i32) -> Result<u8, String>;
    }
}
```

Also:

- Refactored names (ident -> type)
- All code generation methods placed in impl
2019-12-12 10:13:57 -08:00
18 changed files with 829 additions and 499 deletions

View File

@@ -1,6 +1,16 @@
[![Build Status](https://github.com/google/tarpc/workflows/Continuous%20integration/badge.svg)](https://github.com/google/tarpc/actions?query=workflow%3A%22Continuous+integration%22)
[![Latest Version](https://img.shields.io/crates/v/tarpc.svg)](https://crates.io/crates/tarpc)
[![Chat on Discord](https://img.shields.io/discord/647529123996237854)](https://discordapp.com/channels/647529123996237854)
[![Crates.io][crates-badge]][crates-url]
[![MIT licensed][mit-badge]][mit-url]
[![Build status][gh-actions-badge]][gh-actions-url]
[![Discord chat][discord-badge]][discord-url]
[crates-badge]: https://img.shields.io/crates/v/tarpc.svg
[crates-url]: https://crates.io/crates/tarpc
[mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg
[mit-url]: LICENSE
[gh-actions-badge]: https://github.com/google/tarpc/workflows/Continuous%20integration/badge.svg
[gh-actions-url]: https://github.com/google/tarpc/actions?query=workflow%3A%22Continuous+integration%22
[discord-badge]: https://img.shields.io/discord/647529123996237854.svg?logo=discord&style=flat-square
[discord-url]: https://discord.gg/gXwpdSt
# tarpc
@@ -47,7 +57,7 @@ Some other features of tarpc:
Add to your `Cargo.toml` dependencies:
```toml
tarpc = { version = "0.18.0", features = ["full"] }
tarpc = { version = "0.21.0", features = ["full"] }
```
The `tarpc::service` attribute expands to a collection of items that form an rpc service.

View File

@@ -1,3 +1,26 @@
## 0.21.0 (2020-06-26)
### New Features
A new proc macro, `#[tarpc::server]` was added! This enables service impls to elide the boilerplate
of specifying associated types for each RPC. With the ubiquity of async-await, most code won't have
nameable futures and will just be boxing the return type anyway. This macro does that for you.
### Breaking Changes
- Enums had _non_exhaustive fields replaced with the #[non_exhaustive] attribute.
### Bug Fixes
- https://github.com/google/tarpc/issues/304
A race condition in code that limits number of connections per client caused occasional panics.
- https://github.com/google/tarpc/pull/295
Made request timeouts account for time spent in the outbound buffer. Previously, a large outbound
queue would lead to requests not timing out correctly.
## 0.20.0 (2019-12-11)
### Breaking Changes

View File

@@ -16,7 +16,7 @@ description = "An example server built on tarpc."
clap = "2.0"
futures = "0.3"
serde = { version = "1.0" }
tarpc = { version = "0.20", path = "../tarpc", features = ["full"] }
tarpc = { version = "0.21", path = "../tarpc", features = ["full"] }
tokio = { version = "0.2", features = ["full"] }
tokio-serde = { version = "0.6", features = ["json"] }
env_logger = "0.6"

View File

@@ -89,12 +89,12 @@ if [ "$?" == 0 ]; then
exit 1
fi
try_run "Building ... " cargo build --color=always
try_run "Testing ... " cargo test --color=always
try_run "Testing with all features enabled ... " cargo test --all-features --color=always
for EXAMPLE in $(cargo run --example 2>&1 | grep ' ' | awk '{print $1}')
try_run "Building ... " cargo +stable build --color=always
try_run "Testing ... " cargo +stable test --color=always
try_run "Testing with all features enabled ... " cargo +stable test --all-features --color=always
for EXAMPLE in $(cargo +stable run --example 2>&1 | grep ' ' | awk '{print $1}')
do
try_run "Running example \"$EXAMPLE\" ... " cargo run --example $EXAMPLE
try_run "Running example \"$EXAMPLE\" ... " cargo +stable run --example $EXAMPLE
done
fi

View File

@@ -1,6 +1,6 @@
[package]
name = "tarpc-plugins"
version = "0.7.0"
version = "0.8.0"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2018"
license = "MIT"
@@ -30,3 +30,4 @@ proc-macro = true
futures = "0.3"
serde = { version = "1.0", features = ["derive"] }
tarpc = { path = "../tarpc" }
assert-type-eq = "0.1.0"

View File

@@ -12,15 +12,18 @@ extern crate quote;
extern crate syn;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote, ToTokens};
use syn::{
braced, parenthesized,
braced,
ext::IdentExt,
parenthesized,
parse::{Parse, ParseStream},
parse_macro_input, parse_quote,
parse_macro_input, parse_quote, parse_str,
punctuated::Punctuated,
token::Comma,
Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type,
Visibility,
Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool,
MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility,
};
struct Service {
@@ -42,7 +45,7 @@ impl Parse for Service {
let attrs = input.call(Attribute::parse_outer)?;
let vis = input.parse()?;
input.parse::<Token![trait]>()?;
let ident = input.parse()?;
let ident: Ident = input.parse()?;
let content;
braced!(content in input);
let mut rpcs = Vec::<RpcMethod>::new();
@@ -53,7 +56,7 @@ impl Parse for Service {
if rpc.ident == "new" {
return Err(input.error(format!(
"method name conflicts with generated fn `{}Client::new`",
ident
ident.unraw()
)));
}
if rpc.ident == "serve" {
@@ -63,6 +66,7 @@ impl Parse for Service {
)));
}
}
Ok(Self {
attrs,
vis,
@@ -156,19 +160,19 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
ref rpcs,
} = parse_macro_input!(input as Service);
let camel_case_fn_names: &[String] = &rpcs
let camel_case_fn_names: &Vec<_> = &rpcs
.iter()
.map(|rpc| snake_to_camel(&rpc.ident.to_string()))
.collect::<Vec<_>>();
.map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
.collect();
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
let response_fut_name = &format!("{}ResponseFut", ident);
let derive_serialize = quote!(#[derive(serde::Serialize, serde::Deserialize)]);
let response_fut_name = &format!("{}ResponseFut", ident.unraw());
let derive_serialize = if derive_serde.0 {
Some(quote!(#[derive(serde::Serialize, serde::Deserialize)]))
} else {
None
};
let gen_args = &GenArgs {
attrs,
vis,
rpcs,
args,
ServiceGenerator {
response_fut_name,
service_ident: ident,
server_ident: &format_ident!("Serve{}", ident),
@@ -176,8 +180,12 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
client_ident: &format_ident!("{}Client", ident),
request_ident: &format_ident!("{}Request", ident),
response_ident: &format_ident!("{}Response", ident),
vis,
args,
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
method_names: &rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>(),
method_idents: &rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>(),
attrs,
rpcs,
return_types: &rpcs
.iter()
.map(|rpc| match rpc.output {
@@ -185,7 +193,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
ReturnType::Default => unit_type,
})
.collect::<Vec<_>>(),
arg_vars: &args
arg_pats: &args
.iter()
.map(|args| args.iter().map(|arg| &*arg.pat).collect())
.collect::<Vec<_>>(),
@@ -194,43 +202,139 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
.zip(camel_case_fn_names.iter())
.map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
.collect::<Vec<_>>(),
future_idents: &camel_case_fn_names
future_types: &camel_case_fn_names
.iter()
.map(|name| format_ident!("{}Fut", name))
.map(|name| parse_str(&format!("{}Fut", name)).unwrap())
.collect::<Vec<_>>(),
derive_serialize: if derive_serde.0 {
Some(&derive_serialize)
} else {
None
},
derive_serialize: derive_serialize.as_ref(),
}
.into_token_stream()
.into()
}
/// Transforms an async function into a sync one, returning a type declaration
/// for the return type (a future).
fn transform_method(method: &mut ImplItemMethod) -> ImplItemType {
method.sig.asyncness = None;
// get either the return type or ().
let ret = match &method.sig.output {
ReturnType::Default => quote!(()),
ReturnType::Type(_, ret) => quote!(#ret),
};
let tokens = vec![
trait_service(gen_args),
struct_server(gen_args),
impl_serve_for_server(gen_args),
enum_request(gen_args),
enum_response(gen_args),
enum_response_future(gen_args),
impl_debug_for_response_future(gen_args),
impl_future_for_response_future(gen_args),
struct_client(gen_args),
impl_from_for_client(gen_args),
impl_client_new(gen_args),
impl_client_rpc_methods(gen_args),
];
// generate an identifier consisting of the method name to CamelCase with
// Fut appended to it.
let fut_name = snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut";
let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span());
tokens
.into_iter()
.collect::<proc_macro2::TokenStream>()
.into()
// generate the updated return signature.
method.sig.output = parse_quote! {
-> ::core::pin::Pin<Box<
dyn ::core::future::Future<Output = #ret> + ::core::marker::Send
>>
};
// transform the body of the method into Box::pin(async move { body }).
let block = method.block.clone();
method.block = parse_quote! [{
Box::pin(async move
#block
)
}];
// generate and return type declaration for return type.
let t: ImplItemType = parse_quote! {
type #fut_name_ident = ::core::pin::Pin<Box<dyn ::core::future::Future<Output = #ret> + ::core::marker::Send>>;
};
t
}
/// Syntactic sugar to make using async functions in the server implementation
/// easier. It does this by rewriting code like this, which would normally not
/// compile because async functions are disallowed in trait implementations:
///
/// ```rust
/// # extern crate tarpc;
/// # use tarpc::context;
/// # use std::net::SocketAddr;
/// #[tarpc_plugins::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// #[tarpc_plugins::server]
/// impl World for HelloServer {
/// async fn hello(self, _: context::Context, name: String) -> String {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// }
/// }
/// ```
///
/// Into code like this, which matches the service trait definition:
///
/// ```rust
/// # extern crate tarpc;
/// # use tarpc::context;
/// # use std::pin::Pin;
/// # use futures::Future;
/// # use std::net::SocketAddr;
/// #[tarpc_plugins::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// impl World for HelloServer {
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
///
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
/// + Send>> {
/// Box::pin(async move {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// })
/// }
/// }
/// ```
///
/// Note that this won't touch functions unless they have been annotated with
/// `async`, meaning that this should not break existing code.
#[proc_macro_attribute]
pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream {
let mut item = syn::parse_macro_input!(input as ItemImpl);
// the generated type declarations
let mut types: Vec<ImplItemType> = Vec::new();
for inner in &mut item.items {
if let ImplItem::Method(method) = inner {
let sig = &method.sig;
// if this function is declared async, transform it into a regular function
if sig.asyncness.is_some() {
let typedecl = transform_method(method);
types.push(typedecl);
}
}
}
// add the type declarations into the impl block
for t in types.into_iter() {
item.items.push(syn::ImplItem::Type(t));
}
TokenStream::from(quote!(#item))
}
// Things needed to generate the service items: trait, serve impl, request/response enums, and
// the client stub.
struct GenArgs<'a> {
attrs: &'a [Attribute],
rpcs: &'a [RpcMethod],
struct ServiceGenerator<'a> {
service_ident: &'a Ident,
server_ident: &'a Ident,
response_fut_ident: &'a Ident,
@@ -238,331 +342,356 @@ struct GenArgs<'a> {
client_ident: &'a Ident,
request_ident: &'a Ident,
response_ident: &'a Ident,
method_attrs: &'a [&'a [Attribute]],
vis: &'a Visibility,
method_names: &'a [&'a Ident],
attrs: &'a [Attribute],
rpcs: &'a [RpcMethod],
camel_case_idents: &'a [Ident],
future_types: &'a [Type],
method_idents: &'a [&'a Ident],
method_attrs: &'a [&'a [Attribute]],
args: &'a [&'a [PatType]],
return_types: &'a [&'a Type],
arg_vars: &'a [Vec<&'a Pat>],
camel_case_idents: &'a [Ident],
future_idents: &'a [Ident],
derive_serialize: Option<&'a proc_macro2::TokenStream>,
arg_pats: &'a [Vec<&'a Pat>],
derive_serialize: Option<&'a TokenStream2>,
}
fn trait_service(
&GenArgs {
attrs,
rpcs,
vis,
future_idents,
return_types,
service_ident,
server_ident,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
let types_and_fns = rpcs
.iter()
.zip(future_idents.iter())
.zip(return_types.iter())
.map(
|(
(
RpcMethod {
attrs, ident, args, ..
},
future_type,
),
output,
)| {
let ty_doc = format!("The response future returned by {}.", ident);
quote! {
#[doc = #ty_doc]
type #future_type: std::future::Future<Output = #output>;
impl<'a> ServiceGenerator<'a> {
fn trait_service(&self) -> TokenStream2 {
let &Self {
attrs,
rpcs,
vis,
future_types,
return_types,
service_ident,
server_ident,
..
} = self;
#( #attrs )*
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type;
}
},
);
let types_and_fns = rpcs
.iter()
.zip(future_types.iter())
.zip(return_types.iter())
.map(
|(
(
RpcMethod {
attrs, ident, args, ..
},
future_type,
),
output,
)| {
let ty_doc = format!("The response future returned by {}.", ident);
quote! {
#[doc = #ty_doc]
type #future_type: std::future::Future<Output = #output>;
quote! {
#( #attrs )*
#vis trait #service_ident: Clone {
#( #types_and_fns )*
#( #attrs )*
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type;
}
},
);
/// Returns a serving function to use with tarpc::server::Server.
fn serve(self) -> #server_ident<Self> {
#server_ident { service: self }
}
}
}
}
quote! {
#( #attrs )*
#vis trait #service_ident: Clone {
#( #types_and_fns )*
fn struct_server(
&GenArgs {
vis, server_ident, ..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
#[derive(Clone)]
#vis struct #server_ident<S> {
service: S,
}
}
}
fn impl_serve_for_server(
&GenArgs {
request_ident,
server_ident,
service_ident,
response_ident,
response_fut_ident,
camel_case_idents,
arg_vars,
method_names,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl<S> tarpc::server::Serve<#request_ident> for #server_ident<S>
where S: #service_ident
{
type Resp = #response_ident;
type Fut = #response_fut_ident<S>;
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
match req {
#(
#request_ident::#camel_case_idents{ #( #arg_vars ),* } => {
#response_fut_ident::#camel_case_idents(
#service_ident::#method_names(
self.service, ctx, #( #arg_vars ),*))
}
)*
/// Returns a serving function to use with tarpc::server::Server.
fn serve(self) -> #server_ident<Self> {
#server_ident { service: self }
}
}
}
}
}
fn enum_request(
&GenArgs {
derive_serialize,
vis,
request_ident,
camel_case_idents,
args,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
/// The request sent over the wire from the client to the server.
#[derive(Debug)]
#derive_serialize
#vis enum #request_ident {
#( #camel_case_idents{ #( #args ),* } ),*
}
}
}
fn struct_server(&self) -> TokenStream2 {
let &Self {
vis, server_ident, ..
} = self;
fn enum_response(
&GenArgs {
derive_serialize,
vis,
response_ident,
camel_case_idents,
return_types,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
/// The response sent over the wire from the server to the client.
#[derive(Debug)]
#derive_serialize
#vis enum #response_ident {
#( #camel_case_idents(#return_types) ),*
}
}
}
fn enum_response_future(
&GenArgs {
vis,
service_ident,
response_fut_ident,
camel_case_idents,
future_idents,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
/// A future resolving to a server response.
#vis enum #response_fut_ident<S: #service_ident> {
#( #camel_case_idents(<S as #service_ident>::#future_idents) ),*
}
}
}
fn impl_debug_for_response_future(
&GenArgs {
service_ident,
response_fut_ident,
response_fut_name,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct(#response_fut_name).finish()
quote! {
#[derive(Clone)]
#vis struct #server_ident<S> {
service: S,
}
}
}
}
fn impl_future_for_response_future(
&GenArgs {
service_ident,
response_fut_ident,
response_ident,
camel_case_idents,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl<S: #service_ident> std::future::Future for #response_fut_ident<S> {
type Output = #response_ident;
fn impl_serve_for_server(&self) -> TokenStream2 {
let &Self {
request_ident,
server_ident,
service_ident,
response_ident,
response_fut_ident,
camel_case_idents,
arg_pats,
method_idents,
..
} = self;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
-> std::task::Poll<#response_ident>
quote! {
impl<S> tarpc::server::Serve<#request_ident> for #server_ident<S>
where S: #service_ident
{
unsafe {
match std::pin::Pin::get_unchecked_mut(self) {
type Resp = #response_ident;
type Fut = #response_fut_ident<S>;
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
match req {
#(
#response_fut_ident::#camel_case_idents(resp) =>
std::pin::Pin::new_unchecked(resp)
.poll(cx)
.map(#response_ident::#camel_case_idents),
#request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
#response_fut_ident::#camel_case_idents(
#service_ident::#method_idents(
self.service, ctx, #( #arg_pats ),*
)
)
}
)*
}
}
}
}
}
}
fn struct_client(
&GenArgs {
vis,
client_ident,
request_ident,
response_ident,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
#[allow(unused)]
#[derive(Clone, Debug)]
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
#vis struct #client_ident<C = tarpc::client::Channel<#request_ident, #response_ident>>(C);
}
}
fn enum_request(&self) -> TokenStream2 {
let &Self {
derive_serialize,
vis,
request_ident,
camel_case_idents,
args,
..
} = self;
fn impl_from_for_client(
&GenArgs {
client_ident,
request_ident,
response_ident,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl<C> From<C> for #client_ident<C>
where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
fn from(client: C) -> Self {
#client_ident(client)
quote! {
/// The request sent over the wire from the client to the server.
#[derive(Debug)]
#derive_serialize
#vis enum #request_ident {
#( #camel_case_idents{ #( #args ),* } ),*
}
}
}
}
fn impl_client_new(
&GenArgs {
client_ident,
vis,
request_ident,
response_ident,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl #client_ident {
/// Returns a new client stub that sends requests over the given transport.
#vis fn new<T>(config: tarpc::client::Config, transport: T)
-> tarpc::client::NewClient<
Self,
tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>>
where
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>
{
let new_client = tarpc::client::new(config, transport);
tarpc::client::NewClient {
client: #client_ident(new_client.client),
dispatch: new_client.dispatch,
fn enum_response(&self) -> TokenStream2 {
let &Self {
derive_serialize,
vis,
response_ident,
camel_case_idents,
return_types,
..
} = self;
quote! {
/// The response sent over the wire from the server to the client.
#[derive(Debug)]
#derive_serialize
#vis enum #response_ident {
#( #camel_case_idents(#return_types) ),*
}
}
}
fn enum_response_future(&self) -> TokenStream2 {
let &Self {
vis,
service_ident,
response_fut_ident,
camel_case_idents,
future_types,
..
} = self;
quote! {
/// A future resolving to a server response.
#vis enum #response_fut_ident<S: #service_ident> {
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
}
}
}
fn impl_debug_for_response_future(&self) -> TokenStream2 {
let &Self {
service_ident,
response_fut_ident,
response_fut_name,
..
} = self;
quote! {
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct(#response_fut_name).finish()
}
}
}
}
}
fn impl_client_rpc_methods(
&GenArgs {
client_ident,
request_ident,
response_ident,
method_attrs,
vis,
method_names,
args,
return_types,
arg_vars,
camel_case_idents,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl<C> #client_ident<C>
where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
#(
#[allow(unused)]
#( #method_attrs )*
#vis fn #method_names(&mut self, ctx: tarpc::context::Context, #( #args ),*)
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_vars ),* };
let resp = tarpc::Client::call(&mut self.0, ctx, request);
async move {
match resp.await? {
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
_ => unreachable!(),
fn impl_future_for_response_future(&self) -> TokenStream2 {
let &Self {
service_ident,
response_fut_ident,
response_ident,
camel_case_idents,
..
} = self;
quote! {
impl<S: #service_ident> std::future::Future for #response_fut_ident<S> {
type Output = #response_ident;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
-> std::task::Poll<#response_ident>
{
unsafe {
match std::pin::Pin::get_unchecked_mut(self) {
#(
#response_fut_ident::#camel_case_idents(resp) =>
std::pin::Pin::new_unchecked(resp)
.poll(cx)
.map(#response_ident::#camel_case_idents),
)*
}
}
}
)*
}
}
}
fn struct_client(&self) -> TokenStream2 {
let &Self {
vis,
client_ident,
request_ident,
response_ident,
..
} = self;
quote! {
#[allow(unused)]
#[derive(Clone, Debug)]
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
#vis struct #client_ident<C = tarpc::client::Channel<#request_ident, #response_ident>>(C);
}
}
fn impl_from_for_client(&self) -> TokenStream2 {
let &Self {
client_ident,
request_ident,
response_ident,
..
} = self;
quote! {
impl<C> From<C> for #client_ident<C>
where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
fn from(client: C) -> Self {
#client_ident(client)
}
}
}
}
fn impl_client_new(&self) -> TokenStream2 {
let &Self {
client_ident,
vis,
request_ident,
response_ident,
..
} = self;
quote! {
impl #client_ident {
/// Returns a new client stub that sends requests over the given transport.
#vis fn new<T>(config: tarpc::client::Config, transport: T)
-> tarpc::client::NewClient<
Self,
tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>
>
where
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>
{
let new_client = tarpc::client::new(config, transport);
tarpc::client::NewClient {
client: #client_ident(new_client.client),
dispatch: new_client.dispatch,
}
}
}
}
}
fn impl_client_rpc_methods(&self) -> TokenStream2 {
let &Self {
client_ident,
request_ident,
response_ident,
method_attrs,
vis,
method_idents,
args,
return_types,
arg_pats,
camel_case_idents,
..
} = self;
quote! {
impl<C> #client_ident<C>
where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
#(
#[allow(unused)]
#( #method_attrs )*
#vis fn #method_idents(&mut self, ctx: tarpc::context::Context, #( #args ),*)
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
let resp = tarpc::Client::call(&mut self.0, ctx, request);
async move {
match resp.await? {
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
_ => unreachable!(),
}
}
}
)*
}
}
}
}
impl<'a> ToTokens for ServiceGenerator<'a> {
fn to_tokens(&self, output: &mut TokenStream2) {
output.extend(vec![
self.trait_service(),
self.struct_server(),
self.impl_serve_for_server(),
self.enum_request(),
self.enum_response(),
self.enum_response_future(),
self.impl_debug_for_response_future(),
self.impl_future_for_response_future(),
self.struct_client(),
self.impl_from_for_client(),
self.impl_client_new(),
self.impl_client_rpc_methods(),
])
}
}
fn snake_to_camel(ident_str: &str) -> String {
let chars = ident_str.chars();
let mut camel_ty = String::with_capacity(ident_str.len());
let mut last_char_was_underscore = true;
for c in chars {
for c in ident_str.chars() {
match c {
'_' => last_char_was_underscore = true,
c if last_char_was_underscore => {

144
plugins/tests/server.rs Normal file
View File

@@ -0,0 +1,144 @@
use assert_type_eq::assert_type_eq;
use futures::Future;
use std::pin::Pin;
use tarpc::context;
// these need to be out here rather than inside the function so that the
// assert_type_eq macro can pick them up.
#[tarpc::service]
trait Foo {
async fn two_part(s: String, i: i32) -> (String, i32);
async fn bar(s: String) -> String;
async fn baz();
}
#[test]
fn type_generation_works() {
#[tarpc::server]
impl Foo for () {
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
(s, i)
}
async fn bar(self, _: context::Context, s: String) -> String {
s
}
async fn baz(self, _: context::Context) {}
}
// the assert_type_eq macro can only be used once per block.
{
assert_type_eq!(
<() as Foo>::TwoPartFut,
Pin<Box<dyn Future<Output = (String, i32)> + Send>>
);
}
{
assert_type_eq!(
<() as Foo>::BarFut,
Pin<Box<dyn Future<Output = String> + Send>>
);
}
{
assert_type_eq!(
<() as Foo>::BazFut,
Pin<Box<dyn Future<Output = ()> + Send>>
);
}
}
#[allow(non_camel_case_types)]
#[test]
fn raw_idents_work() {
type r#yield = String;
#[tarpc::service]
trait r#trait {
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
async fn r#fn(r#impl: r#yield) -> r#yield;
async fn r#async();
}
#[tarpc::server]
impl r#trait for () {
async fn r#await(
self,
_: context::Context,
r#struct: r#yield,
r#enum: i32,
) -> (r#yield, i32) {
(r#struct, r#enum)
}
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
r#impl
}
async fn r#async(self, _: context::Context) {}
}
}
#[test]
fn syntax() {
#[tarpc::service]
trait Syntax {
#[deny(warnings)]
#[allow(non_snake_case)]
async fn TestCamelCaseDoesntConflict();
async fn hello() -> String;
#[doc = "attr"]
async fn attr(s: String) -> String;
async fn no_args_no_return();
async fn no_args() -> ();
async fn one_arg(one: String) -> i32;
async fn two_args_no_return(one: String, two: u64);
async fn two_args(one: String, two: u64) -> String;
async fn no_args_ret_error() -> i32;
async fn one_arg_ret_error(one: String) -> String;
async fn no_arg_implicit_return_error();
#[doc = "attr"]
async fn one_arg_implicit_return_error(one: String);
}
#[tarpc::server]
impl Syntax for () {
#[deny(warnings)]
#[allow(non_snake_case)]
async fn TestCamelCaseDoesntConflict(self, _: context::Context) {}
async fn hello(self, _: context::Context) -> String {
String::new()
}
async fn attr(self, _: context::Context, _s: String) -> String {
String::new()
}
async fn no_args_no_return(self, _: context::Context) {}
async fn no_args(self, _: context::Context) -> () {}
async fn one_arg(self, _: context::Context, _one: String) -> i32 {
0
}
async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {}
async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String {
String::new()
}
async fn no_args_ret_error(self, _: context::Context) -> i32 {
0
}
async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String {
String::new()
}
async fn no_arg_implicit_return_error(self, _: context::Context) {}
async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {}
}
}

View File

@@ -29,6 +29,38 @@ fn att_service_trait() {
}
}
#[allow(non_camel_case_types)]
#[test]
fn raw_idents() {
use futures::future::{ready, Ready};
type r#yield = String;
#[tarpc::service]
trait r#trait {
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
async fn r#fn(r#impl: r#yield) -> r#yield;
async fn r#async();
}
impl r#trait for () {
type AwaitFut = Ready<(r#yield, i32)>;
fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut {
ready((r#struct, r#enum))
}
type FnFut = Ready<r#yield>;
fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut {
ready(r#impl)
}
type AsyncFut = Ready<()>;
fn r#async(self, _: context::Context) -> Self::AsyncFut {
ready(())
}
}
}
#[test]
fn syntax() {
#[tarpc::service]

View File

@@ -1,6 +1,6 @@
[package]
name = "tarpc"
version = "0.20.0"
version = "0.21.0"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2018"
license = "MIT"
@@ -30,13 +30,12 @@ fnv = "1.0"
futures = "0.3"
humantime = "1.0"
log = "0.4"
pin-project = "0.4"
raii-counter = "0.2"
pin-project = "0.4.17"
rand = "0.7"
tokio = { version = "0.2", features = ["time"] }
serde = { optional = true, version = "1.0", features = ["derive"] }
tokio-util = { optional = true, version = "0.2" }
tarpc-plugins = { path = "../plugins", version = "0.7" }
tarpc-plugins = { path = "../plugins", version = "0.8" }
tokio-serde = { optional = true, version = "0.6" }
[dev-dependencies]
@@ -61,4 +60,3 @@ required-features = ["full"]
[[example]]
name = "pubsub"
required-features = ["full"]

View File

@@ -50,7 +50,7 @@
//! Add to your `Cargo.toml` dependencies:
//!
//! ```toml
//! tarpc = "0.20.0"
//! tarpc = "0.21.0"
//! ```
//!
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
@@ -215,7 +215,6 @@ pub mod trace;
/// Rpc methods are specified, mirroring trait syntax:
///
/// ```
/// # fn main() {}
/// #[tarpc::service]
/// trait Service {
/// /// Say hello
@@ -234,3 +233,59 @@ pub mod trace;
/// * `Client` -- a client stub with a fn for each RPC.
/// * `fn new_stub` -- creates a new Client stub.
pub use tarpc_plugins::service;
/// A utility macro that can be used for RPC server implementations.
///
/// Syntactic sugar to make using async functions in the server implementation
/// easier. It does this by rewriting code like this, which would normally not
/// compile because async functions are disallowed in trait implementations:
///
/// ```rust
/// # use tarpc::context;
/// # use std::net::SocketAddr;
/// #[tarpc::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// #[tarpc::server]
/// impl World for HelloServer {
/// async fn hello(self, _: context::Context, name: String) -> String {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// }
/// }
/// ```
///
/// Into code like this, which matches the service trait definition:
///
/// ```rust
/// # use tarpc::context;
/// # use std::pin::Pin;
/// # use futures::Future;
/// # use std::net::SocketAddr;
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// #[tarpc::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// impl World for HelloServer {
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
///
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
/// + Send>> {
/// Box::pin(async move {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// })
/// }
/// }
/// ```
///
/// Note that this won't touch functions unless they have been annotated with
/// `async`, meaning that this should not break existing code.
pub use tarpc_plugins::server;

View File

@@ -78,14 +78,21 @@ impl<'a, Req, Resp> Future for Send<'a, Req, Resp> {
#[must_use = "futures do nothing unless polled"]
pub struct Call<'a, Req, Resp> {
#[pin]
fut: AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>,
fut: tokio::time::Timeout<AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>>,
}
impl<'a, Req, Resp> Future for Call<'a, Req, Resp> {
type Output = io::Result<Resp>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.as_mut().project().fut.poll(cx)
let resp = ready!(self.as_mut().project().fut.poll(cx));
Poll::Ready(match resp {
Ok(resp) => resp,
Err(tokio::time::Elapsed { .. }) => Err(io::Error::new(
io::ErrorKind::TimedOut,
"Client dropped expired request.".to_string(),
)),
})
}
}
@@ -97,13 +104,6 @@ impl<Req, Resp> Channel<Req, Resp> {
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
let timeout = ctx.deadline.time_until();
trace!(
"[{}] Queuing request with timeout {:?}.",
ctx.trace_id(),
timeout,
);
let (response_completion, response) = oneshot::channel();
let cancellation = self.cancellation.clone();
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
@@ -116,7 +116,7 @@ impl<Req, Resp> Channel<Req, Resp> {
response_completion,
})),
DispatchResponse {
response: tokio::time::timeout(timeout, response),
response,
complete: false,
request_id,
cancellation,
@@ -128,9 +128,16 @@ impl<Req, Resp> Channel<Req, Resp> {
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves to the response.
pub fn call(&mut self, context: context::Context, request: Req) -> Call<Req, Resp> {
pub fn call(&mut self, ctx: context::Context, request: Req) -> Call<Req, Resp> {
let timeout = ctx.deadline.time_until();
trace!(
"[{}] Queuing request with timeout {:?}.",
ctx.trace_id(),
timeout,
);
Call {
fut: AndThenIdent::new(self.send(context, request)),
fut: tokio::time::timeout(timeout, AndThenIdent::new(self.send(ctx, request))),
}
}
}
@@ -140,7 +147,7 @@ impl<Req, Resp> Channel<Req, Resp> {
#[pin_project(PinnedDrop)]
#[derive(Debug)]
struct DispatchResponse<Resp> {
response: tokio::time::Timeout<oneshot::Receiver<Response<Resp>>>,
response: oneshot::Receiver<Response<Resp>>,
ctx: context::Context,
complete: bool,
cancellation: RequestCancellation,
@@ -152,24 +159,15 @@ impl<Resp> Future for DispatchResponse<Resp> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
let resp = ready!(self.response.poll_unpin(cx));
self.complete = true;
Poll::Ready(match resp {
Ok(resp) => {
self.complete = true;
match resp {
Ok(resp) => Ok(resp.message?),
Err(oneshot::Canceled) => {
// The oneshot is Canceled when the dispatch task ends. In that case,
// there's nothing listening on the other side, so there's no point in
// propagating cancellation.
Err(io::Error::from(io::ErrorKind::ConnectionReset))
}
}
Ok(resp) => Ok(resp.message?),
Err(oneshot::Canceled) => {
// The oneshot is Canceled when the dispatch task ends. In that case,
// there's nothing listening on the other side, so there's no point in
// propagating cancellation.
Err(io::Error::from(io::ErrorKind::ConnectionReset))
}
Err(tokio::time::Elapsed { .. }) => Err(io::Error::new(
io::ErrorKind::TimedOut,
"Client dropped expired request.".to_string(),
)),
})
}
}
@@ -189,7 +187,7 @@ impl<Resp> PinnedDrop for DispatchResponse<Resp> {
// closing the receiver before sending the cancel message, it is guaranteed that if the
// dispatch task misses an early-arriving cancellation message, then it will see the
// receiver as closed.
self.response.get_mut().close();
self.response.close();
let request_id = self.request_id;
self.cancellation.cancel(request_id);
}
@@ -385,9 +383,7 @@ where
context: context::Context {
deadline: dispatch_request.ctx.deadline,
trace_context: dispatch_request.ctx.trace_context,
_non_exhaustive: (),
},
_non_exhaustive: (),
});
self.as_mut().project().transport.start_send(request)?;
self.as_mut().project().in_flight_requests.insert(
@@ -632,11 +628,12 @@ where
}
}
#[pin_project(project = TryChainProj)]
#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
enum TryChain<Fut1, Fut2> {
First(Fut1),
Second(Fut2),
First(#[pin] Fut1),
Second(#[pin] Fut2),
Empty,
}
@@ -658,7 +655,7 @@ where
}
fn poll<F>(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
f: F,
) -> Poll<Result<Fut2::Ok, Fut2::Error>>
@@ -667,31 +664,28 @@ where
{
let mut f = Some(f);
// Safe to call `get_unchecked_mut` because we won't move the futures.
let this = unsafe { Pin::get_unchecked_mut(self) };
loop {
let output = match this {
TryChain::First(fut1) => {
let output = match self.as_mut().project() {
TryChainProj::First(fut1) => {
// Poll the first future
match unsafe { Pin::new_unchecked(fut1) }.try_poll(cx) {
match fut1.try_poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(output) => output,
}
}
TryChain::Second(fut2) => {
TryChainProj::Second(fut2) => {
// Poll the second future
return unsafe { Pin::new_unchecked(fut2) }.try_poll(cx);
return fut2.try_poll(cx);
}
TryChain::Empty => {
TryChainProj::Empty => {
panic!("future must not be polled after it returned `Poll::Ready`");
}
};
*this = TryChain::Empty; // Drop fut1
self.set(TryChain::Empty); // Drop fut1
let f = f.take().unwrap();
match f(output) {
TryChainAction::Future(fut2) => *this = TryChain::Second(fut2),
TryChainAction::Future(fut2) => self.set(TryChain::Second(fut2)),
TryChainAction::Output(output) => return Poll::Ready(output),
}
}
@@ -716,24 +710,21 @@ mod tests {
prelude::*,
task::*,
};
use std::time::Duration;
use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc};
#[tokio::test(threaded_scheduler)]
async fn dispatch_response_cancels_on_timeout() {
let (_response_completion, response) = oneshot::channel();
async fn dispatch_response_cancels_on_drop() {
let (cancellation, mut canceled_requests) = cancellations();
let resp = DispatchResponse::<u64> {
// Timeout in the past should cause resp to error out when polled.
response: tokio::time::timeout(Duration::from_secs(0), response),
let (_, response) = oneshot::channel();
drop(DispatchResponse::<u32> {
response,
cancellation,
complete: false,
request_id: 3,
cancellation,
ctx: context::current(),
};
let _ = futures::poll!(resp);
});
// resp's drop() is run, which should send a cancel message.
assert!(canceled_requests.0.try_next().unwrap() == Some(3));
assert_eq!(canceled_requests.0.try_next().unwrap(), Some(3));
}
#[tokio::test(threaded_scheduler)]
@@ -768,7 +759,6 @@ mod tests {
Response {
request_id: 0,
message: Ok("hello".into()),
_non_exhaustive: (),
},
)
.await;
@@ -822,7 +812,7 @@ mod tests {
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
// map.
let mut resp = send_request(&mut channel, "hi").await;
resp.response.get_mut().close();
resp.response.close();
assert!(dispatch.poll_next_request(cx).is_pending());
}

View File

@@ -104,6 +104,7 @@ where
/// Settings that control the behavior of the client.
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct Config {
/// The number of requests that can be in flight at once.
/// `max_in_flight_requests` controls the size of the map used by the client
@@ -113,8 +114,6 @@ pub struct Config {
/// `pending_requests_buffer` controls the size of the channel clients use
/// to communicate with the request dispatch task.
pub pending_request_buffer: usize,
#[doc(hidden)]
_non_exhaustive: (),
}
impl Default for Config {
@@ -122,7 +121,6 @@ impl Default for Config {
Config {
max_in_flight_requests: 1_000,
pending_request_buffer: 100,
_non_exhaustive: (),
}
}
}

View File

@@ -16,6 +16,7 @@ use std::time::{Duration, SystemTime};
/// The context should not be stored directly in a server implementation, because the context will
/// be different for each request in scope.
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Context {
/// When the client expects the request to be complete by. The server should cancel the request
@@ -35,9 +36,6 @@ pub struct Context {
/// include the same `trace_id` as that included on the original request. This way,
/// users can trace related actions across a distributed system.
pub trace_context: trace::Context,
#[doc(hidden)]
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
pub(crate) _non_exhaustive: (),
}
#[cfg(feature = "serde1")]
@@ -51,7 +49,6 @@ pub fn current() -> Context {
Context {
deadline: SystemTime::now() + Duration::from_secs(10),
trace_context: trace::Context::new_root(),
_non_exhaustive: (),
}
}

View File

@@ -38,6 +38,7 @@ use std::{io, time::SystemTime};
/// A message from a client to a server.
#[derive(Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum ClientMessage<T> {
/// A request initiated by a user. The server responds to a request by invoking a
/// service-provided request handler. The handler completes with a [`response`](Response), which
@@ -58,12 +59,11 @@ pub enum ClientMessage<T> {
/// The ID of the request to cancel.
request_id: u64,
},
#[doc(hidden)]
_NonExhaustive,
}
/// A request from a client to a server.
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Request<T> {
/// Trace context, deadline, and other cross-cutting concerns.
@@ -72,26 +72,22 @@ pub struct Request<T> {
pub id: u64,
/// The request body.
pub message: T,
#[doc(hidden)]
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
_non_exhaustive: (),
}
/// A response from a server to a client.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Response<T> {
/// The ID of the request being responded to.
pub request_id: u64,
/// The response body, or an error if the request failed.
pub message: Result<T, ServerError>,
#[doc(hidden)]
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
_non_exhaustive: (),
}
/// An error response from a server to a client.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct ServerError {
#[cfg_attr(
@@ -106,9 +102,6 @@ pub struct ServerError {
pub kind: io::ErrorKind,
/// A message describing more detail about the error that occurred.
pub detail: Option<String>,
#[doc(hidden)]
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
_non_exhaustive: (),
}
impl From<ServerError> for io::Error {

View File

@@ -12,7 +12,6 @@ use fnv::FnvHashMap;
use futures::{channel::mpsc, future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
use log::{debug, info, trace};
use pin_project::pin_project;
use raii_counter::{Counter, WeakCounter};
use std::sync::{Arc, Weak};
use std::{
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
@@ -32,7 +31,7 @@ where
dropped_keys: mpsc::UnboundedReceiver<K>,
#[pin]
dropped_keys_tx: mpsc::UnboundedSender<K>,
key_counts: FnvHashMap<K, TrackerPrototype<K>>,
key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
keymaker: F,
}
@@ -42,37 +41,22 @@ where
pub struct TrackedChannel<C, K> {
#[pin]
inner: C,
tracker: Tracker<K>,
tracker: Arc<Tracker<K>>,
}
#[derive(Clone, Debug)]
#[derive(Debug)]
struct Tracker<K> {
key: Option<Arc<K>>,
counter: Counter,
key: Option<K>,
dropped_keys: mpsc::UnboundedSender<K>,
}
impl<K> Drop for Tracker<K> {
fn drop(&mut self) {
if self.counter.count() <= 1 {
// Don't care if the listener is dropped.
match Arc::try_unwrap(self.key.take().unwrap()) {
Ok(key) => {
let _ = self.dropped_keys.unbounded_send(key);
}
_ => unreachable!(),
}
}
// Don't care if the listener is dropped.
let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap());
}
}
#[derive(Clone, Debug)]
struct TrackerPrototype<K> {
key: Weak<K>,
counter: WeakCounter,
dropped_keys: mpsc::UnboundedSender<K>,
}
impl<C, K> Stream for TrackedChannel<C, K>
where
C: Stream,
@@ -181,7 +165,7 @@ where
trace!(
"[{}] Opening channel ({}/{}) channels for key.",
key,
tracker.counter.count(),
Arc::strong_count(&tracker),
self.as_mut().project().channels_per_key
);
@@ -191,28 +175,22 @@ where
})
}
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Tracker<K>, K> {
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
let channels_per_key = self.channels_per_key;
let dropped_keys = self.dropped_keys_tx.clone();
let key_counts = &mut self.as_mut().project().key_counts;
match key_counts.entry(key.clone()) {
Entry::Vacant(vacant) => {
let key = Arc::new(key);
let counter = WeakCounter::new();
vacant.insert(TrackerPrototype {
key: Arc::downgrade(&key),
counter: counter.clone(),
dropped_keys: dropped_keys.clone(),
});
Ok(Tracker {
let tracker = Arc::new(Tracker {
key: Some(key),
counter: counter.upgrade(),
dropped_keys,
})
});
vacant.insert(Arc::downgrade(&tracker));
Ok(tracker)
}
Entry::Occupied(o) => {
let count = o.get().counter.count();
Entry::Occupied(mut o) => {
let count = o.get().strong_count();
if count >= channels_per_key.try_into().unwrap() {
info!(
"[{}] Opened max channels from key ({}/{}).",
@@ -220,16 +198,15 @@ where
);
Err(key)
} else {
let TrackerPrototype {
key,
counter,
dropped_keys,
} = o.get().clone();
Ok(Tracker {
counter: counter.upgrade(),
key: Some(key.upgrade().unwrap()),
dropped_keys,
})
Ok(o.get().upgrade().unwrap_or_else(|| {
let tracker = Arc::new(Tracker {
key: Some(key),
dropped_keys,
});
*o.get_mut() = Arc::downgrade(&tracker);
tracker
}))
}
}
}
@@ -302,12 +279,10 @@ fn ctx() -> Context<'static> {
#[test]
fn tracker_drop() {
use assert_matches::assert_matches;
use raii_counter::Counter;
let (tx, mut rx) = mpsc::unbounded();
Tracker {
key: Some(Arc::new(1)),
counter: Counter::new(),
key: Some(1),
dropped_keys: tx,
};
assert_matches!(rx.try_next(), Ok(Some(1)));
@@ -317,17 +292,15 @@ fn tracker_drop() {
fn tracked_channel_stream() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
use raii_counter::Counter;
let (chan_tx, chan) = mpsc::unbounded();
let (dropped_keys, _) = mpsc::unbounded();
let channel = TrackedChannel {
inner: chan,
tracker: Tracker {
key: Some(Arc::new(1)),
counter: Counter::new(),
tracker: Arc::new(Tracker {
key: Some(1),
dropped_keys,
},
}),
};
chan_tx.unbounded_send("test").unwrap();
@@ -339,17 +312,15 @@ fn tracked_channel_stream() {
fn tracked_channel_sink() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
use raii_counter::Counter;
let (chan, mut chan_rx) = mpsc::unbounded();
let (dropped_keys, _) = mpsc::unbounded();
let channel = TrackedChannel {
inner: chan,
tracker: Tracker {
key: Some(Arc::new(1)),
counter: Counter::new(),
tracker: Arc::new(Tracker {
key: Some(1),
dropped_keys,
},
}),
};
pin_mut!(channel);
@@ -371,12 +342,12 @@ fn channel_filter_increment_channels_for_key() {
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
assert_eq!(tracker1.counter.count(), 1);
assert_eq!(Arc::strong_count(&tracker1), 1);
let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap();
assert_eq!(tracker1.counter.count(), 2);
assert_eq!(Arc::strong_count(&tracker1), 2);
assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
drop(tracker2);
assert_eq!(tracker1.counter.count(), 1);
assert_eq!(Arc::strong_count(&tracker1), 1);
}
#[test]
@@ -395,20 +366,20 @@ fn channel_filter_handle_new_channel() {
.as_mut()
.handle_new_channel(TestChannel { key: "key" })
.unwrap();
assert_eq!(channel1.tracker.counter.count(), 1);
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
let channel2 = filter
.as_mut()
.handle_new_channel(TestChannel { key: "key" })
.unwrap();
assert_eq!(channel1.tracker.counter.count(), 2);
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
assert_matches!(
filter.handle_new_channel(TestChannel { key: "key" }),
Err("key")
);
drop(channel2);
assert_eq!(channel1.tracker.counter.count(), 1);
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
}
#[test]
@@ -429,14 +400,14 @@ fn channel_filter_poll_listener() {
.unwrap();
let channel1 =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(channel1.tracker.counter.count(), 1);
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let _channel2 =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(channel1.tracker.counter.count(), 2);
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
new_channels
.unbounded_send(TestChannel { key: "key" })
@@ -444,7 +415,7 @@ fn channel_filter_poll_listener() {
let key =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
assert_eq!(key, "key");
assert_eq!(channel1.tracker.counter.count(), 2);
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
}
#[test]

View File

@@ -304,7 +304,6 @@ where
} => {
self.as_mut().cancel_request(&trace_context, request_id);
}
ClientMessage::_NonExhaustive => unreachable!(),
},
None => return Poll::Ready(None),
}
@@ -569,11 +568,9 @@ where
"Response did not complete before deadline of {}s.",
format_rfc3339(self.deadline)
)),
_non_exhaustive: (),
})
}
},
_non_exhaustive: (),
});
*self.as_mut().project().state = RespState::PollReady;
}
@@ -653,11 +650,9 @@ where
pub fn execute(self) -> impl Future<Output = ()> {
use log::info;
self.try_for_each(|request_handler| {
async {
tokio::spawn(request_handler);
Ok(())
}
self.try_for_each(|request_handler| async {
tokio::spawn(request_handler);
Ok(())
})
.unwrap_or_else(|e| info!("ClientHandler errored out: {}", e))
}

View File

@@ -87,11 +87,9 @@ impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
context: context::Context {
deadline: SystemTime::UNIX_EPOCH,
trace_context: Default::default(),
_non_exhaustive: (),
},
id,
message,
_non_exhaustive: (),
}));
}
}

View File

@@ -61,9 +61,7 @@ where
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()),
_non_exhaustive: (),
}),
_non_exhaustive: (),
})?;
}
None => return Poll::Ready(None),
@@ -311,7 +309,6 @@ fn throttler_start_send() {
.start_send(Response {
request_id: 0,
message: Ok(1),
_non_exhaustive: (),
})
.unwrap();
assert!(throttler.inner.in_flight_requests.is_empty());
@@ -320,7 +317,6 @@ fn throttler_start_send() {
Some(&Response {
request_id: 0,
message: Ok(1),
_non_exhaustive: ()
})
);
}