12 Commits

Author SHA1 Message Date
Tim Kuehn
7e521768ab Prepare for v0.21.0 release. 2020-06-26 20:05:02 -07:00
Tim Kuehn
e9b1e7d101 Use #[non_exhaustive] in lieu of _NonExhaustive enum variant. 2020-06-26 19:47:20 -07:00
Taiki Endo
f0322fb892 Remove uses of pin_project::project attribute
pin-project will deprecate the project attribute due to some unfixable
limitations.

Refs: https://github.com/taiki-e/pin-project/issues/225
2020-06-05 20:34:44 -07:00
Patrick Elsen
617daebb88 Add tarpc::server proc-macro as syntactic sugar for async methods. (#302)
The tarpc::server proc-macro can be used to annotate implementations of
services to turn async functions into the proper declarations needed
for tarpc to be able to call them.

This uses the assert_type_eq crate to check that the transformations
applied by the tarpc::server proc macro are correct and lead to code
that compiles.
2020-05-16 10:25:25 -07:00
Tim Kuehn
a11d4fff58 Remove raii_counter 2020-04-22 02:13:02 -07:00
Tim
bf42a04d83 Move the request timeout so that it surrounds the entire call, not just the response future. (#295)
* Move the request timeout so that it surrounds the entire call, not just the response future.

This will enable the timeout earlier, so that a backlog in the outbound request buffer can not cause requests to stall indefinitely.

* Run cargo fmt
2020-02-25 14:42:40 -08:00
Tim Kuehn
06528d6953 Fix clippy lint. 2019-12-19 12:28:26 -08:00
Tim Kuehn
9f00395746 Replace _non_exhaustive fields with #[non_exhaustive] attribute.
The attribute landed on stable rust (1.40.0) today.

Fixes https://github.com/google/tarpc/issues/275
2019-12-19 12:14:34 -08:00
Tim Kuehn
e0674cd57f Make pre-push run on rust stable. 2019-12-19 12:06:06 -08:00
Tim Kuehn
7e49bd9ee7 Clean up badges a bit. 2019-12-16 13:21:00 -08:00
Tim Kuehn
8a1baa9c4e Remove usage of unsafe in rpc::client::channel.
pin_project is actually able to handle the complexities of enum Futures.
2019-12-16 11:10:57 -08:00
Oleg Nosov
31c713d188 Allow raw identifiers + fixed naming + place all code generation methods in impl (#291)
Allows defining services using raw identifiers like:

```rust
pub mod service {
    #[tarpc::service]
    pub trait r#trait {
        async fn r#fn(x: i32) -> Result<u8, String>;
    }
}
```

Also:

- Refactored names (ident -> type)
- All code generation methods placed in impl
2019-12-12 10:13:57 -08:00
18 changed files with 829 additions and 499 deletions

View File

@@ -1,6 +1,16 @@
[![Build Status](https://github.com/google/tarpc/workflows/Continuous%20integration/badge.svg)](https://github.com/google/tarpc/actions?query=workflow%3A%22Continuous+integration%22) [![Crates.io][crates-badge]][crates-url]
[![Latest Version](https://img.shields.io/crates/v/tarpc.svg)](https://crates.io/crates/tarpc) [![MIT licensed][mit-badge]][mit-url]
[![Chat on Discord](https://img.shields.io/discord/647529123996237854)](https://discordapp.com/channels/647529123996237854) [![Build status][gh-actions-badge]][gh-actions-url]
[![Discord chat][discord-badge]][discord-url]
[crates-badge]: https://img.shields.io/crates/v/tarpc.svg
[crates-url]: https://crates.io/crates/tarpc
[mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg
[mit-url]: LICENSE
[gh-actions-badge]: https://github.com/google/tarpc/workflows/Continuous%20integration/badge.svg
[gh-actions-url]: https://github.com/google/tarpc/actions?query=workflow%3A%22Continuous+integration%22
[discord-badge]: https://img.shields.io/discord/647529123996237854.svg?logo=discord&style=flat-square
[discord-url]: https://discord.gg/gXwpdSt
# tarpc # tarpc
@@ -47,7 +57,7 @@ Some other features of tarpc:
Add to your `Cargo.toml` dependencies: Add to your `Cargo.toml` dependencies:
```toml ```toml
tarpc = { version = "0.18.0", features = ["full"] } tarpc = { version = "0.21.0", features = ["full"] }
``` ```
The `tarpc::service` attribute expands to a collection of items that form an rpc service. The `tarpc::service` attribute expands to a collection of items that form an rpc service.

View File

@@ -1,3 +1,26 @@
## 0.21.0 (2020-06-26)
### New Features
A new proc macro, `#[tarpc::server]` was added! This enables service impls to elide the boilerplate
of specifying associated types for each RPC. With the ubiquity of async-await, most code won't have
nameable futures and will just be boxing the return type anyway. This macro does that for you.
### Breaking Changes
- Enums had _non_exhaustive fields replaced with the #[non_exhaustive] attribute.
### Bug Fixes
- https://github.com/google/tarpc/issues/304
A race condition in code that limits number of connections per client caused occasional panics.
- https://github.com/google/tarpc/pull/295
Made request timeouts account for time spent in the outbound buffer. Previously, a large outbound
queue would lead to requests not timing out correctly.
## 0.20.0 (2019-12-11) ## 0.20.0 (2019-12-11)
### Breaking Changes ### Breaking Changes

View File

@@ -16,7 +16,7 @@ description = "An example server built on tarpc."
clap = "2.0" clap = "2.0"
futures = "0.3" futures = "0.3"
serde = { version = "1.0" } serde = { version = "1.0" }
tarpc = { version = "0.20", path = "../tarpc", features = ["full"] } tarpc = { version = "0.21", path = "../tarpc", features = ["full"] }
tokio = { version = "0.2", features = ["full"] } tokio = { version = "0.2", features = ["full"] }
tokio-serde = { version = "0.6", features = ["json"] } tokio-serde = { version = "0.6", features = ["json"] }
env_logger = "0.6" env_logger = "0.6"

View File

@@ -89,12 +89,12 @@ if [ "$?" == 0 ]; then
exit 1 exit 1
fi fi
try_run "Building ... " cargo build --color=always try_run "Building ... " cargo +stable build --color=always
try_run "Testing ... " cargo test --color=always try_run "Testing ... " cargo +stable test --color=always
try_run "Testing with all features enabled ... " cargo test --all-features --color=always try_run "Testing with all features enabled ... " cargo +stable test --all-features --color=always
for EXAMPLE in $(cargo run --example 2>&1 | grep ' ' | awk '{print $1}') for EXAMPLE in $(cargo +stable run --example 2>&1 | grep ' ' | awk '{print $1}')
do do
try_run "Running example \"$EXAMPLE\" ... " cargo run --example $EXAMPLE try_run "Running example \"$EXAMPLE\" ... " cargo +stable run --example $EXAMPLE
done done
fi fi

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "tarpc-plugins" name = "tarpc-plugins"
version = "0.7.0" version = "0.8.0"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"] authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2018" edition = "2018"
license = "MIT" license = "MIT"
@@ -30,3 +30,4 @@ proc-macro = true
futures = "0.3" futures = "0.3"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
tarpc = { path = "../tarpc" } tarpc = { path = "../tarpc" }
assert-type-eq = "0.1.0"

View File

@@ -12,15 +12,18 @@ extern crate quote;
extern crate syn; extern crate syn;
use proc_macro::TokenStream; use proc_macro::TokenStream;
use quote::{format_ident, quote}; use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote, ToTokens};
use syn::{ use syn::{
braced, parenthesized, braced,
ext::IdentExt,
parenthesized,
parse::{Parse, ParseStream}, parse::{Parse, ParseStream},
parse_macro_input, parse_quote, parse_macro_input, parse_quote, parse_str,
punctuated::Punctuated, punctuated::Punctuated,
token::Comma, token::Comma,
Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type, Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool,
Visibility, MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility,
}; };
struct Service { struct Service {
@@ -42,7 +45,7 @@ impl Parse for Service {
let attrs = input.call(Attribute::parse_outer)?; let attrs = input.call(Attribute::parse_outer)?;
let vis = input.parse()?; let vis = input.parse()?;
input.parse::<Token![trait]>()?; input.parse::<Token![trait]>()?;
let ident = input.parse()?; let ident: Ident = input.parse()?;
let content; let content;
braced!(content in input); braced!(content in input);
let mut rpcs = Vec::<RpcMethod>::new(); let mut rpcs = Vec::<RpcMethod>::new();
@@ -53,7 +56,7 @@ impl Parse for Service {
if rpc.ident == "new" { if rpc.ident == "new" {
return Err(input.error(format!( return Err(input.error(format!(
"method name conflicts with generated fn `{}Client::new`", "method name conflicts with generated fn `{}Client::new`",
ident ident.unraw()
))); )));
} }
if rpc.ident == "serve" { if rpc.ident == "serve" {
@@ -63,6 +66,7 @@ impl Parse for Service {
))); )));
} }
} }
Ok(Self { Ok(Self {
attrs, attrs,
vis, vis,
@@ -156,19 +160,19 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
ref rpcs, ref rpcs,
} = parse_macro_input!(input as Service); } = parse_macro_input!(input as Service);
let camel_case_fn_names: &[String] = &rpcs let camel_case_fn_names: &Vec<_> = &rpcs
.iter() .iter()
.map(|rpc| snake_to_camel(&rpc.ident.to_string())) .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
.collect::<Vec<_>>(); .collect();
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>(); let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
let response_fut_name = &format!("{}ResponseFut", ident); let response_fut_name = &format!("{}ResponseFut", ident.unraw());
let derive_serialize = quote!(#[derive(serde::Serialize, serde::Deserialize)]); let derive_serialize = if derive_serde.0 {
Some(quote!(#[derive(serde::Serialize, serde::Deserialize)]))
} else {
None
};
let gen_args = &GenArgs { ServiceGenerator {
attrs,
vis,
rpcs,
args,
response_fut_name, response_fut_name,
service_ident: ident, service_ident: ident,
server_ident: &format_ident!("Serve{}", ident), server_ident: &format_ident!("Serve{}", ident),
@@ -176,8 +180,12 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
client_ident: &format_ident!("{}Client", ident), client_ident: &format_ident!("{}Client", ident),
request_ident: &format_ident!("{}Request", ident), request_ident: &format_ident!("{}Request", ident),
response_ident: &format_ident!("{}Response", ident), response_ident: &format_ident!("{}Response", ident),
vis,
args,
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(), method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
method_names: &rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>(), method_idents: &rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>(),
attrs,
rpcs,
return_types: &rpcs return_types: &rpcs
.iter() .iter()
.map(|rpc| match rpc.output { .map(|rpc| match rpc.output {
@@ -185,7 +193,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
ReturnType::Default => unit_type, ReturnType::Default => unit_type,
}) })
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
arg_vars: &args arg_pats: &args
.iter() .iter()
.map(|args| args.iter().map(|arg| &*arg.pat).collect()) .map(|args| args.iter().map(|arg| &*arg.pat).collect())
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
@@ -194,43 +202,139 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
.zip(camel_case_fn_names.iter()) .zip(camel_case_fn_names.iter())
.map(|(rpc, name)| Ident::new(name, rpc.ident.span())) .map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
future_idents: &camel_case_fn_names future_types: &camel_case_fn_names
.iter() .iter()
.map(|name| format_ident!("{}Fut", name)) .map(|name| parse_str(&format!("{}Fut", name)).unwrap())
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
derive_serialize: if derive_serde.0 { derive_serialize: derive_serialize.as_ref(),
Some(&derive_serialize) }
} else { .into_token_stream()
None .into()
}, }
/// Transforms an async function into a sync one, returning a type declaration
/// for the return type (a future).
fn transform_method(method: &mut ImplItemMethod) -> ImplItemType {
method.sig.asyncness = None;
// get either the return type or ().
let ret = match &method.sig.output {
ReturnType::Default => quote!(()),
ReturnType::Type(_, ret) => quote!(#ret),
}; };
let tokens = vec![ // generate an identifier consisting of the method name to CamelCase with
trait_service(gen_args), // Fut appended to it.
struct_server(gen_args), let fut_name = snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut";
impl_serve_for_server(gen_args), let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span());
enum_request(gen_args),
enum_response(gen_args),
enum_response_future(gen_args),
impl_debug_for_response_future(gen_args),
impl_future_for_response_future(gen_args),
struct_client(gen_args),
impl_from_for_client(gen_args),
impl_client_new(gen_args),
impl_client_rpc_methods(gen_args),
];
tokens // generate the updated return signature.
.into_iter() method.sig.output = parse_quote! {
.collect::<proc_macro2::TokenStream>() -> ::core::pin::Pin<Box<
.into() dyn ::core::future::Future<Output = #ret> + ::core::marker::Send
>>
};
// transform the body of the method into Box::pin(async move { body }).
let block = method.block.clone();
method.block = parse_quote! [{
Box::pin(async move
#block
)
}];
// generate and return type declaration for return type.
let t: ImplItemType = parse_quote! {
type #fut_name_ident = ::core::pin::Pin<Box<dyn ::core::future::Future<Output = #ret> + ::core::marker::Send>>;
};
t
}
/// Syntactic sugar to make using async functions in the server implementation
/// easier. It does this by rewriting code like this, which would normally not
/// compile because async functions are disallowed in trait implementations:
///
/// ```rust
/// # extern crate tarpc;
/// # use tarpc::context;
/// # use std::net::SocketAddr;
/// #[tarpc_plugins::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// #[tarpc_plugins::server]
/// impl World for HelloServer {
/// async fn hello(self, _: context::Context, name: String) -> String {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// }
/// }
/// ```
///
/// Into code like this, which matches the service trait definition:
///
/// ```rust
/// # extern crate tarpc;
/// # use tarpc::context;
/// # use std::pin::Pin;
/// # use futures::Future;
/// # use std::net::SocketAddr;
/// #[tarpc_plugins::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// impl World for HelloServer {
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
///
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
/// + Send>> {
/// Box::pin(async move {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// })
/// }
/// }
/// ```
///
/// Note that this won't touch functions unless they have been annotated with
/// `async`, meaning that this should not break existing code.
#[proc_macro_attribute]
pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream {
let mut item = syn::parse_macro_input!(input as ItemImpl);
// the generated type declarations
let mut types: Vec<ImplItemType> = Vec::new();
for inner in &mut item.items {
if let ImplItem::Method(method) = inner {
let sig = &method.sig;
// if this function is declared async, transform it into a regular function
if sig.asyncness.is_some() {
let typedecl = transform_method(method);
types.push(typedecl);
}
}
}
// add the type declarations into the impl block
for t in types.into_iter() {
item.items.push(syn::ImplItem::Type(t));
}
TokenStream::from(quote!(#item))
} }
// Things needed to generate the service items: trait, serve impl, request/response enums, and // Things needed to generate the service items: trait, serve impl, request/response enums, and
// the client stub. // the client stub.
struct GenArgs<'a> { struct ServiceGenerator<'a> {
attrs: &'a [Attribute],
rpcs: &'a [RpcMethod],
service_ident: &'a Ident, service_ident: &'a Ident,
server_ident: &'a Ident, server_ident: &'a Ident,
response_fut_ident: &'a Ident, response_fut_ident: &'a Ident,
@@ -238,331 +342,356 @@ struct GenArgs<'a> {
client_ident: &'a Ident, client_ident: &'a Ident,
request_ident: &'a Ident, request_ident: &'a Ident,
response_ident: &'a Ident, response_ident: &'a Ident,
method_attrs: &'a [&'a [Attribute]],
vis: &'a Visibility, vis: &'a Visibility,
method_names: &'a [&'a Ident], attrs: &'a [Attribute],
rpcs: &'a [RpcMethod],
camel_case_idents: &'a [Ident],
future_types: &'a [Type],
method_idents: &'a [&'a Ident],
method_attrs: &'a [&'a [Attribute]],
args: &'a [&'a [PatType]], args: &'a [&'a [PatType]],
return_types: &'a [&'a Type], return_types: &'a [&'a Type],
arg_vars: &'a [Vec<&'a Pat>], arg_pats: &'a [Vec<&'a Pat>],
camel_case_idents: &'a [Ident], derive_serialize: Option<&'a TokenStream2>,
future_idents: &'a [Ident],
derive_serialize: Option<&'a proc_macro2::TokenStream>,
} }
fn trait_service( impl<'a> ServiceGenerator<'a> {
&GenArgs { fn trait_service(&self) -> TokenStream2 {
attrs, let &Self {
rpcs, attrs,
vis, rpcs,
future_idents, vis,
return_types, future_types,
service_ident, return_types,
server_ident, service_ident,
.. server_ident,
}: &GenArgs, ..
) -> proc_macro2::TokenStream { } = self;
let types_and_fns = rpcs
.iter()
.zip(future_idents.iter())
.zip(return_types.iter())
.map(
|(
(
RpcMethod {
attrs, ident, args, ..
},
future_type,
),
output,
)| {
let ty_doc = format!("The response future returned by {}.", ident);
quote! {
#[doc = #ty_doc]
type #future_type: std::future::Future<Output = #output>;
#( #attrs )* let types_and_fns = rpcs
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type; .iter()
} .zip(future_types.iter())
}, .zip(return_types.iter())
); .map(
|(
(
RpcMethod {
attrs, ident, args, ..
},
future_type,
),
output,
)| {
let ty_doc = format!("The response future returned by {}.", ident);
quote! {
#[doc = #ty_doc]
type #future_type: std::future::Future<Output = #output>;
quote! { #( #attrs )*
#( #attrs )* fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type;
#vis trait #service_ident: Clone { }
#( #types_and_fns )* },
);
/// Returns a serving function to use with tarpc::server::Server. quote! {
fn serve(self) -> #server_ident<Self> { #( #attrs )*
#server_ident { service: self } #vis trait #service_ident: Clone {
} #( #types_and_fns )*
}
}
}
fn struct_server( /// Returns a serving function to use with tarpc::server::Server.
&GenArgs { fn serve(self) -> #server_ident<Self> {
vis, server_ident, .. #server_ident { service: self }
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
#[derive(Clone)]
#vis struct #server_ident<S> {
service: S,
}
}
}
fn impl_serve_for_server(
&GenArgs {
request_ident,
server_ident,
service_ident,
response_ident,
response_fut_ident,
camel_case_idents,
arg_vars,
method_names,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl<S> tarpc::server::Serve<#request_ident> for #server_ident<S>
where S: #service_ident
{
type Resp = #response_ident;
type Fut = #response_fut_ident<S>;
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
match req {
#(
#request_ident::#camel_case_idents{ #( #arg_vars ),* } => {
#response_fut_ident::#camel_case_idents(
#service_ident::#method_names(
self.service, ctx, #( #arg_vars ),*))
}
)*
} }
} }
} }
} }
}
fn enum_request( fn struct_server(&self) -> TokenStream2 {
&GenArgs { let &Self {
derive_serialize, vis, server_ident, ..
vis, } = self;
request_ident,
camel_case_idents,
args,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
/// The request sent over the wire from the client to the server.
#[derive(Debug)]
#derive_serialize
#vis enum #request_ident {
#( #camel_case_idents{ #( #args ),* } ),*
}
}
}
fn enum_response( quote! {
&GenArgs { #[derive(Clone)]
derive_serialize, #vis struct #server_ident<S> {
vis, service: S,
response_ident,
camel_case_idents,
return_types,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
/// The response sent over the wire from the server to the client.
#[derive(Debug)]
#derive_serialize
#vis enum #response_ident {
#( #camel_case_idents(#return_types) ),*
}
}
}
fn enum_response_future(
&GenArgs {
vis,
service_ident,
response_fut_ident,
camel_case_idents,
future_idents,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
/// A future resolving to a server response.
#vis enum #response_fut_ident<S: #service_ident> {
#( #camel_case_idents(<S as #service_ident>::#future_idents) ),*
}
}
}
fn impl_debug_for_response_future(
&GenArgs {
service_ident,
response_fut_ident,
response_fut_name,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct(#response_fut_name).finish()
} }
} }
} }
}
fn impl_future_for_response_future( fn impl_serve_for_server(&self) -> TokenStream2 {
&GenArgs { let &Self {
service_ident, request_ident,
response_fut_ident, server_ident,
response_ident, service_ident,
camel_case_idents, response_ident,
.. response_fut_ident,
}: &GenArgs, camel_case_idents,
) -> proc_macro2::TokenStream { arg_pats,
quote! { method_idents,
impl<S: #service_ident> std::future::Future for #response_fut_ident<S> { ..
type Output = #response_ident; } = self;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) quote! {
-> std::task::Poll<#response_ident> impl<S> tarpc::server::Serve<#request_ident> for #server_ident<S>
where S: #service_ident
{ {
unsafe { type Resp = #response_ident;
match std::pin::Pin::get_unchecked_mut(self) { type Fut = #response_fut_ident<S>;
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
match req {
#( #(
#response_fut_ident::#camel_case_idents(resp) => #request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
std::pin::Pin::new_unchecked(resp) #response_fut_ident::#camel_case_idents(
.poll(cx) #service_ident::#method_idents(
.map(#response_ident::#camel_case_idents), self.service, ctx, #( #arg_pats ),*
)
)
}
)* )*
} }
} }
} }
} }
} }
}
fn struct_client( fn enum_request(&self) -> TokenStream2 {
&GenArgs { let &Self {
vis, derive_serialize,
client_ident, vis,
request_ident, request_ident,
response_ident, camel_case_idents,
.. args,
}: &GenArgs, ..
) -> proc_macro2::TokenStream { } = self;
quote! {
#[allow(unused)]
#[derive(Clone, Debug)]
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
#vis struct #client_ident<C = tarpc::client::Channel<#request_ident, #response_ident>>(C);
}
}
fn impl_from_for_client( quote! {
&GenArgs { /// The request sent over the wire from the client to the server.
client_ident, #[derive(Debug)]
request_ident, #derive_serialize
response_ident, #vis enum #request_ident {
.. #( #camel_case_idents{ #( #args ),* } ),*
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl<C> From<C> for #client_ident<C>
where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
fn from(client: C) -> Self {
#client_ident(client)
} }
} }
} }
}
fn impl_client_new( fn enum_response(&self) -> TokenStream2 {
&GenArgs { let &Self {
client_ident, derive_serialize,
vis, vis,
request_ident, response_ident,
response_ident, camel_case_idents,
.. return_types,
}: &GenArgs, ..
) -> proc_macro2::TokenStream { } = self;
quote! {
impl #client_ident { quote! {
/// Returns a new client stub that sends requests over the given transport. /// The response sent over the wire from the server to the client.
#vis fn new<T>(config: tarpc::client::Config, transport: T) #[derive(Debug)]
-> tarpc::client::NewClient< #derive_serialize
Self, #vis enum #response_ident {
tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>> #( #camel_case_idents(#return_types) ),*
where }
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>> }
{ }
let new_client = tarpc::client::new(config, transport);
tarpc::client::NewClient { fn enum_response_future(&self) -> TokenStream2 {
client: #client_ident(new_client.client), let &Self {
dispatch: new_client.dispatch, vis,
service_ident,
response_fut_ident,
camel_case_idents,
future_types,
..
} = self;
quote! {
/// A future resolving to a server response.
#vis enum #response_fut_ident<S: #service_ident> {
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
}
}
}
fn impl_debug_for_response_future(&self) -> TokenStream2 {
let &Self {
service_ident,
response_fut_ident,
response_fut_name,
..
} = self;
quote! {
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct(#response_fut_name).finish()
} }
} }
} }
} }
}
fn impl_client_rpc_methods( fn impl_future_for_response_future(&self) -> TokenStream2 {
&GenArgs { let &Self {
client_ident, service_ident,
request_ident, response_fut_ident,
response_ident, response_ident,
method_attrs, camel_case_idents,
vis, ..
method_names, } = self;
args,
return_types, quote! {
arg_vars, impl<S: #service_ident> std::future::Future for #response_fut_ident<S> {
camel_case_idents, type Output = #response_ident;
..
}: &GenArgs, fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
) -> proc_macro2::TokenStream { -> std::task::Poll<#response_ident>
quote! { {
impl<C> #client_ident<C> unsafe {
where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident> match std::pin::Pin::get_unchecked_mut(self) {
{ #(
#( #response_fut_ident::#camel_case_idents(resp) =>
#[allow(unused)] std::pin::Pin::new_unchecked(resp)
#( #method_attrs )* .poll(cx)
#vis fn #method_names(&mut self, ctx: tarpc::context::Context, #( #args ),*) .map(#response_ident::#camel_case_idents),
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ { )*
let request = #request_ident::#camel_case_idents { #( #arg_vars ),* };
let resp = tarpc::Client::call(&mut self.0, ctx, request);
async move {
match resp.await? {
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
_ => unreachable!(),
} }
} }
} }
)* }
} }
} }
fn struct_client(&self) -> TokenStream2 {
let &Self {
vis,
client_ident,
request_ident,
response_ident,
..
} = self;
quote! {
#[allow(unused)]
#[derive(Clone, Debug)]
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
#vis struct #client_ident<C = tarpc::client::Channel<#request_ident, #response_ident>>(C);
}
}
fn impl_from_for_client(&self) -> TokenStream2 {
let &Self {
client_ident,
request_ident,
response_ident,
..
} = self;
quote! {
impl<C> From<C> for #client_ident<C>
where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
fn from(client: C) -> Self {
#client_ident(client)
}
}
}
}
fn impl_client_new(&self) -> TokenStream2 {
let &Self {
client_ident,
vis,
request_ident,
response_ident,
..
} = self;
quote! {
impl #client_ident {
/// Returns a new client stub that sends requests over the given transport.
#vis fn new<T>(config: tarpc::client::Config, transport: T)
-> tarpc::client::NewClient<
Self,
tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>
>
where
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>
{
let new_client = tarpc::client::new(config, transport);
tarpc::client::NewClient {
client: #client_ident(new_client.client),
dispatch: new_client.dispatch,
}
}
}
}
}
fn impl_client_rpc_methods(&self) -> TokenStream2 {
let &Self {
client_ident,
request_ident,
response_ident,
method_attrs,
vis,
method_idents,
args,
return_types,
arg_pats,
camel_case_idents,
..
} = self;
quote! {
impl<C> #client_ident<C>
where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
#(
#[allow(unused)]
#( #method_attrs )*
#vis fn #method_idents(&mut self, ctx: tarpc::context::Context, #( #args ),*)
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
let resp = tarpc::Client::call(&mut self.0, ctx, request);
async move {
match resp.await? {
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
_ => unreachable!(),
}
}
}
)*
}
}
}
}
impl<'a> ToTokens for ServiceGenerator<'a> {
fn to_tokens(&self, output: &mut TokenStream2) {
output.extend(vec![
self.trait_service(),
self.struct_server(),
self.impl_serve_for_server(),
self.enum_request(),
self.enum_response(),
self.enum_response_future(),
self.impl_debug_for_response_future(),
self.impl_future_for_response_future(),
self.struct_client(),
self.impl_from_for_client(),
self.impl_client_new(),
self.impl_client_rpc_methods(),
])
}
} }
fn snake_to_camel(ident_str: &str) -> String { fn snake_to_camel(ident_str: &str) -> String {
let chars = ident_str.chars();
let mut camel_ty = String::with_capacity(ident_str.len()); let mut camel_ty = String::with_capacity(ident_str.len());
let mut last_char_was_underscore = true; let mut last_char_was_underscore = true;
for c in chars { for c in ident_str.chars() {
match c { match c {
'_' => last_char_was_underscore = true, '_' => last_char_was_underscore = true,
c if last_char_was_underscore => { c if last_char_was_underscore => {

144
plugins/tests/server.rs Normal file
View File

@@ -0,0 +1,144 @@
use assert_type_eq::assert_type_eq;
use futures::Future;
use std::pin::Pin;
use tarpc::context;
// these need to be out here rather than inside the function so that the
// assert_type_eq macro can pick them up.
#[tarpc::service]
trait Foo {
async fn two_part(s: String, i: i32) -> (String, i32);
async fn bar(s: String) -> String;
async fn baz();
}
#[test]
fn type_generation_works() {
#[tarpc::server]
impl Foo for () {
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
(s, i)
}
async fn bar(self, _: context::Context, s: String) -> String {
s
}
async fn baz(self, _: context::Context) {}
}
// the assert_type_eq macro can only be used once per block.
{
assert_type_eq!(
<() as Foo>::TwoPartFut,
Pin<Box<dyn Future<Output = (String, i32)> + Send>>
);
}
{
assert_type_eq!(
<() as Foo>::BarFut,
Pin<Box<dyn Future<Output = String> + Send>>
);
}
{
assert_type_eq!(
<() as Foo>::BazFut,
Pin<Box<dyn Future<Output = ()> + Send>>
);
}
}
#[allow(non_camel_case_types)]
#[test]
fn raw_idents_work() {
type r#yield = String;
#[tarpc::service]
trait r#trait {
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
async fn r#fn(r#impl: r#yield) -> r#yield;
async fn r#async();
}
#[tarpc::server]
impl r#trait for () {
async fn r#await(
self,
_: context::Context,
r#struct: r#yield,
r#enum: i32,
) -> (r#yield, i32) {
(r#struct, r#enum)
}
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
r#impl
}
async fn r#async(self, _: context::Context) {}
}
}
#[test]
fn syntax() {
#[tarpc::service]
trait Syntax {
#[deny(warnings)]
#[allow(non_snake_case)]
async fn TestCamelCaseDoesntConflict();
async fn hello() -> String;
#[doc = "attr"]
async fn attr(s: String) -> String;
async fn no_args_no_return();
async fn no_args() -> ();
async fn one_arg(one: String) -> i32;
async fn two_args_no_return(one: String, two: u64);
async fn two_args(one: String, two: u64) -> String;
async fn no_args_ret_error() -> i32;
async fn one_arg_ret_error(one: String) -> String;
async fn no_arg_implicit_return_error();
#[doc = "attr"]
async fn one_arg_implicit_return_error(one: String);
}
#[tarpc::server]
impl Syntax for () {
#[deny(warnings)]
#[allow(non_snake_case)]
async fn TestCamelCaseDoesntConflict(self, _: context::Context) {}
async fn hello(self, _: context::Context) -> String {
String::new()
}
async fn attr(self, _: context::Context, _s: String) -> String {
String::new()
}
async fn no_args_no_return(self, _: context::Context) {}
async fn no_args(self, _: context::Context) -> () {}
async fn one_arg(self, _: context::Context, _one: String) -> i32 {
0
}
async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {}
async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String {
String::new()
}
async fn no_args_ret_error(self, _: context::Context) -> i32 {
0
}
async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String {
String::new()
}
async fn no_arg_implicit_return_error(self, _: context::Context) {}
async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {}
}
}

View File

@@ -29,6 +29,38 @@ fn att_service_trait() {
} }
} }
#[allow(non_camel_case_types)]
#[test]
fn raw_idents() {
use futures::future::{ready, Ready};
type r#yield = String;
#[tarpc::service]
trait r#trait {
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
async fn r#fn(r#impl: r#yield) -> r#yield;
async fn r#async();
}
impl r#trait for () {
type AwaitFut = Ready<(r#yield, i32)>;
fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut {
ready((r#struct, r#enum))
}
type FnFut = Ready<r#yield>;
fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut {
ready(r#impl)
}
type AsyncFut = Ready<()>;
fn r#async(self, _: context::Context) -> Self::AsyncFut {
ready(())
}
}
}
#[test] #[test]
fn syntax() { fn syntax() {
#[tarpc::service] #[tarpc::service]

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "tarpc" name = "tarpc"
version = "0.20.0" version = "0.21.0"
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"] authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
edition = "2018" edition = "2018"
license = "MIT" license = "MIT"
@@ -30,13 +30,12 @@ fnv = "1.0"
futures = "0.3" futures = "0.3"
humantime = "1.0" humantime = "1.0"
log = "0.4" log = "0.4"
pin-project = "0.4" pin-project = "0.4.17"
raii-counter = "0.2"
rand = "0.7" rand = "0.7"
tokio = { version = "0.2", features = ["time"] } tokio = { version = "0.2", features = ["time"] }
serde = { optional = true, version = "1.0", features = ["derive"] } serde = { optional = true, version = "1.0", features = ["derive"] }
tokio-util = { optional = true, version = "0.2" } tokio-util = { optional = true, version = "0.2" }
tarpc-plugins = { path = "../plugins", version = "0.7" } tarpc-plugins = { path = "../plugins", version = "0.8" }
tokio-serde = { optional = true, version = "0.6" } tokio-serde = { optional = true, version = "0.6" }
[dev-dependencies] [dev-dependencies]
@@ -61,4 +60,3 @@ required-features = ["full"]
[[example]] [[example]]
name = "pubsub" name = "pubsub"
required-features = ["full"] required-features = ["full"]

View File

@@ -50,7 +50,7 @@
//! Add to your `Cargo.toml` dependencies: //! Add to your `Cargo.toml` dependencies:
//! //!
//! ```toml //! ```toml
//! tarpc = "0.20.0" //! tarpc = "0.21.0"
//! ``` //! ```
//! //!
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service. //! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
@@ -215,7 +215,6 @@ pub mod trace;
/// Rpc methods are specified, mirroring trait syntax: /// Rpc methods are specified, mirroring trait syntax:
/// ///
/// ``` /// ```
/// # fn main() {}
/// #[tarpc::service] /// #[tarpc::service]
/// trait Service { /// trait Service {
/// /// Say hello /// /// Say hello
@@ -234,3 +233,59 @@ pub mod trace;
/// * `Client` -- a client stub with a fn for each RPC. /// * `Client` -- a client stub with a fn for each RPC.
/// * `fn new_stub` -- creates a new Client stub. /// * `fn new_stub` -- creates a new Client stub.
pub use tarpc_plugins::service; pub use tarpc_plugins::service;
/// A utility macro that can be used for RPC server implementations.
///
/// Syntactic sugar to make using async functions in the server implementation
/// easier. It does this by rewriting code like this, which would normally not
/// compile because async functions are disallowed in trait implementations:
///
/// ```rust
/// # use tarpc::context;
/// # use std::net::SocketAddr;
/// #[tarpc::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// #[tarpc::server]
/// impl World for HelloServer {
/// async fn hello(self, _: context::Context, name: String) -> String {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// }
/// }
/// ```
///
/// Into code like this, which matches the service trait definition:
///
/// ```rust
/// # use tarpc::context;
/// # use std::pin::Pin;
/// # use futures::Future;
/// # use std::net::SocketAddr;
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// #[tarpc::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// impl World for HelloServer {
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
///
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
/// + Send>> {
/// Box::pin(async move {
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
/// })
/// }
/// }
/// ```
///
/// Note that this won't touch functions unless they have been annotated with
/// `async`, meaning that this should not break existing code.
pub use tarpc_plugins::server;

View File

@@ -78,14 +78,21 @@ impl<'a, Req, Resp> Future for Send<'a, Req, Resp> {
#[must_use = "futures do nothing unless polled"] #[must_use = "futures do nothing unless polled"]
pub struct Call<'a, Req, Resp> { pub struct Call<'a, Req, Resp> {
#[pin] #[pin]
fut: AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>, fut: tokio::time::Timeout<AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>>,
} }
impl<'a, Req, Resp> Future for Call<'a, Req, Resp> { impl<'a, Req, Resp> Future for Call<'a, Req, Resp> {
type Output = io::Result<Resp>; type Output = io::Result<Resp>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.as_mut().project().fut.poll(cx) let resp = ready!(self.as_mut().project().fut.poll(cx));
Poll::Ready(match resp {
Ok(resp) => resp,
Err(tokio::time::Elapsed { .. }) => Err(io::Error::new(
io::ErrorKind::TimedOut,
"Client dropped expired request.".to_string(),
)),
})
} }
} }
@@ -97,13 +104,6 @@ impl<Req, Resp> Channel<Req, Resp> {
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id); ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng()); ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
let timeout = ctx.deadline.time_until();
trace!(
"[{}] Queuing request with timeout {:?}.",
ctx.trace_id(),
timeout,
);
let (response_completion, response) = oneshot::channel(); let (response_completion, response) = oneshot::channel();
let cancellation = self.cancellation.clone(); let cancellation = self.cancellation.clone();
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
@@ -116,7 +116,7 @@ impl<Req, Resp> Channel<Req, Resp> {
response_completion, response_completion,
})), })),
DispatchResponse { DispatchResponse {
response: tokio::time::timeout(timeout, response), response,
complete: false, complete: false,
request_id, request_id,
cancellation, cancellation,
@@ -128,9 +128,16 @@ impl<Req, Resp> Channel<Req, Resp> {
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves to the response. /// resolves to the response.
pub fn call(&mut self, context: context::Context, request: Req) -> Call<Req, Resp> { pub fn call(&mut self, ctx: context::Context, request: Req) -> Call<Req, Resp> {
let timeout = ctx.deadline.time_until();
trace!(
"[{}] Queuing request with timeout {:?}.",
ctx.trace_id(),
timeout,
);
Call { Call {
fut: AndThenIdent::new(self.send(context, request)), fut: tokio::time::timeout(timeout, AndThenIdent::new(self.send(ctx, request))),
} }
} }
} }
@@ -140,7 +147,7 @@ impl<Req, Resp> Channel<Req, Resp> {
#[pin_project(PinnedDrop)] #[pin_project(PinnedDrop)]
#[derive(Debug)] #[derive(Debug)]
struct DispatchResponse<Resp> { struct DispatchResponse<Resp> {
response: tokio::time::Timeout<oneshot::Receiver<Response<Resp>>>, response: oneshot::Receiver<Response<Resp>>,
ctx: context::Context, ctx: context::Context,
complete: bool, complete: bool,
cancellation: RequestCancellation, cancellation: RequestCancellation,
@@ -152,24 +159,15 @@ impl<Resp> Future for DispatchResponse<Resp> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
let resp = ready!(self.response.poll_unpin(cx)); let resp = ready!(self.response.poll_unpin(cx));
self.complete = true;
Poll::Ready(match resp { Poll::Ready(match resp {
Ok(resp) => { Ok(resp) => Ok(resp.message?),
self.complete = true; Err(oneshot::Canceled) => {
match resp { // The oneshot is Canceled when the dispatch task ends. In that case,
Ok(resp) => Ok(resp.message?), // there's nothing listening on the other side, so there's no point in
Err(oneshot::Canceled) => { // propagating cancellation.
// The oneshot is Canceled when the dispatch task ends. In that case, Err(io::Error::from(io::ErrorKind::ConnectionReset))
// there's nothing listening on the other side, so there's no point in
// propagating cancellation.
Err(io::Error::from(io::ErrorKind::ConnectionReset))
}
}
} }
Err(tokio::time::Elapsed { .. }) => Err(io::Error::new(
io::ErrorKind::TimedOut,
"Client dropped expired request.".to_string(),
)),
}) })
} }
} }
@@ -189,7 +187,7 @@ impl<Resp> PinnedDrop for DispatchResponse<Resp> {
// closing the receiver before sending the cancel message, it is guaranteed that if the // closing the receiver before sending the cancel message, it is guaranteed that if the
// dispatch task misses an early-arriving cancellation message, then it will see the // dispatch task misses an early-arriving cancellation message, then it will see the
// receiver as closed. // receiver as closed.
self.response.get_mut().close(); self.response.close();
let request_id = self.request_id; let request_id = self.request_id;
self.cancellation.cancel(request_id); self.cancellation.cancel(request_id);
} }
@@ -385,9 +383,7 @@ where
context: context::Context { context: context::Context {
deadline: dispatch_request.ctx.deadline, deadline: dispatch_request.ctx.deadline,
trace_context: dispatch_request.ctx.trace_context, trace_context: dispatch_request.ctx.trace_context,
_non_exhaustive: (),
}, },
_non_exhaustive: (),
}); });
self.as_mut().project().transport.start_send(request)?; self.as_mut().project().transport.start_send(request)?;
self.as_mut().project().in_flight_requests.insert( self.as_mut().project().in_flight_requests.insert(
@@ -632,11 +628,12 @@ where
} }
} }
#[pin_project(project = TryChainProj)]
#[must_use = "futures do nothing unless polled"] #[must_use = "futures do nothing unless polled"]
#[derive(Debug)] #[derive(Debug)]
enum TryChain<Fut1, Fut2> { enum TryChain<Fut1, Fut2> {
First(Fut1), First(#[pin] Fut1),
Second(Fut2), Second(#[pin] Fut2),
Empty, Empty,
} }
@@ -658,7 +655,7 @@ where
} }
fn poll<F>( fn poll<F>(
self: Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
f: F, f: F,
) -> Poll<Result<Fut2::Ok, Fut2::Error>> ) -> Poll<Result<Fut2::Ok, Fut2::Error>>
@@ -667,31 +664,28 @@ where
{ {
let mut f = Some(f); let mut f = Some(f);
// Safe to call `get_unchecked_mut` because we won't move the futures.
let this = unsafe { Pin::get_unchecked_mut(self) };
loop { loop {
let output = match this { let output = match self.as_mut().project() {
TryChain::First(fut1) => { TryChainProj::First(fut1) => {
// Poll the first future // Poll the first future
match unsafe { Pin::new_unchecked(fut1) }.try_poll(cx) { match fut1.try_poll(cx) {
Poll::Pending => return Poll::Pending, Poll::Pending => return Poll::Pending,
Poll::Ready(output) => output, Poll::Ready(output) => output,
} }
} }
TryChain::Second(fut2) => { TryChainProj::Second(fut2) => {
// Poll the second future // Poll the second future
return unsafe { Pin::new_unchecked(fut2) }.try_poll(cx); return fut2.try_poll(cx);
} }
TryChain::Empty => { TryChainProj::Empty => {
panic!("future must not be polled after it returned `Poll::Ready`"); panic!("future must not be polled after it returned `Poll::Ready`");
} }
}; };
*this = TryChain::Empty; // Drop fut1 self.set(TryChain::Empty); // Drop fut1
let f = f.take().unwrap(); let f = f.take().unwrap();
match f(output) { match f(output) {
TryChainAction::Future(fut2) => *this = TryChain::Second(fut2), TryChainAction::Future(fut2) => self.set(TryChain::Second(fut2)),
TryChainAction::Output(output) => return Poll::Ready(output), TryChainAction::Output(output) => return Poll::Ready(output),
} }
} }
@@ -716,24 +710,21 @@ mod tests {
prelude::*, prelude::*,
task::*, task::*,
}; };
use std::time::Duration;
use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc}; use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc};
#[tokio::test(threaded_scheduler)] #[tokio::test(threaded_scheduler)]
async fn dispatch_response_cancels_on_timeout() { async fn dispatch_response_cancels_on_drop() {
let (_response_completion, response) = oneshot::channel();
let (cancellation, mut canceled_requests) = cancellations(); let (cancellation, mut canceled_requests) = cancellations();
let resp = DispatchResponse::<u64> { let (_, response) = oneshot::channel();
// Timeout in the past should cause resp to error out when polled. drop(DispatchResponse::<u32> {
response: tokio::time::timeout(Duration::from_secs(0), response), response,
cancellation,
complete: false, complete: false,
request_id: 3, request_id: 3,
cancellation,
ctx: context::current(), ctx: context::current(),
}; });
let _ = futures::poll!(resp);
// resp's drop() is run, which should send a cancel message. // resp's drop() is run, which should send a cancel message.
assert!(canceled_requests.0.try_next().unwrap() == Some(3)); assert_eq!(canceled_requests.0.try_next().unwrap(), Some(3));
} }
#[tokio::test(threaded_scheduler)] #[tokio::test(threaded_scheduler)]
@@ -768,7 +759,6 @@ mod tests {
Response { Response {
request_id: 0, request_id: 0,
message: Ok("hello".into()), message: Ok("hello".into()),
_non_exhaustive: (),
}, },
) )
.await; .await;
@@ -822,7 +812,7 @@ mod tests {
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request // i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
// map. // map.
let mut resp = send_request(&mut channel, "hi").await; let mut resp = send_request(&mut channel, "hi").await;
resp.response.get_mut().close(); resp.response.close();
assert!(dispatch.poll_next_request(cx).is_pending()); assert!(dispatch.poll_next_request(cx).is_pending());
} }

View File

@@ -104,6 +104,7 @@ where
/// Settings that control the behavior of the client. /// Settings that control the behavior of the client.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
#[non_exhaustive]
pub struct Config { pub struct Config {
/// The number of requests that can be in flight at once. /// The number of requests that can be in flight at once.
/// `max_in_flight_requests` controls the size of the map used by the client /// `max_in_flight_requests` controls the size of the map used by the client
@@ -113,8 +114,6 @@ pub struct Config {
/// `pending_requests_buffer` controls the size of the channel clients use /// `pending_requests_buffer` controls the size of the channel clients use
/// to communicate with the request dispatch task. /// to communicate with the request dispatch task.
pub pending_request_buffer: usize, pub pending_request_buffer: usize,
#[doc(hidden)]
_non_exhaustive: (),
} }
impl Default for Config { impl Default for Config {
@@ -122,7 +121,6 @@ impl Default for Config {
Config { Config {
max_in_flight_requests: 1_000, max_in_flight_requests: 1_000,
pending_request_buffer: 100, pending_request_buffer: 100,
_non_exhaustive: (),
} }
} }
} }

View File

@@ -16,6 +16,7 @@ use std::time::{Duration, SystemTime};
/// The context should not be stored directly in a server implementation, because the context will /// The context should not be stored directly in a server implementation, because the context will
/// be different for each request in scope. /// be different for each request in scope.
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Context { pub struct Context {
/// When the client expects the request to be complete by. The server should cancel the request /// When the client expects the request to be complete by. The server should cancel the request
@@ -35,9 +36,6 @@ pub struct Context {
/// include the same `trace_id` as that included on the original request. This way, /// include the same `trace_id` as that included on the original request. This way,
/// users can trace related actions across a distributed system. /// users can trace related actions across a distributed system.
pub trace_context: trace::Context, pub trace_context: trace::Context,
#[doc(hidden)]
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
pub(crate) _non_exhaustive: (),
} }
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
@@ -51,7 +49,6 @@ pub fn current() -> Context {
Context { Context {
deadline: SystemTime::now() + Duration::from_secs(10), deadline: SystemTime::now() + Duration::from_secs(10),
trace_context: trace::Context::new_root(), trace_context: trace::Context::new_root(),
_non_exhaustive: (),
} }
} }

View File

@@ -38,6 +38,7 @@ use std::{io, time::SystemTime};
/// A message from a client to a server. /// A message from a client to a server.
#[derive(Debug)] #[derive(Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum ClientMessage<T> { pub enum ClientMessage<T> {
/// A request initiated by a user. The server responds to a request by invoking a /// A request initiated by a user. The server responds to a request by invoking a
/// service-provided request handler. The handler completes with a [`response`](Response), which /// service-provided request handler. The handler completes with a [`response`](Response), which
@@ -58,12 +59,11 @@ pub enum ClientMessage<T> {
/// The ID of the request to cancel. /// The ID of the request to cancel.
request_id: u64, request_id: u64,
}, },
#[doc(hidden)]
_NonExhaustive,
} }
/// A request from a client to a server. /// A request from a client to a server.
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Request<T> { pub struct Request<T> {
/// Trace context, deadline, and other cross-cutting concerns. /// Trace context, deadline, and other cross-cutting concerns.
@@ -72,26 +72,22 @@ pub struct Request<T> {
pub id: u64, pub id: u64,
/// The request body. /// The request body.
pub message: T, pub message: T,
#[doc(hidden)]
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
_non_exhaustive: (),
} }
/// A response from a server to a client. /// A response from a server to a client.
#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Response<T> { pub struct Response<T> {
/// The ID of the request being responded to. /// The ID of the request being responded to.
pub request_id: u64, pub request_id: u64,
/// The response body, or an error if the request failed. /// The response body, or an error if the request failed.
pub message: Result<T, ServerError>, pub message: Result<T, ServerError>,
#[doc(hidden)]
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
_non_exhaustive: (),
} }
/// An error response from a server to a client. /// An error response from a server to a client.
#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct ServerError { pub struct ServerError {
#[cfg_attr( #[cfg_attr(
@@ -106,9 +102,6 @@ pub struct ServerError {
pub kind: io::ErrorKind, pub kind: io::ErrorKind,
/// A message describing more detail about the error that occurred. /// A message describing more detail about the error that occurred.
pub detail: Option<String>, pub detail: Option<String>,
#[doc(hidden)]
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
_non_exhaustive: (),
} }
impl From<ServerError> for io::Error { impl From<ServerError> for io::Error {

View File

@@ -12,7 +12,6 @@ use fnv::FnvHashMap;
use futures::{channel::mpsc, future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*}; use futures::{channel::mpsc, future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
use log::{debug, info, trace}; use log::{debug, info, trace};
use pin_project::pin_project; use pin_project::pin_project;
use raii_counter::{Counter, WeakCounter};
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
use std::{ use std::{
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin, collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
@@ -32,7 +31,7 @@ where
dropped_keys: mpsc::UnboundedReceiver<K>, dropped_keys: mpsc::UnboundedReceiver<K>,
#[pin] #[pin]
dropped_keys_tx: mpsc::UnboundedSender<K>, dropped_keys_tx: mpsc::UnboundedSender<K>,
key_counts: FnvHashMap<K, TrackerPrototype<K>>, key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
keymaker: F, keymaker: F,
} }
@@ -42,37 +41,22 @@ where
pub struct TrackedChannel<C, K> { pub struct TrackedChannel<C, K> {
#[pin] #[pin]
inner: C, inner: C,
tracker: Tracker<K>, tracker: Arc<Tracker<K>>,
} }
#[derive(Clone, Debug)] #[derive(Debug)]
struct Tracker<K> { struct Tracker<K> {
key: Option<Arc<K>>, key: Option<K>,
counter: Counter,
dropped_keys: mpsc::UnboundedSender<K>, dropped_keys: mpsc::UnboundedSender<K>,
} }
impl<K> Drop for Tracker<K> { impl<K> Drop for Tracker<K> {
fn drop(&mut self) { fn drop(&mut self) {
if self.counter.count() <= 1 { // Don't care if the listener is dropped.
// Don't care if the listener is dropped. let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap());
match Arc::try_unwrap(self.key.take().unwrap()) {
Ok(key) => {
let _ = self.dropped_keys.unbounded_send(key);
}
_ => unreachable!(),
}
}
} }
} }
#[derive(Clone, Debug)]
struct TrackerPrototype<K> {
key: Weak<K>,
counter: WeakCounter,
dropped_keys: mpsc::UnboundedSender<K>,
}
impl<C, K> Stream for TrackedChannel<C, K> impl<C, K> Stream for TrackedChannel<C, K>
where where
C: Stream, C: Stream,
@@ -181,7 +165,7 @@ where
trace!( trace!(
"[{}] Opening channel ({}/{}) channels for key.", "[{}] Opening channel ({}/{}) channels for key.",
key, key,
tracker.counter.count(), Arc::strong_count(&tracker),
self.as_mut().project().channels_per_key self.as_mut().project().channels_per_key
); );
@@ -191,28 +175,22 @@ where
}) })
} }
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Tracker<K>, K> { fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
let channels_per_key = self.channels_per_key; let channels_per_key = self.channels_per_key;
let dropped_keys = self.dropped_keys_tx.clone(); let dropped_keys = self.dropped_keys_tx.clone();
let key_counts = &mut self.as_mut().project().key_counts; let key_counts = &mut self.as_mut().project().key_counts;
match key_counts.entry(key.clone()) { match key_counts.entry(key.clone()) {
Entry::Vacant(vacant) => { Entry::Vacant(vacant) => {
let key = Arc::new(key); let tracker = Arc::new(Tracker {
let counter = WeakCounter::new();
vacant.insert(TrackerPrototype {
key: Arc::downgrade(&key),
counter: counter.clone(),
dropped_keys: dropped_keys.clone(),
});
Ok(Tracker {
key: Some(key), key: Some(key),
counter: counter.upgrade(),
dropped_keys, dropped_keys,
}) });
vacant.insert(Arc::downgrade(&tracker));
Ok(tracker)
} }
Entry::Occupied(o) => { Entry::Occupied(mut o) => {
let count = o.get().counter.count(); let count = o.get().strong_count();
if count >= channels_per_key.try_into().unwrap() { if count >= channels_per_key.try_into().unwrap() {
info!( info!(
"[{}] Opened max channels from key ({}/{}).", "[{}] Opened max channels from key ({}/{}).",
@@ -220,16 +198,15 @@ where
); );
Err(key) Err(key)
} else { } else {
let TrackerPrototype { Ok(o.get().upgrade().unwrap_or_else(|| {
key, let tracker = Arc::new(Tracker {
counter, key: Some(key),
dropped_keys, dropped_keys,
} = o.get().clone(); });
Ok(Tracker {
counter: counter.upgrade(), *o.get_mut() = Arc::downgrade(&tracker);
key: Some(key.upgrade().unwrap()), tracker
dropped_keys, }))
})
} }
} }
} }
@@ -302,12 +279,10 @@ fn ctx() -> Context<'static> {
#[test] #[test]
fn tracker_drop() { fn tracker_drop() {
use assert_matches::assert_matches; use assert_matches::assert_matches;
use raii_counter::Counter;
let (tx, mut rx) = mpsc::unbounded(); let (tx, mut rx) = mpsc::unbounded();
Tracker { Tracker {
key: Some(Arc::new(1)), key: Some(1),
counter: Counter::new(),
dropped_keys: tx, dropped_keys: tx,
}; };
assert_matches!(rx.try_next(), Ok(Some(1))); assert_matches!(rx.try_next(), Ok(Some(1)));
@@ -317,17 +292,15 @@ fn tracker_drop() {
fn tracked_channel_stream() { fn tracked_channel_stream() {
use assert_matches::assert_matches; use assert_matches::assert_matches;
use pin_utils::pin_mut; use pin_utils::pin_mut;
use raii_counter::Counter;
let (chan_tx, chan) = mpsc::unbounded(); let (chan_tx, chan) = mpsc::unbounded();
let (dropped_keys, _) = mpsc::unbounded(); let (dropped_keys, _) = mpsc::unbounded();
let channel = TrackedChannel { let channel = TrackedChannel {
inner: chan, inner: chan,
tracker: Tracker { tracker: Arc::new(Tracker {
key: Some(Arc::new(1)), key: Some(1),
counter: Counter::new(),
dropped_keys, dropped_keys,
}, }),
}; };
chan_tx.unbounded_send("test").unwrap(); chan_tx.unbounded_send("test").unwrap();
@@ -339,17 +312,15 @@ fn tracked_channel_stream() {
fn tracked_channel_sink() { fn tracked_channel_sink() {
use assert_matches::assert_matches; use assert_matches::assert_matches;
use pin_utils::pin_mut; use pin_utils::pin_mut;
use raii_counter::Counter;
let (chan, mut chan_rx) = mpsc::unbounded(); let (chan, mut chan_rx) = mpsc::unbounded();
let (dropped_keys, _) = mpsc::unbounded(); let (dropped_keys, _) = mpsc::unbounded();
let channel = TrackedChannel { let channel = TrackedChannel {
inner: chan, inner: chan,
tracker: Tracker { tracker: Arc::new(Tracker {
key: Some(Arc::new(1)), key: Some(1),
counter: Counter::new(),
dropped_keys, dropped_keys,
}, }),
}; };
pin_mut!(channel); pin_mut!(channel);
@@ -371,12 +342,12 @@ fn channel_filter_increment_channels_for_key() {
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter); pin_mut!(filter);
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap(); let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
assert_eq!(tracker1.counter.count(), 1); assert_eq!(Arc::strong_count(&tracker1), 1);
let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap(); let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap();
assert_eq!(tracker1.counter.count(), 2); assert_eq!(Arc::strong_count(&tracker1), 2);
assert_matches!(filter.increment_channels_for_key("key"), Err("key")); assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
drop(tracker2); drop(tracker2);
assert_eq!(tracker1.counter.count(), 1); assert_eq!(Arc::strong_count(&tracker1), 1);
} }
#[test] #[test]
@@ -395,20 +366,20 @@ fn channel_filter_handle_new_channel() {
.as_mut() .as_mut()
.handle_new_channel(TestChannel { key: "key" }) .handle_new_channel(TestChannel { key: "key" })
.unwrap(); .unwrap();
assert_eq!(channel1.tracker.counter.count(), 1); assert_eq!(Arc::strong_count(&channel1.tracker), 1);
let channel2 = filter let channel2 = filter
.as_mut() .as_mut()
.handle_new_channel(TestChannel { key: "key" }) .handle_new_channel(TestChannel { key: "key" })
.unwrap(); .unwrap();
assert_eq!(channel1.tracker.counter.count(), 2); assert_eq!(Arc::strong_count(&channel1.tracker), 2);
assert_matches!( assert_matches!(
filter.handle_new_channel(TestChannel { key: "key" }), filter.handle_new_channel(TestChannel { key: "key" }),
Err("key") Err("key")
); );
drop(channel2); drop(channel2);
assert_eq!(channel1.tracker.counter.count(), 1); assert_eq!(Arc::strong_count(&channel1.tracker), 1);
} }
#[test] #[test]
@@ -429,14 +400,14 @@ fn channel_filter_poll_listener() {
.unwrap(); .unwrap();
let channel1 = let channel1 =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c); assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(channel1.tracker.counter.count(), 1); assert_eq!(Arc::strong_count(&channel1.tracker), 1);
new_channels new_channels
.unbounded_send(TestChannel { key: "key" }) .unbounded_send(TestChannel { key: "key" })
.unwrap(); .unwrap();
let _channel2 = let _channel2 =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c); assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(channel1.tracker.counter.count(), 2); assert_eq!(Arc::strong_count(&channel1.tracker), 2);
new_channels new_channels
.unbounded_send(TestChannel { key: "key" }) .unbounded_send(TestChannel { key: "key" })
@@ -444,7 +415,7 @@ fn channel_filter_poll_listener() {
let key = let key =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k); assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
assert_eq!(key, "key"); assert_eq!(key, "key");
assert_eq!(channel1.tracker.counter.count(), 2); assert_eq!(Arc::strong_count(&channel1.tracker), 2);
} }
#[test] #[test]

View File

@@ -304,7 +304,6 @@ where
} => { } => {
self.as_mut().cancel_request(&trace_context, request_id); self.as_mut().cancel_request(&trace_context, request_id);
} }
ClientMessage::_NonExhaustive => unreachable!(),
}, },
None => return Poll::Ready(None), None => return Poll::Ready(None),
} }
@@ -569,11 +568,9 @@ where
"Response did not complete before deadline of {}s.", "Response did not complete before deadline of {}s.",
format_rfc3339(self.deadline) format_rfc3339(self.deadline)
)), )),
_non_exhaustive: (),
}) })
} }
}, },
_non_exhaustive: (),
}); });
*self.as_mut().project().state = RespState::PollReady; *self.as_mut().project().state = RespState::PollReady;
} }
@@ -653,11 +650,9 @@ where
pub fn execute(self) -> impl Future<Output = ()> { pub fn execute(self) -> impl Future<Output = ()> {
use log::info; use log::info;
self.try_for_each(|request_handler| { self.try_for_each(|request_handler| async {
async { tokio::spawn(request_handler);
tokio::spawn(request_handler); Ok(())
Ok(())
}
}) })
.unwrap_or_else(|e| info!("ClientHandler errored out: {}", e)) .unwrap_or_else(|e| info!("ClientHandler errored out: {}", e))
} }

View File

@@ -87,11 +87,9 @@ impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
context: context::Context { context: context::Context {
deadline: SystemTime::UNIX_EPOCH, deadline: SystemTime::UNIX_EPOCH,
trace_context: Default::default(), trace_context: Default::default(),
_non_exhaustive: (),
}, },
id, id,
message, message,
_non_exhaustive: (),
})); }));
} }
} }

View File

@@ -61,9 +61,7 @@ where
message: Err(ServerError { message: Err(ServerError {
kind: io::ErrorKind::WouldBlock, kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()), detail: Some("Server throttled the request.".into()),
_non_exhaustive: (),
}), }),
_non_exhaustive: (),
})?; })?;
} }
None => return Poll::Ready(None), None => return Poll::Ready(None),
@@ -311,7 +309,6 @@ fn throttler_start_send() {
.start_send(Response { .start_send(Response {
request_id: 0, request_id: 0,
message: Ok(1), message: Ok(1),
_non_exhaustive: (),
}) })
.unwrap(); .unwrap();
assert!(throttler.inner.in_flight_requests.is_empty()); assert!(throttler.inner.in_flight_requests.is_empty());
@@ -320,7 +317,6 @@ fn throttler_start_send() {
Some(&Response { Some(&Response {
request_id: 0, request_id: 0,
message: Ok(1), message: Ok(1),
_non_exhaustive: ()
}) })
); );
} }