mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e521768ab | ||
|
|
e9b1e7d101 | ||
|
|
f0322fb892 | ||
|
|
617daebb88 | ||
|
|
a11d4fff58 | ||
|
|
bf42a04d83 | ||
|
|
06528d6953 | ||
|
|
9f00395746 | ||
|
|
e0674cd57f | ||
|
|
7e49bd9ee7 | ||
|
|
8a1baa9c4e | ||
|
|
31c713d188 |
18
README.md
18
README.md
@@ -1,6 +1,16 @@
|
||||
[](https://github.com/google/tarpc/actions?query=workflow%3A%22Continuous+integration%22)
|
||||
[](https://crates.io/crates/tarpc)
|
||||
[](https://discordapp.com/channels/647529123996237854)
|
||||
[![Crates.io][crates-badge]][crates-url]
|
||||
[![MIT licensed][mit-badge]][mit-url]
|
||||
[![Build status][gh-actions-badge]][gh-actions-url]
|
||||
[![Discord chat][discord-badge]][discord-url]
|
||||
|
||||
[crates-badge]: https://img.shields.io/crates/v/tarpc.svg
|
||||
[crates-url]: https://crates.io/crates/tarpc
|
||||
[mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg
|
||||
[mit-url]: LICENSE
|
||||
[gh-actions-badge]: https://github.com/google/tarpc/workflows/Continuous%20integration/badge.svg
|
||||
[gh-actions-url]: https://github.com/google/tarpc/actions?query=workflow%3A%22Continuous+integration%22
|
||||
[discord-badge]: https://img.shields.io/discord/647529123996237854.svg?logo=discord&style=flat-square
|
||||
[discord-url]: https://discord.gg/gXwpdSt
|
||||
|
||||
# tarpc
|
||||
|
||||
@@ -47,7 +57,7 @@ Some other features of tarpc:
|
||||
Add to your `Cargo.toml` dependencies:
|
||||
|
||||
```toml
|
||||
tarpc = { version = "0.18.0", features = ["full"] }
|
||||
tarpc = { version = "0.21.0", features = ["full"] }
|
||||
```
|
||||
|
||||
The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
|
||||
23
RELEASES.md
23
RELEASES.md
@@ -1,3 +1,26 @@
|
||||
## 0.21.0 (2020-06-26)
|
||||
|
||||
### New Features
|
||||
|
||||
A new proc macro, `#[tarpc::server]` was added! This enables service impls to elide the boilerplate
|
||||
of specifying associated types for each RPC. With the ubiquity of async-await, most code won't have
|
||||
nameable futures and will just be boxing the return type anyway. This macro does that for you.
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- Enums had _non_exhaustive fields replaced with the #[non_exhaustive] attribute.
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- https://github.com/google/tarpc/issues/304
|
||||
|
||||
A race condition in code that limits number of connections per client caused occasional panics.
|
||||
|
||||
- https://github.com/google/tarpc/pull/295
|
||||
|
||||
Made request timeouts account for time spent in the outbound buffer. Previously, a large outbound
|
||||
queue would lead to requests not timing out correctly.
|
||||
|
||||
## 0.20.0 (2019-12-11)
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
@@ -16,7 +16,7 @@ description = "An example server built on tarpc."
|
||||
clap = "2.0"
|
||||
futures = "0.3"
|
||||
serde = { version = "1.0" }
|
||||
tarpc = { version = "0.20", path = "../tarpc", features = ["full"] }
|
||||
tarpc = { version = "0.21", path = "../tarpc", features = ["full"] }
|
||||
tokio = { version = "0.2", features = ["full"] }
|
||||
tokio-serde = { version = "0.6", features = ["json"] }
|
||||
env_logger = "0.6"
|
||||
|
||||
@@ -89,12 +89,12 @@ if [ "$?" == 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
try_run "Building ... " cargo build --color=always
|
||||
try_run "Testing ... " cargo test --color=always
|
||||
try_run "Testing with all features enabled ... " cargo test --all-features --color=always
|
||||
for EXAMPLE in $(cargo run --example 2>&1 | grep ' ' | awk '{print $1}')
|
||||
try_run "Building ... " cargo +stable build --color=always
|
||||
try_run "Testing ... " cargo +stable test --color=always
|
||||
try_run "Testing with all features enabled ... " cargo +stable test --all-features --color=always
|
||||
for EXAMPLE in $(cargo +stable run --example 2>&1 | grep ' ' | awk '{print $1}')
|
||||
do
|
||||
try_run "Running example \"$EXAMPLE\" ... " cargo run --example $EXAMPLE
|
||||
try_run "Running example \"$EXAMPLE\" ... " cargo +stable run --example $EXAMPLE
|
||||
done
|
||||
|
||||
fi
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc-plugins"
|
||||
version = "0.7.0"
|
||||
version = "0.8.0"
|
||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||
edition = "2018"
|
||||
license = "MIT"
|
||||
@@ -30,3 +30,4 @@ proc-macro = true
|
||||
futures = "0.3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tarpc = { path = "../tarpc" }
|
||||
assert-type-eq = "0.1.0"
|
||||
|
||||
@@ -12,15 +12,18 @@ extern crate quote;
|
||||
extern crate syn;
|
||||
|
||||
use proc_macro::TokenStream;
|
||||
use quote::{format_ident, quote};
|
||||
use proc_macro2::TokenStream as TokenStream2;
|
||||
use quote::{format_ident, quote, ToTokens};
|
||||
use syn::{
|
||||
braced, parenthesized,
|
||||
braced,
|
||||
ext::IdentExt,
|
||||
parenthesized,
|
||||
parse::{Parse, ParseStream},
|
||||
parse_macro_input, parse_quote,
|
||||
parse_macro_input, parse_quote, parse_str,
|
||||
punctuated::Punctuated,
|
||||
token::Comma,
|
||||
Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type,
|
||||
Visibility,
|
||||
Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool,
|
||||
MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility,
|
||||
};
|
||||
|
||||
struct Service {
|
||||
@@ -42,7 +45,7 @@ impl Parse for Service {
|
||||
let attrs = input.call(Attribute::parse_outer)?;
|
||||
let vis = input.parse()?;
|
||||
input.parse::<Token![trait]>()?;
|
||||
let ident = input.parse()?;
|
||||
let ident: Ident = input.parse()?;
|
||||
let content;
|
||||
braced!(content in input);
|
||||
let mut rpcs = Vec::<RpcMethod>::new();
|
||||
@@ -53,7 +56,7 @@ impl Parse for Service {
|
||||
if rpc.ident == "new" {
|
||||
return Err(input.error(format!(
|
||||
"method name conflicts with generated fn `{}Client::new`",
|
||||
ident
|
||||
ident.unraw()
|
||||
)));
|
||||
}
|
||||
if rpc.ident == "serve" {
|
||||
@@ -63,6 +66,7 @@ impl Parse for Service {
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
attrs,
|
||||
vis,
|
||||
@@ -156,19 +160,19 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
ref rpcs,
|
||||
} = parse_macro_input!(input as Service);
|
||||
|
||||
let camel_case_fn_names: &[String] = &rpcs
|
||||
let camel_case_fn_names: &Vec<_> = &rpcs
|
||||
.iter()
|
||||
.map(|rpc| snake_to_camel(&rpc.ident.to_string()))
|
||||
.collect::<Vec<_>>();
|
||||
.map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
|
||||
.collect();
|
||||
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
|
||||
let response_fut_name = &format!("{}ResponseFut", ident);
|
||||
let derive_serialize = quote!(#[derive(serde::Serialize, serde::Deserialize)]);
|
||||
let response_fut_name = &format!("{}ResponseFut", ident.unraw());
|
||||
let derive_serialize = if derive_serde.0 {
|
||||
Some(quote!(#[derive(serde::Serialize, serde::Deserialize)]))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let gen_args = &GenArgs {
|
||||
attrs,
|
||||
vis,
|
||||
rpcs,
|
||||
args,
|
||||
ServiceGenerator {
|
||||
response_fut_name,
|
||||
service_ident: ident,
|
||||
server_ident: &format_ident!("Serve{}", ident),
|
||||
@@ -176,8 +180,12 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
client_ident: &format_ident!("{}Client", ident),
|
||||
request_ident: &format_ident!("{}Request", ident),
|
||||
response_ident: &format_ident!("{}Response", ident),
|
||||
vis,
|
||||
args,
|
||||
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
|
||||
method_names: &rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>(),
|
||||
method_idents: &rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>(),
|
||||
attrs,
|
||||
rpcs,
|
||||
return_types: &rpcs
|
||||
.iter()
|
||||
.map(|rpc| match rpc.output {
|
||||
@@ -185,7 +193,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
ReturnType::Default => unit_type,
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
arg_vars: &args
|
||||
arg_pats: &args
|
||||
.iter()
|
||||
.map(|args| args.iter().map(|arg| &*arg.pat).collect())
|
||||
.collect::<Vec<_>>(),
|
||||
@@ -194,43 +202,139 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
.zip(camel_case_fn_names.iter())
|
||||
.map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
|
||||
.collect::<Vec<_>>(),
|
||||
future_idents: &camel_case_fn_names
|
||||
future_types: &camel_case_fn_names
|
||||
.iter()
|
||||
.map(|name| format_ident!("{}Fut", name))
|
||||
.map(|name| parse_str(&format!("{}Fut", name)).unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
derive_serialize: if derive_serde.0 {
|
||||
Some(&derive_serialize)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
derive_serialize: derive_serialize.as_ref(),
|
||||
}
|
||||
.into_token_stream()
|
||||
.into()
|
||||
}
|
||||
|
||||
/// Transforms an async function into a sync one, returning a type declaration
|
||||
/// for the return type (a future).
|
||||
fn transform_method(method: &mut ImplItemMethod) -> ImplItemType {
|
||||
method.sig.asyncness = None;
|
||||
|
||||
// get either the return type or ().
|
||||
let ret = match &method.sig.output {
|
||||
ReturnType::Default => quote!(()),
|
||||
ReturnType::Type(_, ret) => quote!(#ret),
|
||||
};
|
||||
|
||||
let tokens = vec![
|
||||
trait_service(gen_args),
|
||||
struct_server(gen_args),
|
||||
impl_serve_for_server(gen_args),
|
||||
enum_request(gen_args),
|
||||
enum_response(gen_args),
|
||||
enum_response_future(gen_args),
|
||||
impl_debug_for_response_future(gen_args),
|
||||
impl_future_for_response_future(gen_args),
|
||||
struct_client(gen_args),
|
||||
impl_from_for_client(gen_args),
|
||||
impl_client_new(gen_args),
|
||||
impl_client_rpc_methods(gen_args),
|
||||
];
|
||||
// generate an identifier consisting of the method name to CamelCase with
|
||||
// Fut appended to it.
|
||||
let fut_name = snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut";
|
||||
let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span());
|
||||
|
||||
tokens
|
||||
.into_iter()
|
||||
.collect::<proc_macro2::TokenStream>()
|
||||
.into()
|
||||
// generate the updated return signature.
|
||||
method.sig.output = parse_quote! {
|
||||
-> ::core::pin::Pin<Box<
|
||||
dyn ::core::future::Future<Output = #ret> + ::core::marker::Send
|
||||
>>
|
||||
};
|
||||
|
||||
// transform the body of the method into Box::pin(async move { body }).
|
||||
let block = method.block.clone();
|
||||
method.block = parse_quote! [{
|
||||
Box::pin(async move
|
||||
#block
|
||||
)
|
||||
}];
|
||||
|
||||
// generate and return type declaration for return type.
|
||||
let t: ImplItemType = parse_quote! {
|
||||
type #fut_name_ident = ::core::pin::Pin<Box<dyn ::core::future::Future<Output = #ret> + ::core::marker::Send>>;
|
||||
};
|
||||
|
||||
t
|
||||
}
|
||||
|
||||
/// Syntactic sugar to make using async functions in the server implementation
|
||||
/// easier. It does this by rewriting code like this, which would normally not
|
||||
/// compile because async functions are disallowed in trait implementations:
|
||||
///
|
||||
/// ```rust
|
||||
/// # extern crate tarpc;
|
||||
/// # use tarpc::context;
|
||||
/// # use std::net::SocketAddr;
|
||||
/// #[tarpc_plugins::service]
|
||||
/// trait World {
|
||||
/// async fn hello(name: String) -> String;
|
||||
/// }
|
||||
///
|
||||
/// #[derive(Clone)]
|
||||
/// struct HelloServer(SocketAddr);
|
||||
///
|
||||
/// #[tarpc_plugins::server]
|
||||
/// impl World for HelloServer {
|
||||
/// async fn hello(self, _: context::Context, name: String) -> String {
|
||||
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Into code like this, which matches the service trait definition:
|
||||
///
|
||||
/// ```rust
|
||||
/// # extern crate tarpc;
|
||||
/// # use tarpc::context;
|
||||
/// # use std::pin::Pin;
|
||||
/// # use futures::Future;
|
||||
/// # use std::net::SocketAddr;
|
||||
/// #[tarpc_plugins::service]
|
||||
/// trait World {
|
||||
/// async fn hello(name: String) -> String;
|
||||
/// }
|
||||
///
|
||||
/// #[derive(Clone)]
|
||||
/// struct HelloServer(SocketAddr);
|
||||
///
|
||||
/// impl World for HelloServer {
|
||||
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
|
||||
///
|
||||
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
|
||||
/// + Send>> {
|
||||
/// Box::pin(async move {
|
||||
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
||||
/// })
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Note that this won't touch functions unless they have been annotated with
|
||||
/// `async`, meaning that this should not break existing code.
|
||||
#[proc_macro_attribute]
|
||||
pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let mut item = syn::parse_macro_input!(input as ItemImpl);
|
||||
|
||||
// the generated type declarations
|
||||
let mut types: Vec<ImplItemType> = Vec::new();
|
||||
|
||||
for inner in &mut item.items {
|
||||
if let ImplItem::Method(method) = inner {
|
||||
let sig = &method.sig;
|
||||
|
||||
// if this function is declared async, transform it into a regular function
|
||||
if sig.asyncness.is_some() {
|
||||
let typedecl = transform_method(method);
|
||||
types.push(typedecl);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add the type declarations into the impl block
|
||||
for t in types.into_iter() {
|
||||
item.items.push(syn::ImplItem::Type(t));
|
||||
}
|
||||
|
||||
TokenStream::from(quote!(#item))
|
||||
}
|
||||
|
||||
// Things needed to generate the service items: trait, serve impl, request/response enums, and
|
||||
// the client stub.
|
||||
struct GenArgs<'a> {
|
||||
attrs: &'a [Attribute],
|
||||
rpcs: &'a [RpcMethod],
|
||||
struct ServiceGenerator<'a> {
|
||||
service_ident: &'a Ident,
|
||||
server_ident: &'a Ident,
|
||||
response_fut_ident: &'a Ident,
|
||||
@@ -238,331 +342,356 @@ struct GenArgs<'a> {
|
||||
client_ident: &'a Ident,
|
||||
request_ident: &'a Ident,
|
||||
response_ident: &'a Ident,
|
||||
method_attrs: &'a [&'a [Attribute]],
|
||||
vis: &'a Visibility,
|
||||
method_names: &'a [&'a Ident],
|
||||
attrs: &'a [Attribute],
|
||||
rpcs: &'a [RpcMethod],
|
||||
camel_case_idents: &'a [Ident],
|
||||
future_types: &'a [Type],
|
||||
method_idents: &'a [&'a Ident],
|
||||
method_attrs: &'a [&'a [Attribute]],
|
||||
args: &'a [&'a [PatType]],
|
||||
return_types: &'a [&'a Type],
|
||||
arg_vars: &'a [Vec<&'a Pat>],
|
||||
camel_case_idents: &'a [Ident],
|
||||
future_idents: &'a [Ident],
|
||||
derive_serialize: Option<&'a proc_macro2::TokenStream>,
|
||||
arg_pats: &'a [Vec<&'a Pat>],
|
||||
derive_serialize: Option<&'a TokenStream2>,
|
||||
}
|
||||
|
||||
fn trait_service(
|
||||
&GenArgs {
|
||||
attrs,
|
||||
rpcs,
|
||||
vis,
|
||||
future_idents,
|
||||
return_types,
|
||||
service_ident,
|
||||
server_ident,
|
||||
..
|
||||
}: &GenArgs,
|
||||
) -> proc_macro2::TokenStream {
|
||||
let types_and_fns = rpcs
|
||||
.iter()
|
||||
.zip(future_idents.iter())
|
||||
.zip(return_types.iter())
|
||||
.map(
|
||||
|(
|
||||
(
|
||||
RpcMethod {
|
||||
attrs, ident, args, ..
|
||||
},
|
||||
future_type,
|
||||
),
|
||||
output,
|
||||
)| {
|
||||
let ty_doc = format!("The response future returned by {}.", ident);
|
||||
quote! {
|
||||
#[doc = #ty_doc]
|
||||
type #future_type: std::future::Future<Output = #output>;
|
||||
impl<'a> ServiceGenerator<'a> {
|
||||
fn trait_service(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
attrs,
|
||||
rpcs,
|
||||
vis,
|
||||
future_types,
|
||||
return_types,
|
||||
service_ident,
|
||||
server_ident,
|
||||
..
|
||||
} = self;
|
||||
|
||||
#( #attrs )*
|
||||
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type;
|
||||
}
|
||||
},
|
||||
);
|
||||
let types_and_fns = rpcs
|
||||
.iter()
|
||||
.zip(future_types.iter())
|
||||
.zip(return_types.iter())
|
||||
.map(
|
||||
|(
|
||||
(
|
||||
RpcMethod {
|
||||
attrs, ident, args, ..
|
||||
},
|
||||
future_type,
|
||||
),
|
||||
output,
|
||||
)| {
|
||||
let ty_doc = format!("The response future returned by {}.", ident);
|
||||
quote! {
|
||||
#[doc = #ty_doc]
|
||||
type #future_type: std::future::Future<Output = #output>;
|
||||
|
||||
quote! {
|
||||
#( #attrs )*
|
||||
#vis trait #service_ident: Clone {
|
||||
#( #types_and_fns )*
|
||||
#( #attrs )*
|
||||
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type;
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
/// Returns a serving function to use with tarpc::server::Server.
|
||||
fn serve(self) -> #server_ident<Self> {
|
||||
#server_ident { service: self }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
quote! {
|
||||
#( #attrs )*
|
||||
#vis trait #service_ident: Clone {
|
||||
#( #types_and_fns )*
|
||||
|
||||
fn struct_server(
|
||||
&GenArgs {
|
||||
vis, server_ident, ..
|
||||
}: &GenArgs,
|
||||
) -> proc_macro2::TokenStream {
|
||||
quote! {
|
||||
#[derive(Clone)]
|
||||
#vis struct #server_ident<S> {
|
||||
service: S,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_serve_for_server(
|
||||
&GenArgs {
|
||||
request_ident,
|
||||
server_ident,
|
||||
service_ident,
|
||||
response_ident,
|
||||
response_fut_ident,
|
||||
camel_case_idents,
|
||||
arg_vars,
|
||||
method_names,
|
||||
..
|
||||
}: &GenArgs,
|
||||
) -> proc_macro2::TokenStream {
|
||||
quote! {
|
||||
impl<S> tarpc::server::Serve<#request_ident> for #server_ident<S>
|
||||
where S: #service_ident
|
||||
{
|
||||
type Resp = #response_ident;
|
||||
type Fut = #response_fut_ident<S>;
|
||||
|
||||
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
|
||||
match req {
|
||||
#(
|
||||
#request_ident::#camel_case_idents{ #( #arg_vars ),* } => {
|
||||
#response_fut_ident::#camel_case_idents(
|
||||
#service_ident::#method_names(
|
||||
self.service, ctx, #( #arg_vars ),*))
|
||||
}
|
||||
)*
|
||||
/// Returns a serving function to use with tarpc::server::Server.
|
||||
fn serve(self) -> #server_ident<Self> {
|
||||
#server_ident { service: self }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn enum_request(
|
||||
&GenArgs {
|
||||
derive_serialize,
|
||||
vis,
|
||||
request_ident,
|
||||
camel_case_idents,
|
||||
args,
|
||||
..
|
||||
}: &GenArgs,
|
||||
) -> proc_macro2::TokenStream {
|
||||
quote! {
|
||||
/// The request sent over the wire from the client to the server.
|
||||
#[derive(Debug)]
|
||||
#derive_serialize
|
||||
#vis enum #request_ident {
|
||||
#( #camel_case_idents{ #( #args ),* } ),*
|
||||
}
|
||||
}
|
||||
}
|
||||
fn struct_server(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
vis, server_ident, ..
|
||||
} = self;
|
||||
|
||||
fn enum_response(
|
||||
&GenArgs {
|
||||
derive_serialize,
|
||||
vis,
|
||||
response_ident,
|
||||
camel_case_idents,
|
||||
return_types,
|
||||
..
|
||||
}: &GenArgs,
|
||||
) -> proc_macro2::TokenStream {
|
||||
quote! {
|
||||
/// The response sent over the wire from the server to the client.
|
||||
#[derive(Debug)]
|
||||
#derive_serialize
|
||||
#vis enum #response_ident {
|
||||
#( #camel_case_idents(#return_types) ),*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn enum_response_future(
|
||||
&GenArgs {
|
||||
vis,
|
||||
service_ident,
|
||||
response_fut_ident,
|
||||
camel_case_idents,
|
||||
future_idents,
|
||||
..
|
||||
}: &GenArgs,
|
||||
) -> proc_macro2::TokenStream {
|
||||
quote! {
|
||||
/// A future resolving to a server response.
|
||||
#vis enum #response_fut_ident<S: #service_ident> {
|
||||
#( #camel_case_idents(<S as #service_ident>::#future_idents) ),*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_debug_for_response_future(
|
||||
&GenArgs {
|
||||
service_ident,
|
||||
response_fut_ident,
|
||||
response_fut_name,
|
||||
..
|
||||
}: &GenArgs,
|
||||
) -> proc_macro2::TokenStream {
|
||||
quote! {
|
||||
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fmt.debug_struct(#response_fut_name).finish()
|
||||
quote! {
|
||||
#[derive(Clone)]
|
||||
#vis struct #server_ident<S> {
|
||||
service: S,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_future_for_response_future(
|
||||
&GenArgs {
|
||||
service_ident,
|
||||
response_fut_ident,
|
||||
response_ident,
|
||||
camel_case_idents,
|
||||
..
|
||||
}: &GenArgs,
|
||||
) -> proc_macro2::TokenStream {
|
||||
quote! {
|
||||
impl<S: #service_ident> std::future::Future for #response_fut_ident<S> {
|
||||
type Output = #response_ident;
|
||||
fn impl_serve_for_server(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
request_ident,
|
||||
server_ident,
|
||||
service_ident,
|
||||
response_ident,
|
||||
response_fut_ident,
|
||||
camel_case_idents,
|
||||
arg_pats,
|
||||
method_idents,
|
||||
..
|
||||
} = self;
|
||||
|
||||
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
|
||||
-> std::task::Poll<#response_ident>
|
||||
quote! {
|
||||
impl<S> tarpc::server::Serve<#request_ident> for #server_ident<S>
|
||||
where S: #service_ident
|
||||
{
|
||||
unsafe {
|
||||
match std::pin::Pin::get_unchecked_mut(self) {
|
||||
type Resp = #response_ident;
|
||||
type Fut = #response_fut_ident<S>;
|
||||
|
||||
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
|
||||
match req {
|
||||
#(
|
||||
#response_fut_ident::#camel_case_idents(resp) =>
|
||||
std::pin::Pin::new_unchecked(resp)
|
||||
.poll(cx)
|
||||
.map(#response_ident::#camel_case_idents),
|
||||
#request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
|
||||
#response_fut_ident::#camel_case_idents(
|
||||
#service_ident::#method_idents(
|
||||
self.service, ctx, #( #arg_pats ),*
|
||||
)
|
||||
)
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn struct_client(
|
||||
&GenArgs {
|
||||
vis,
|
||||
client_ident,
|
||||
request_ident,
|
||||
response_ident,
|
||||
..
|
||||
}: &GenArgs,
|
||||
) -> proc_macro2::TokenStream {
|
||||
quote! {
|
||||
#[allow(unused)]
|
||||
#[derive(Clone, Debug)]
|
||||
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
|
||||
#vis struct #client_ident<C = tarpc::client::Channel<#request_ident, #response_ident>>(C);
|
||||
}
|
||||
}
|
||||
fn enum_request(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
derive_serialize,
|
||||
vis,
|
||||
request_ident,
|
||||
camel_case_idents,
|
||||
args,
|
||||
..
|
||||
} = self;
|
||||
|
||||
fn impl_from_for_client(
|
||||
&GenArgs {
|
||||
client_ident,
|
||||
request_ident,
|
||||
response_ident,
|
||||
..
|
||||
}: &GenArgs,
|
||||
) -> proc_macro2::TokenStream {
|
||||
quote! {
|
||||
impl<C> From<C> for #client_ident<C>
|
||||
where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
|
||||
{
|
||||
fn from(client: C) -> Self {
|
||||
#client_ident(client)
|
||||
quote! {
|
||||
/// The request sent over the wire from the client to the server.
|
||||
#[derive(Debug)]
|
||||
#derive_serialize
|
||||
#vis enum #request_ident {
|
||||
#( #camel_case_idents{ #( #args ),* } ),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_client_new(
|
||||
&GenArgs {
|
||||
client_ident,
|
||||
vis,
|
||||
request_ident,
|
||||
response_ident,
|
||||
..
|
||||
}: &GenArgs,
|
||||
) -> proc_macro2::TokenStream {
|
||||
quote! {
|
||||
impl #client_ident {
|
||||
/// Returns a new client stub that sends requests over the given transport.
|
||||
#vis fn new<T>(config: tarpc::client::Config, transport: T)
|
||||
-> tarpc::client::NewClient<
|
||||
Self,
|
||||
tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>>
|
||||
where
|
||||
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>
|
||||
{
|
||||
let new_client = tarpc::client::new(config, transport);
|
||||
tarpc::client::NewClient {
|
||||
client: #client_ident(new_client.client),
|
||||
dispatch: new_client.dispatch,
|
||||
fn enum_response(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
derive_serialize,
|
||||
vis,
|
||||
response_ident,
|
||||
camel_case_idents,
|
||||
return_types,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
/// The response sent over the wire from the server to the client.
|
||||
#[derive(Debug)]
|
||||
#derive_serialize
|
||||
#vis enum #response_ident {
|
||||
#( #camel_case_idents(#return_types) ),*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn enum_response_future(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
vis,
|
||||
service_ident,
|
||||
response_fut_ident,
|
||||
camel_case_idents,
|
||||
future_types,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
/// A future resolving to a server response.
|
||||
#vis enum #response_fut_ident<S: #service_ident> {
|
||||
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_debug_for_response_future(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
service_ident,
|
||||
response_fut_ident,
|
||||
response_fut_name,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fmt.debug_struct(#response_fut_name).finish()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_client_rpc_methods(
|
||||
&GenArgs {
|
||||
client_ident,
|
||||
request_ident,
|
||||
response_ident,
|
||||
method_attrs,
|
||||
vis,
|
||||
method_names,
|
||||
args,
|
||||
return_types,
|
||||
arg_vars,
|
||||
camel_case_idents,
|
||||
..
|
||||
}: &GenArgs,
|
||||
) -> proc_macro2::TokenStream {
|
||||
quote! {
|
||||
impl<C> #client_ident<C>
|
||||
where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
|
||||
{
|
||||
#(
|
||||
#[allow(unused)]
|
||||
#( #method_attrs )*
|
||||
#vis fn #method_names(&mut self, ctx: tarpc::context::Context, #( #args ),*)
|
||||
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
|
||||
let request = #request_ident::#camel_case_idents { #( #arg_vars ),* };
|
||||
let resp = tarpc::Client::call(&mut self.0, ctx, request);
|
||||
async move {
|
||||
match resp.await? {
|
||||
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
|
||||
_ => unreachable!(),
|
||||
fn impl_future_for_response_future(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
service_ident,
|
||||
response_fut_ident,
|
||||
response_ident,
|
||||
camel_case_idents,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl<S: #service_ident> std::future::Future for #response_fut_ident<S> {
|
||||
type Output = #response_ident;
|
||||
|
||||
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
|
||||
-> std::task::Poll<#response_ident>
|
||||
{
|
||||
unsafe {
|
||||
match std::pin::Pin::get_unchecked_mut(self) {
|
||||
#(
|
||||
#response_fut_ident::#camel_case_idents(resp) =>
|
||||
std::pin::Pin::new_unchecked(resp)
|
||||
.poll(cx)
|
||||
.map(#response_ident::#camel_case_idents),
|
||||
)*
|
||||
}
|
||||
}
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn struct_client(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
vis,
|
||||
client_ident,
|
||||
request_ident,
|
||||
response_ident,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
#[allow(unused)]
|
||||
#[derive(Clone, Debug)]
|
||||
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
|
||||
#vis struct #client_ident<C = tarpc::client::Channel<#request_ident, #response_ident>>(C);
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_from_for_client(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
client_ident,
|
||||
request_ident,
|
||||
response_ident,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl<C> From<C> for #client_ident<C>
|
||||
where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
|
||||
{
|
||||
fn from(client: C) -> Self {
|
||||
#client_ident(client)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_client_new(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
client_ident,
|
||||
vis,
|
||||
request_ident,
|
||||
response_ident,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl #client_ident {
|
||||
/// Returns a new client stub that sends requests over the given transport.
|
||||
#vis fn new<T>(config: tarpc::client::Config, transport: T)
|
||||
-> tarpc::client::NewClient<
|
||||
Self,
|
||||
tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>
|
||||
>
|
||||
where
|
||||
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>
|
||||
{
|
||||
let new_client = tarpc::client::new(config, transport);
|
||||
tarpc::client::NewClient {
|
||||
client: #client_ident(new_client.client),
|
||||
dispatch: new_client.dispatch,
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_client_rpc_methods(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
client_ident,
|
||||
request_ident,
|
||||
response_ident,
|
||||
method_attrs,
|
||||
vis,
|
||||
method_idents,
|
||||
args,
|
||||
return_types,
|
||||
arg_pats,
|
||||
camel_case_idents,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl<C> #client_ident<C>
|
||||
where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
|
||||
{
|
||||
#(
|
||||
#[allow(unused)]
|
||||
#( #method_attrs )*
|
||||
#vis fn #method_idents(&mut self, ctx: tarpc::context::Context, #( #args ),*)
|
||||
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
|
||||
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
|
||||
let resp = tarpc::Client::call(&mut self.0, ctx, request);
|
||||
async move {
|
||||
match resp.await? {
|
||||
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ToTokens for ServiceGenerator<'a> {
|
||||
fn to_tokens(&self, output: &mut TokenStream2) {
|
||||
output.extend(vec![
|
||||
self.trait_service(),
|
||||
self.struct_server(),
|
||||
self.impl_serve_for_server(),
|
||||
self.enum_request(),
|
||||
self.enum_response(),
|
||||
self.enum_response_future(),
|
||||
self.impl_debug_for_response_future(),
|
||||
self.impl_future_for_response_future(),
|
||||
self.struct_client(),
|
||||
self.impl_from_for_client(),
|
||||
self.impl_client_new(),
|
||||
self.impl_client_rpc_methods(),
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
fn snake_to_camel(ident_str: &str) -> String {
|
||||
let chars = ident_str.chars();
|
||||
let mut camel_ty = String::with_capacity(ident_str.len());
|
||||
|
||||
let mut last_char_was_underscore = true;
|
||||
for c in chars {
|
||||
for c in ident_str.chars() {
|
||||
match c {
|
||||
'_' => last_char_was_underscore = true,
|
||||
c if last_char_was_underscore => {
|
||||
|
||||
144
plugins/tests/server.rs
Normal file
144
plugins/tests/server.rs
Normal file
@@ -0,0 +1,144 @@
|
||||
use assert_type_eq::assert_type_eq;
|
||||
use futures::Future;
|
||||
use std::pin::Pin;
|
||||
use tarpc::context;
|
||||
|
||||
// these need to be out here rather than inside the function so that the
|
||||
// assert_type_eq macro can pick them up.
|
||||
#[tarpc::service]
|
||||
trait Foo {
|
||||
async fn two_part(s: String, i: i32) -> (String, i32);
|
||||
async fn bar(s: String) -> String;
|
||||
async fn baz();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn type_generation_works() {
|
||||
#[tarpc::server]
|
||||
impl Foo for () {
|
||||
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
|
||||
(s, i)
|
||||
}
|
||||
|
||||
async fn bar(self, _: context::Context, s: String) -> String {
|
||||
s
|
||||
}
|
||||
|
||||
async fn baz(self, _: context::Context) {}
|
||||
}
|
||||
|
||||
// the assert_type_eq macro can only be used once per block.
|
||||
{
|
||||
assert_type_eq!(
|
||||
<() as Foo>::TwoPartFut,
|
||||
Pin<Box<dyn Future<Output = (String, i32)> + Send>>
|
||||
);
|
||||
}
|
||||
{
|
||||
assert_type_eq!(
|
||||
<() as Foo>::BarFut,
|
||||
Pin<Box<dyn Future<Output = String> + Send>>
|
||||
);
|
||||
}
|
||||
{
|
||||
assert_type_eq!(
|
||||
<() as Foo>::BazFut,
|
||||
Pin<Box<dyn Future<Output = ()> + Send>>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
#[test]
|
||||
fn raw_idents_work() {
|
||||
type r#yield = String;
|
||||
|
||||
#[tarpc::service]
|
||||
trait r#trait {
|
||||
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
|
||||
async fn r#fn(r#impl: r#yield) -> r#yield;
|
||||
async fn r#async();
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl r#trait for () {
|
||||
async fn r#await(
|
||||
self,
|
||||
_: context::Context,
|
||||
r#struct: r#yield,
|
||||
r#enum: i32,
|
||||
) -> (r#yield, i32) {
|
||||
(r#struct, r#enum)
|
||||
}
|
||||
|
||||
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
|
||||
r#impl
|
||||
}
|
||||
|
||||
async fn r#async(self, _: context::Context) {}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn syntax() {
|
||||
#[tarpc::service]
|
||||
trait Syntax {
|
||||
#[deny(warnings)]
|
||||
#[allow(non_snake_case)]
|
||||
async fn TestCamelCaseDoesntConflict();
|
||||
async fn hello() -> String;
|
||||
#[doc = "attr"]
|
||||
async fn attr(s: String) -> String;
|
||||
async fn no_args_no_return();
|
||||
async fn no_args() -> ();
|
||||
async fn one_arg(one: String) -> i32;
|
||||
async fn two_args_no_return(one: String, two: u64);
|
||||
async fn two_args(one: String, two: u64) -> String;
|
||||
async fn no_args_ret_error() -> i32;
|
||||
async fn one_arg_ret_error(one: String) -> String;
|
||||
async fn no_arg_implicit_return_error();
|
||||
#[doc = "attr"]
|
||||
async fn one_arg_implicit_return_error(one: String);
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl Syntax for () {
|
||||
#[deny(warnings)]
|
||||
#[allow(non_snake_case)]
|
||||
async fn TestCamelCaseDoesntConflict(self, _: context::Context) {}
|
||||
|
||||
async fn hello(self, _: context::Context) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
async fn attr(self, _: context::Context, _s: String) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
async fn no_args_no_return(self, _: context::Context) {}
|
||||
|
||||
async fn no_args(self, _: context::Context) -> () {}
|
||||
|
||||
async fn one_arg(self, _: context::Context, _one: String) -> i32 {
|
||||
0
|
||||
}
|
||||
|
||||
async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {}
|
||||
|
||||
async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
async fn no_args_ret_error(self, _: context::Context) -> i32 {
|
||||
0
|
||||
}
|
||||
|
||||
async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
async fn no_arg_implicit_return_error(self, _: context::Context) {}
|
||||
|
||||
async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {}
|
||||
}
|
||||
}
|
||||
@@ -29,6 +29,38 @@ fn att_service_trait() {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
#[test]
|
||||
fn raw_idents() {
|
||||
use futures::future::{ready, Ready};
|
||||
|
||||
type r#yield = String;
|
||||
|
||||
#[tarpc::service]
|
||||
trait r#trait {
|
||||
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
|
||||
async fn r#fn(r#impl: r#yield) -> r#yield;
|
||||
async fn r#async();
|
||||
}
|
||||
|
||||
impl r#trait for () {
|
||||
type AwaitFut = Ready<(r#yield, i32)>;
|
||||
fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut {
|
||||
ready((r#struct, r#enum))
|
||||
}
|
||||
|
||||
type FnFut = Ready<r#yield>;
|
||||
fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut {
|
||||
ready(r#impl)
|
||||
}
|
||||
|
||||
type AsyncFut = Ready<()>;
|
||||
fn r#async(self, _: context::Context) -> Self::AsyncFut {
|
||||
ready(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn syntax() {
|
||||
#[tarpc::service]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "tarpc"
|
||||
version = "0.20.0"
|
||||
version = "0.21.0"
|
||||
authors = ["Adam Wright <adam.austin.wright@gmail.com>", "Tim Kuehn <timothy.j.kuehn@gmail.com>"]
|
||||
edition = "2018"
|
||||
license = "MIT"
|
||||
@@ -30,13 +30,12 @@ fnv = "1.0"
|
||||
futures = "0.3"
|
||||
humantime = "1.0"
|
||||
log = "0.4"
|
||||
pin-project = "0.4"
|
||||
raii-counter = "0.2"
|
||||
pin-project = "0.4.17"
|
||||
rand = "0.7"
|
||||
tokio = { version = "0.2", features = ["time"] }
|
||||
serde = { optional = true, version = "1.0", features = ["derive"] }
|
||||
tokio-util = { optional = true, version = "0.2" }
|
||||
tarpc-plugins = { path = "../plugins", version = "0.7" }
|
||||
tarpc-plugins = { path = "../plugins", version = "0.8" }
|
||||
tokio-serde = { optional = true, version = "0.6" }
|
||||
|
||||
[dev-dependencies]
|
||||
@@ -61,4 +60,3 @@ required-features = ["full"]
|
||||
[[example]]
|
||||
name = "pubsub"
|
||||
required-features = ["full"]
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@
|
||||
//! Add to your `Cargo.toml` dependencies:
|
||||
//!
|
||||
//! ```toml
|
||||
//! tarpc = "0.20.0"
|
||||
//! tarpc = "0.21.0"
|
||||
//! ```
|
||||
//!
|
||||
//! The `tarpc::service` attribute expands to a collection of items that form an rpc service.
|
||||
@@ -215,7 +215,6 @@ pub mod trace;
|
||||
/// Rpc methods are specified, mirroring trait syntax:
|
||||
///
|
||||
/// ```
|
||||
/// # fn main() {}
|
||||
/// #[tarpc::service]
|
||||
/// trait Service {
|
||||
/// /// Say hello
|
||||
@@ -234,3 +233,59 @@ pub mod trace;
|
||||
/// * `Client` -- a client stub with a fn for each RPC.
|
||||
/// * `fn new_stub` -- creates a new Client stub.
|
||||
pub use tarpc_plugins::service;
|
||||
|
||||
/// A utility macro that can be used for RPC server implementations.
|
||||
///
|
||||
/// Syntactic sugar to make using async functions in the server implementation
|
||||
/// easier. It does this by rewriting code like this, which would normally not
|
||||
/// compile because async functions are disallowed in trait implementations:
|
||||
///
|
||||
/// ```rust
|
||||
/// # use tarpc::context;
|
||||
/// # use std::net::SocketAddr;
|
||||
/// #[tarpc::service]
|
||||
/// trait World {
|
||||
/// async fn hello(name: String) -> String;
|
||||
/// }
|
||||
///
|
||||
/// #[derive(Clone)]
|
||||
/// struct HelloServer(SocketAddr);
|
||||
///
|
||||
/// #[tarpc::server]
|
||||
/// impl World for HelloServer {
|
||||
/// async fn hello(self, _: context::Context, name: String) -> String {
|
||||
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Into code like this, which matches the service trait definition:
|
||||
///
|
||||
/// ```rust
|
||||
/// # use tarpc::context;
|
||||
/// # use std::pin::Pin;
|
||||
/// # use futures::Future;
|
||||
/// # use std::net::SocketAddr;
|
||||
/// #[derive(Clone)]
|
||||
/// struct HelloServer(SocketAddr);
|
||||
///
|
||||
/// #[tarpc::service]
|
||||
/// trait World {
|
||||
/// async fn hello(name: String) -> String;
|
||||
/// }
|
||||
///
|
||||
/// impl World for HelloServer {
|
||||
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
|
||||
///
|
||||
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
|
||||
/// + Send>> {
|
||||
/// Box::pin(async move {
|
||||
/// format!("Hello, {}! You are connected from {:?}.", name, self.0)
|
||||
/// })
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Note that this won't touch functions unless they have been annotated with
|
||||
/// `async`, meaning that this should not break existing code.
|
||||
pub use tarpc_plugins::server;
|
||||
|
||||
@@ -78,14 +78,21 @@ impl<'a, Req, Resp> Future for Send<'a, Req, Resp> {
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct Call<'a, Req, Resp> {
|
||||
#[pin]
|
||||
fut: AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>,
|
||||
fut: tokio::time::Timeout<AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>>,
|
||||
}
|
||||
|
||||
impl<'a, Req, Resp> Future for Call<'a, Req, Resp> {
|
||||
type Output = io::Result<Resp>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.as_mut().project().fut.poll(cx)
|
||||
let resp = ready!(self.as_mut().project().fut.poll(cx));
|
||||
Poll::Ready(match resp {
|
||||
Ok(resp) => resp,
|
||||
Err(tokio::time::Elapsed { .. }) => Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"Client dropped expired request.".to_string(),
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,13 +104,6 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
|
||||
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
|
||||
|
||||
let timeout = ctx.deadline.time_until();
|
||||
trace!(
|
||||
"[{}] Queuing request with timeout {:?}.",
|
||||
ctx.trace_id(),
|
||||
timeout,
|
||||
);
|
||||
|
||||
let (response_completion, response) = oneshot::channel();
|
||||
let cancellation = self.cancellation.clone();
|
||||
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
||||
@@ -116,7 +116,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
response_completion,
|
||||
})),
|
||||
DispatchResponse {
|
||||
response: tokio::time::timeout(timeout, response),
|
||||
response,
|
||||
complete: false,
|
||||
request_id,
|
||||
cancellation,
|
||||
@@ -128,9 +128,16 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
|
||||
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
||||
/// resolves to the response.
|
||||
pub fn call(&mut self, context: context::Context, request: Req) -> Call<Req, Resp> {
|
||||
pub fn call(&mut self, ctx: context::Context, request: Req) -> Call<Req, Resp> {
|
||||
let timeout = ctx.deadline.time_until();
|
||||
trace!(
|
||||
"[{}] Queuing request with timeout {:?}.",
|
||||
ctx.trace_id(),
|
||||
timeout,
|
||||
);
|
||||
|
||||
Call {
|
||||
fut: AndThenIdent::new(self.send(context, request)),
|
||||
fut: tokio::time::timeout(timeout, AndThenIdent::new(self.send(ctx, request))),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -140,7 +147,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
#[pin_project(PinnedDrop)]
|
||||
#[derive(Debug)]
|
||||
struct DispatchResponse<Resp> {
|
||||
response: tokio::time::Timeout<oneshot::Receiver<Response<Resp>>>,
|
||||
response: oneshot::Receiver<Response<Resp>>,
|
||||
ctx: context::Context,
|
||||
complete: bool,
|
||||
cancellation: RequestCancellation,
|
||||
@@ -152,24 +159,15 @@ impl<Resp> Future for DispatchResponse<Resp> {
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
|
||||
let resp = ready!(self.response.poll_unpin(cx));
|
||||
|
||||
self.complete = true;
|
||||
Poll::Ready(match resp {
|
||||
Ok(resp) => {
|
||||
self.complete = true;
|
||||
match resp {
|
||||
Ok(resp) => Ok(resp.message?),
|
||||
Err(oneshot::Canceled) => {
|
||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||
// there's nothing listening on the other side, so there's no point in
|
||||
// propagating cancellation.
|
||||
Err(io::Error::from(io::ErrorKind::ConnectionReset))
|
||||
}
|
||||
}
|
||||
Ok(resp) => Ok(resp.message?),
|
||||
Err(oneshot::Canceled) => {
|
||||
// The oneshot is Canceled when the dispatch task ends. In that case,
|
||||
// there's nothing listening on the other side, so there's no point in
|
||||
// propagating cancellation.
|
||||
Err(io::Error::from(io::ErrorKind::ConnectionReset))
|
||||
}
|
||||
Err(tokio::time::Elapsed { .. }) => Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"Client dropped expired request.".to_string(),
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -189,7 +187,7 @@ impl<Resp> PinnedDrop for DispatchResponse<Resp> {
|
||||
// closing the receiver before sending the cancel message, it is guaranteed that if the
|
||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||
// receiver as closed.
|
||||
self.response.get_mut().close();
|
||||
self.response.close();
|
||||
let request_id = self.request_id;
|
||||
self.cancellation.cancel(request_id);
|
||||
}
|
||||
@@ -385,9 +383,7 @@ where
|
||||
context: context::Context {
|
||||
deadline: dispatch_request.ctx.deadline,
|
||||
trace_context: dispatch_request.ctx.trace_context,
|
||||
_non_exhaustive: (),
|
||||
},
|
||||
_non_exhaustive: (),
|
||||
});
|
||||
self.as_mut().project().transport.start_send(request)?;
|
||||
self.as_mut().project().in_flight_requests.insert(
|
||||
@@ -632,11 +628,12 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project(project = TryChainProj)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
#[derive(Debug)]
|
||||
enum TryChain<Fut1, Fut2> {
|
||||
First(Fut1),
|
||||
Second(Fut2),
|
||||
First(#[pin] Fut1),
|
||||
Second(#[pin] Fut2),
|
||||
Empty,
|
||||
}
|
||||
|
||||
@@ -658,7 +655,7 @@ where
|
||||
}
|
||||
|
||||
fn poll<F>(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
f: F,
|
||||
) -> Poll<Result<Fut2::Ok, Fut2::Error>>
|
||||
@@ -667,31 +664,28 @@ where
|
||||
{
|
||||
let mut f = Some(f);
|
||||
|
||||
// Safe to call `get_unchecked_mut` because we won't move the futures.
|
||||
let this = unsafe { Pin::get_unchecked_mut(self) };
|
||||
|
||||
loop {
|
||||
let output = match this {
|
||||
TryChain::First(fut1) => {
|
||||
let output = match self.as_mut().project() {
|
||||
TryChainProj::First(fut1) => {
|
||||
// Poll the first future
|
||||
match unsafe { Pin::new_unchecked(fut1) }.try_poll(cx) {
|
||||
match fut1.try_poll(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(output) => output,
|
||||
}
|
||||
}
|
||||
TryChain::Second(fut2) => {
|
||||
TryChainProj::Second(fut2) => {
|
||||
// Poll the second future
|
||||
return unsafe { Pin::new_unchecked(fut2) }.try_poll(cx);
|
||||
return fut2.try_poll(cx);
|
||||
}
|
||||
TryChain::Empty => {
|
||||
TryChainProj::Empty => {
|
||||
panic!("future must not be polled after it returned `Poll::Ready`");
|
||||
}
|
||||
};
|
||||
|
||||
*this = TryChain::Empty; // Drop fut1
|
||||
self.set(TryChain::Empty); // Drop fut1
|
||||
let f = f.take().unwrap();
|
||||
match f(output) {
|
||||
TryChainAction::Future(fut2) => *this = TryChain::Second(fut2),
|
||||
TryChainAction::Future(fut2) => self.set(TryChain::Second(fut2)),
|
||||
TryChainAction::Output(output) => return Poll::Ready(output),
|
||||
}
|
||||
}
|
||||
@@ -716,24 +710,21 @@ mod tests {
|
||||
prelude::*,
|
||||
task::*,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc};
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
async fn dispatch_response_cancels_on_timeout() {
|
||||
let (_response_completion, response) = oneshot::channel();
|
||||
async fn dispatch_response_cancels_on_drop() {
|
||||
let (cancellation, mut canceled_requests) = cancellations();
|
||||
let resp = DispatchResponse::<u64> {
|
||||
// Timeout in the past should cause resp to error out when polled.
|
||||
response: tokio::time::timeout(Duration::from_secs(0), response),
|
||||
let (_, response) = oneshot::channel();
|
||||
drop(DispatchResponse::<u32> {
|
||||
response,
|
||||
cancellation,
|
||||
complete: false,
|
||||
request_id: 3,
|
||||
cancellation,
|
||||
ctx: context::current(),
|
||||
};
|
||||
let _ = futures::poll!(resp);
|
||||
});
|
||||
// resp's drop() is run, which should send a cancel message.
|
||||
assert!(canceled_requests.0.try_next().unwrap() == Some(3));
|
||||
assert_eq!(canceled_requests.0.try_next().unwrap(), Some(3));
|
||||
}
|
||||
|
||||
#[tokio::test(threaded_scheduler)]
|
||||
@@ -768,7 +759,6 @@ mod tests {
|
||||
Response {
|
||||
request_id: 0,
|
||||
message: Ok("hello".into()),
|
||||
_non_exhaustive: (),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
@@ -822,7 +812,7 @@ mod tests {
|
||||
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
|
||||
// map.
|
||||
let mut resp = send_request(&mut channel, "hi").await;
|
||||
resp.response.get_mut().close();
|
||||
resp.response.close();
|
||||
|
||||
assert!(dispatch.poll_next_request(cx).is_pending());
|
||||
}
|
||||
|
||||
@@ -104,6 +104,7 @@ where
|
||||
|
||||
/// Settings that control the behavior of the client.
|
||||
#[derive(Clone, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub struct Config {
|
||||
/// The number of requests that can be in flight at once.
|
||||
/// `max_in_flight_requests` controls the size of the map used by the client
|
||||
@@ -113,8 +114,6 @@ pub struct Config {
|
||||
/// `pending_requests_buffer` controls the size of the channel clients use
|
||||
/// to communicate with the request dispatch task.
|
||||
pub pending_request_buffer: usize,
|
||||
#[doc(hidden)]
|
||||
_non_exhaustive: (),
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
@@ -122,7 +121,6 @@ impl Default for Config {
|
||||
Config {
|
||||
max_in_flight_requests: 1_000,
|
||||
pending_request_buffer: 100,
|
||||
_non_exhaustive: (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ use std::time::{Duration, SystemTime};
|
||||
/// The context should not be stored directly in a server implementation, because the context will
|
||||
/// be different for each request in scope.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Context {
|
||||
/// When the client expects the request to be complete by. The server should cancel the request
|
||||
@@ -35,9 +36,6 @@ pub struct Context {
|
||||
/// include the same `trace_id` as that included on the original request. This way,
|
||||
/// users can trace related actions across a distributed system.
|
||||
pub trace_context: trace::Context,
|
||||
#[doc(hidden)]
|
||||
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
|
||||
pub(crate) _non_exhaustive: (),
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde1")]
|
||||
@@ -51,7 +49,6 @@ pub fn current() -> Context {
|
||||
Context {
|
||||
deadline: SystemTime::now() + Duration::from_secs(10),
|
||||
trace_context: trace::Context::new_root(),
|
||||
_non_exhaustive: (),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ use std::{io, time::SystemTime};
|
||||
/// A message from a client to a server.
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[non_exhaustive]
|
||||
pub enum ClientMessage<T> {
|
||||
/// A request initiated by a user. The server responds to a request by invoking a
|
||||
/// service-provided request handler. The handler completes with a [`response`](Response), which
|
||||
@@ -58,12 +59,11 @@ pub enum ClientMessage<T> {
|
||||
/// The ID of the request to cancel.
|
||||
request_id: u64,
|
||||
},
|
||||
#[doc(hidden)]
|
||||
_NonExhaustive,
|
||||
}
|
||||
|
||||
/// A request from a client to a server.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Request<T> {
|
||||
/// Trace context, deadline, and other cross-cutting concerns.
|
||||
@@ -72,26 +72,22 @@ pub struct Request<T> {
|
||||
pub id: u64,
|
||||
/// The request body.
|
||||
pub message: T,
|
||||
#[doc(hidden)]
|
||||
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
|
||||
_non_exhaustive: (),
|
||||
}
|
||||
|
||||
/// A response from a server to a client.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Response<T> {
|
||||
/// The ID of the request being responded to.
|
||||
pub request_id: u64,
|
||||
/// The response body, or an error if the request failed.
|
||||
pub message: Result<T, ServerError>,
|
||||
#[doc(hidden)]
|
||||
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
|
||||
_non_exhaustive: (),
|
||||
}
|
||||
|
||||
/// An error response from a server to a client.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ServerError {
|
||||
#[cfg_attr(
|
||||
@@ -106,9 +102,6 @@ pub struct ServerError {
|
||||
pub kind: io::ErrorKind,
|
||||
/// A message describing more detail about the error that occurred.
|
||||
pub detail: Option<String>,
|
||||
#[doc(hidden)]
|
||||
#[cfg_attr(feature = "serde1", serde(skip_serializing, default))]
|
||||
_non_exhaustive: (),
|
||||
}
|
||||
|
||||
impl From<ServerError> for io::Error {
|
||||
|
||||
@@ -12,7 +12,6 @@ use fnv::FnvHashMap;
|
||||
use futures::{channel::mpsc, future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
|
||||
use log::{debug, info, trace};
|
||||
use pin_project::pin_project;
|
||||
use raii_counter::{Counter, WeakCounter};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::{
|
||||
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
|
||||
@@ -32,7 +31,7 @@ where
|
||||
dropped_keys: mpsc::UnboundedReceiver<K>,
|
||||
#[pin]
|
||||
dropped_keys_tx: mpsc::UnboundedSender<K>,
|
||||
key_counts: FnvHashMap<K, TrackerPrototype<K>>,
|
||||
key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
|
||||
keymaker: F,
|
||||
}
|
||||
|
||||
@@ -42,37 +41,22 @@ where
|
||||
pub struct TrackedChannel<C, K> {
|
||||
#[pin]
|
||||
inner: C,
|
||||
tracker: Tracker<K>,
|
||||
tracker: Arc<Tracker<K>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Debug)]
|
||||
struct Tracker<K> {
|
||||
key: Option<Arc<K>>,
|
||||
counter: Counter,
|
||||
key: Option<K>,
|
||||
dropped_keys: mpsc::UnboundedSender<K>,
|
||||
}
|
||||
|
||||
impl<K> Drop for Tracker<K> {
|
||||
fn drop(&mut self) {
|
||||
if self.counter.count() <= 1 {
|
||||
// Don't care if the listener is dropped.
|
||||
match Arc::try_unwrap(self.key.take().unwrap()) {
|
||||
Ok(key) => {
|
||||
let _ = self.dropped_keys.unbounded_send(key);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
// Don't care if the listener is dropped.
|
||||
let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TrackerPrototype<K> {
|
||||
key: Weak<K>,
|
||||
counter: WeakCounter,
|
||||
dropped_keys: mpsc::UnboundedSender<K>,
|
||||
}
|
||||
|
||||
impl<C, K> Stream for TrackedChannel<C, K>
|
||||
where
|
||||
C: Stream,
|
||||
@@ -181,7 +165,7 @@ where
|
||||
trace!(
|
||||
"[{}] Opening channel ({}/{}) channels for key.",
|
||||
key,
|
||||
tracker.counter.count(),
|
||||
Arc::strong_count(&tracker),
|
||||
self.as_mut().project().channels_per_key
|
||||
);
|
||||
|
||||
@@ -191,28 +175,22 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Tracker<K>, K> {
|
||||
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
|
||||
let channels_per_key = self.channels_per_key;
|
||||
let dropped_keys = self.dropped_keys_tx.clone();
|
||||
let key_counts = &mut self.as_mut().project().key_counts;
|
||||
match key_counts.entry(key.clone()) {
|
||||
Entry::Vacant(vacant) => {
|
||||
let key = Arc::new(key);
|
||||
let counter = WeakCounter::new();
|
||||
|
||||
vacant.insert(TrackerPrototype {
|
||||
key: Arc::downgrade(&key),
|
||||
counter: counter.clone(),
|
||||
dropped_keys: dropped_keys.clone(),
|
||||
});
|
||||
Ok(Tracker {
|
||||
let tracker = Arc::new(Tracker {
|
||||
key: Some(key),
|
||||
counter: counter.upgrade(),
|
||||
dropped_keys,
|
||||
})
|
||||
});
|
||||
|
||||
vacant.insert(Arc::downgrade(&tracker));
|
||||
Ok(tracker)
|
||||
}
|
||||
Entry::Occupied(o) => {
|
||||
let count = o.get().counter.count();
|
||||
Entry::Occupied(mut o) => {
|
||||
let count = o.get().strong_count();
|
||||
if count >= channels_per_key.try_into().unwrap() {
|
||||
info!(
|
||||
"[{}] Opened max channels from key ({}/{}).",
|
||||
@@ -220,16 +198,15 @@ where
|
||||
);
|
||||
Err(key)
|
||||
} else {
|
||||
let TrackerPrototype {
|
||||
key,
|
||||
counter,
|
||||
dropped_keys,
|
||||
} = o.get().clone();
|
||||
Ok(Tracker {
|
||||
counter: counter.upgrade(),
|
||||
key: Some(key.upgrade().unwrap()),
|
||||
dropped_keys,
|
||||
})
|
||||
Ok(o.get().upgrade().unwrap_or_else(|| {
|
||||
let tracker = Arc::new(Tracker {
|
||||
key: Some(key),
|
||||
dropped_keys,
|
||||
});
|
||||
|
||||
*o.get_mut() = Arc::downgrade(&tracker);
|
||||
tracker
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -302,12 +279,10 @@ fn ctx() -> Context<'static> {
|
||||
#[test]
|
||||
fn tracker_drop() {
|
||||
use assert_matches::assert_matches;
|
||||
use raii_counter::Counter;
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded();
|
||||
Tracker {
|
||||
key: Some(Arc::new(1)),
|
||||
counter: Counter::new(),
|
||||
key: Some(1),
|
||||
dropped_keys: tx,
|
||||
};
|
||||
assert_matches!(rx.try_next(), Ok(Some(1)));
|
||||
@@ -317,17 +292,15 @@ fn tracker_drop() {
|
||||
fn tracked_channel_stream() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
use raii_counter::Counter;
|
||||
|
||||
let (chan_tx, chan) = mpsc::unbounded();
|
||||
let (dropped_keys, _) = mpsc::unbounded();
|
||||
let channel = TrackedChannel {
|
||||
inner: chan,
|
||||
tracker: Tracker {
|
||||
key: Some(Arc::new(1)),
|
||||
counter: Counter::new(),
|
||||
tracker: Arc::new(Tracker {
|
||||
key: Some(1),
|
||||
dropped_keys,
|
||||
},
|
||||
}),
|
||||
};
|
||||
|
||||
chan_tx.unbounded_send("test").unwrap();
|
||||
@@ -339,17 +312,15 @@ fn tracked_channel_stream() {
|
||||
fn tracked_channel_sink() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
use raii_counter::Counter;
|
||||
|
||||
let (chan, mut chan_rx) = mpsc::unbounded();
|
||||
let (dropped_keys, _) = mpsc::unbounded();
|
||||
let channel = TrackedChannel {
|
||||
inner: chan,
|
||||
tracker: Tracker {
|
||||
key: Some(Arc::new(1)),
|
||||
counter: Counter::new(),
|
||||
tracker: Arc::new(Tracker {
|
||||
key: Some(1),
|
||||
dropped_keys,
|
||||
},
|
||||
}),
|
||||
};
|
||||
|
||||
pin_mut!(channel);
|
||||
@@ -371,12 +342,12 @@ fn channel_filter_increment_channels_for_key() {
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
|
||||
assert_eq!(tracker1.counter.count(), 1);
|
||||
assert_eq!(Arc::strong_count(&tracker1), 1);
|
||||
let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap();
|
||||
assert_eq!(tracker1.counter.count(), 2);
|
||||
assert_eq!(Arc::strong_count(&tracker1), 2);
|
||||
assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
|
||||
drop(tracker2);
|
||||
assert_eq!(tracker1.counter.count(), 1);
|
||||
assert_eq!(Arc::strong_count(&tracker1), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -395,20 +366,20 @@ fn channel_filter_handle_new_channel() {
|
||||
.as_mut()
|
||||
.handle_new_channel(TestChannel { key: "key" })
|
||||
.unwrap();
|
||||
assert_eq!(channel1.tracker.counter.count(), 1);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
|
||||
|
||||
let channel2 = filter
|
||||
.as_mut()
|
||||
.handle_new_channel(TestChannel { key: "key" })
|
||||
.unwrap();
|
||||
assert_eq!(channel1.tracker.counter.count(), 2);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
|
||||
|
||||
assert_matches!(
|
||||
filter.handle_new_channel(TestChannel { key: "key" }),
|
||||
Err("key")
|
||||
);
|
||||
drop(channel2);
|
||||
assert_eq!(channel1.tracker.counter.count(), 1);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -429,14 +400,14 @@ fn channel_filter_poll_listener() {
|
||||
.unwrap();
|
||||
let channel1 =
|
||||
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
|
||||
assert_eq!(channel1.tracker.counter.count(), 1);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
|
||||
|
||||
new_channels
|
||||
.unbounded_send(TestChannel { key: "key" })
|
||||
.unwrap();
|
||||
let _channel2 =
|
||||
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
|
||||
assert_eq!(channel1.tracker.counter.count(), 2);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
|
||||
|
||||
new_channels
|
||||
.unbounded_send(TestChannel { key: "key" })
|
||||
@@ -444,7 +415,7 @@ fn channel_filter_poll_listener() {
|
||||
let key =
|
||||
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
|
||||
assert_eq!(key, "key");
|
||||
assert_eq!(channel1.tracker.counter.count(), 2);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -304,7 +304,6 @@ where
|
||||
} => {
|
||||
self.as_mut().cancel_request(&trace_context, request_id);
|
||||
}
|
||||
ClientMessage::_NonExhaustive => unreachable!(),
|
||||
},
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
@@ -569,11 +568,9 @@ where
|
||||
"Response did not complete before deadline of {}s.",
|
||||
format_rfc3339(self.deadline)
|
||||
)),
|
||||
_non_exhaustive: (),
|
||||
})
|
||||
}
|
||||
},
|
||||
_non_exhaustive: (),
|
||||
});
|
||||
*self.as_mut().project().state = RespState::PollReady;
|
||||
}
|
||||
@@ -653,11 +650,9 @@ where
|
||||
pub fn execute(self) -> impl Future<Output = ()> {
|
||||
use log::info;
|
||||
|
||||
self.try_for_each(|request_handler| {
|
||||
async {
|
||||
tokio::spawn(request_handler);
|
||||
Ok(())
|
||||
}
|
||||
self.try_for_each(|request_handler| async {
|
||||
tokio::spawn(request_handler);
|
||||
Ok(())
|
||||
})
|
||||
.unwrap_or_else(|e| info!("ClientHandler errored out: {}", e))
|
||||
}
|
||||
|
||||
@@ -87,11 +87,9 @@ impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
|
||||
context: context::Context {
|
||||
deadline: SystemTime::UNIX_EPOCH,
|
||||
trace_context: Default::default(),
|
||||
_non_exhaustive: (),
|
||||
},
|
||||
id,
|
||||
message,
|
||||
_non_exhaustive: (),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,9 +61,7 @@ where
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::WouldBlock,
|
||||
detail: Some("Server throttled the request.".into()),
|
||||
_non_exhaustive: (),
|
||||
}),
|
||||
_non_exhaustive: (),
|
||||
})?;
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
@@ -311,7 +309,6 @@ fn throttler_start_send() {
|
||||
.start_send(Response {
|
||||
request_id: 0,
|
||||
message: Ok(1),
|
||||
_non_exhaustive: (),
|
||||
})
|
||||
.unwrap();
|
||||
assert!(throttler.inner.in_flight_requests.is_empty());
|
||||
@@ -320,7 +317,6 @@ fn throttler_start_send() {
|
||||
Some(&Response {
|
||||
request_id: 0,
|
||||
message: Ok(1),
|
||||
_non_exhaustive: ()
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user