Use async fn in generated traits!!

The major breaking change is that Channel::execute no longer internally
spawns RPC handlers, because it is no longer possible to place a Send
bound on the return type of Serve::serve. Instead, Channel::execute
returns a stream of RPC handler futures.

Service authors can reproduce the old behavior by spawning each response
handler (the compiler knows whether or not the futures can be spawned;
it's just that the bounds can't be expressed generically):

    channel.execute(server.serve())
           .for_each(|rpc| { tokio::spawn(rpc); })
This commit is contained in:
Tim Kuehn
2022-11-23 01:36:51 -08:00
committed by Tim
parent 7c5afa97bb
commit 8dc3711a80
31 changed files with 421 additions and 838 deletions

View File

@@ -4,6 +4,9 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use std::env;
use tracing_subscriber::{fmt::format::FmtSpan, prelude::*};

View File

@@ -4,6 +4,9 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use clap::Parser;
use futures::{future, prelude::*};
use rand::{
@@ -34,7 +37,6 @@ struct Flags {
#[derive(Clone)]
struct HelloServer(SocketAddr);
#[tarpc::server]
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
let sleep_time =
@@ -44,6 +46,10 @@ impl World for HelloServer {
}
}
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let flags = Flags::parse();
@@ -66,7 +72,7 @@ async fn main() -> anyhow::Result<()> {
// the generated World trait.
.map(|channel| {
let server = HelloServer(channel.transport().peer_addr().unwrap());
channel.execute(server.serve())
channel.execute(server.serve()).for_each(spawn)
})
// Max 10 channels.
.buffer_unordered(10)

View File

@@ -12,18 +12,18 @@ extern crate quote;
extern crate syn;
use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote, ToTokens};
use syn::{
braced,
ext::IdentExt,
parenthesized,
parse::{Parse, ParseStream},
parse_macro_input, parse_quote, parse_str,
parse_macro_input, parse_quote,
spanned::Spanned,
token::Comma,
Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool,
MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility,
Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type,
Visibility,
};
/// Accumulates multiple errors into a result.
@@ -257,7 +257,6 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
.map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
.collect();
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
let response_fut_name = &format!("{}ResponseFut", ident.unraw());
let derive_serialize = if derive_serde.0 {
Some(
quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
@@ -274,11 +273,9 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
.collect::<Vec<_>>();
ServiceGenerator {
response_fut_name,
service_ident: ident,
client_stub_ident: &format_ident!("{}Stub", ident),
server_ident: &format_ident!("Serve{}", ident),
response_fut_ident: &Ident::new(response_fut_name, ident.span()),
client_ident: &format_ident!("{}Client", ident),
request_ident: &format_ident!("{}Request", ident),
response_ident: &format_ident!("{}Response", ident),
@@ -305,138 +302,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
.zip(camel_case_fn_names.iter())
.map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
.collect::<Vec<_>>(),
future_types: &camel_case_fn_names
.iter()
.map(|name| parse_str(&format!("{name}Fut")).unwrap())
.collect::<Vec<_>>(),
derive_serialize: derive_serialize.as_ref(),
}
.into_token_stream()
.into()
}
/// generate an identifier consisting of the method name to CamelCase with
/// Fut appended to it.
fn associated_type_for_rpc(method: &ImplItemMethod) -> String {
snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut"
}
/// Transforms an async function into a sync one, returning a type declaration
/// for the return type (a future).
fn transform_method(method: &mut ImplItemMethod) -> ImplItemType {
method.sig.asyncness = None;
// get either the return type or ().
let ret = match &method.sig.output {
ReturnType::Default => quote!(()),
ReturnType::Type(_, ret) => quote!(#ret),
};
let fut_name = associated_type_for_rpc(method);
let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span());
// generate the updated return signature.
method.sig.output = parse_quote! {
-> ::core::pin::Pin<Box<
dyn ::core::future::Future<Output = #ret> + ::core::marker::Send
>>
};
// transform the body of the method into Box::pin(async move { body }).
let block = method.block.clone();
method.block = parse_quote! [{
Box::pin(async move
#block
)
}];
// generate and return type declaration for return type.
let t: ImplItemType = parse_quote! {
type #fut_name_ident = ::core::pin::Pin<Box<dyn ::core::future::Future<Output = #ret> + ::core::marker::Send>>;
};
t
}
#[proc_macro_attribute]
pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream {
let mut item = syn::parse_macro_input!(input as ItemImpl);
let span = item.span();
// the generated type declarations
let mut types: Vec<ImplItemType> = Vec::new();
let mut expected_non_async_types: Vec<(&ImplItemMethod, String)> = Vec::new();
let mut found_non_async_types: Vec<&ImplItemType> = Vec::new();
for inner in &mut item.items {
match inner {
ImplItem::Method(method) => {
if method.sig.asyncness.is_some() {
// if this function is declared async, transform it into a regular function
let typedecl = transform_method(method);
types.push(typedecl);
} else {
// If it's not async, keep track of all required associated types for better
// error reporting.
expected_non_async_types.push((method, associated_type_for_rpc(method)));
}
}
ImplItem::Type(typedecl) => found_non_async_types.push(typedecl),
_ => {}
}
}
if let Err(e) =
verify_types_were_provided(span, &expected_non_async_types, &found_non_async_types)
{
return TokenStream::from(e.to_compile_error());
}
// add the type declarations into the impl block
for t in types.into_iter() {
item.items.push(syn::ImplItem::Type(t));
}
TokenStream::from(quote!(#item))
}
fn verify_types_were_provided(
span: Span,
expected: &[(&ImplItemMethod, String)],
provided: &[&ImplItemType],
) -> syn::Result<()> {
let mut result = Ok(());
for (method, expected) in expected {
if !provided.iter().any(|typedecl| typedecl.ident == expected) {
let mut e = syn::Error::new(
span,
format!("not all trait items implemented, missing: `{expected}`"),
);
let fn_span = method.sig.fn_token.span();
e.extend(syn::Error::new(
fn_span.join(method.sig.ident.span()).unwrap_or(fn_span),
format!(
"hint: `#[tarpc::server]` only rewrites async fns, and `fn {}` is not async",
method.sig.ident
),
));
match result {
Ok(_) => result = Err(e),
Err(ref mut error) => error.extend(Some(e)),
}
}
}
result
}
// Things needed to generate the service items: trait, serve impl, request/response enums, and
// the client stub.
struct ServiceGenerator<'a> {
service_ident: &'a Ident,
client_stub_ident: &'a Ident,
server_ident: &'a Ident,
response_fut_ident: &'a Ident,
response_fut_name: &'a str,
client_ident: &'a Ident,
request_ident: &'a Ident,
response_ident: &'a Ident,
@@ -444,7 +321,6 @@ struct ServiceGenerator<'a> {
attrs: &'a [Attribute],
rpcs: &'a [RpcMethod],
camel_case_idents: &'a [Ident],
future_types: &'a [Type],
method_idents: &'a [&'a Ident],
request_names: &'a [String],
method_attrs: &'a [&'a [Attribute]],
@@ -460,7 +336,6 @@ impl<'a> ServiceGenerator<'a> {
attrs,
rpcs,
vis,
future_types,
return_types,
service_ident,
client_stub_ident,
@@ -470,27 +345,19 @@ impl<'a> ServiceGenerator<'a> {
..
} = self;
let types_and_fns = rpcs
let rpc_fns = rpcs
.iter()
.zip(future_types.iter())
.zip(return_types.iter())
.map(
|(
(
RpcMethod {
attrs, ident, args, ..
},
future_type,
),
RpcMethod {
attrs, ident, args, ..
},
output,
)| {
let ty_doc = format!("The response future returned by [`{service_ident}::{ident}`].");
quote! {
#[doc = #ty_doc]
type #future_type: std::future::Future<Output = #output>;
#( #attrs )*
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type;
async fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> #output;
}
},
);
@@ -499,7 +366,7 @@ impl<'a> ServiceGenerator<'a> {
quote! {
#( #attrs )*
#vis trait #service_ident: Sized {
#( #types_and_fns )*
#( #rpc_fns )*
/// Returns a serving function to use with
/// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute).
@@ -539,7 +406,6 @@ impl<'a> ServiceGenerator<'a> {
server_ident,
service_ident,
response_ident,
response_fut_ident,
camel_case_idents,
arg_pats,
method_idents,
@@ -553,7 +419,6 @@ impl<'a> ServiceGenerator<'a> {
{
type Req = #request_ident;
type Resp = #response_ident;
type Fut = #response_fut_ident<S>;
fn method(&self, req: &#request_ident) -> Option<&'static str> {
Some(match req {
@@ -565,15 +430,16 @@ impl<'a> ServiceGenerator<'a> {
})
}
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
async fn serve(self, ctx: tarpc::context::Context, req: #request_ident)
-> Result<#response_ident, tarpc::ServerError> {
match req {
#(
#request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
#response_fut_ident::#camel_case_idents(
Ok(#response_ident::#camel_case_idents(
#service_ident::#method_idents(
self.service, ctx, #( #arg_pats ),*
)
)
).await
))
}
)*
}
@@ -624,74 +490,6 @@ impl<'a> ServiceGenerator<'a> {
}
}
fn enum_response_future(&self) -> TokenStream2 {
let &Self {
vis,
service_ident,
response_fut_ident,
camel_case_idents,
future_types,
..
} = self;
quote! {
/// A future resolving to a server response.
#[allow(missing_docs)]
#vis enum #response_fut_ident<S: #service_ident> {
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
}
}
}
fn impl_debug_for_response_future(&self) -> TokenStream2 {
let &Self {
service_ident,
response_fut_ident,
response_fut_name,
..
} = self;
quote! {
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct(#response_fut_name).finish()
}
}
}
}
fn impl_future_for_response_future(&self) -> TokenStream2 {
let &Self {
service_ident,
response_fut_ident,
response_ident,
camel_case_idents,
..
} = self;
quote! {
impl<S: #service_ident> std::future::Future for #response_fut_ident<S> {
type Output = Result<#response_ident, tarpc::ServerError>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
-> std::task::Poll<Result<#response_ident, tarpc::ServerError>>
{
unsafe {
match std::pin::Pin::get_unchecked_mut(self) {
#(
#response_fut_ident::#camel_case_idents(resp) =>
std::pin::Pin::new_unchecked(resp)
.poll(cx)
.map(#response_ident::#camel_case_idents)
.map(Ok),
)*
}
}
}
}
}
}
fn struct_client(&self) -> TokenStream2 {
let &Self {
vis,
@@ -804,9 +602,6 @@ impl<'a> ToTokens for ServiceGenerator<'a> {
self.impl_serve_for_server(),
self.enum_request(),
self.enum_response(),
self.enum_response_future(),
self.impl_debug_for_response_future(),
self.impl_future_for_response_future(),
self.struct_client(),
self.impl_client_new(),
self.impl_client_rpc_methods(),

View File

@@ -1,7 +1,5 @@
use assert_type_eq::assert_type_eq;
use futures::Future;
use std::pin::Pin;
use tarpc::context;
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
// these need to be out here rather than inside the function so that the
// assert_type_eq macro can pick them up.
@@ -12,42 +10,6 @@ trait Foo {
async fn baz();
}
#[test]
fn type_generation_works() {
#[tarpc::server]
impl Foo for () {
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
(s, i)
}
async fn bar(self, _: context::Context, s: String) -> String {
s
}
async fn baz(self, _: context::Context) {}
}
// the assert_type_eq macro can only be used once per block.
{
assert_type_eq!(
<() as Foo>::TwoPartFut,
Pin<Box<dyn Future<Output = (String, i32)> + Send>>
);
}
{
assert_type_eq!(
<() as Foo>::BarFut,
Pin<Box<dyn Future<Output = String> + Send>>
);
}
{
assert_type_eq!(
<() as Foo>::BazFut,
Pin<Box<dyn Future<Output = ()> + Send>>
);
}
}
#[allow(non_camel_case_types)]
#[test]
fn raw_idents_work() {
@@ -59,24 +21,6 @@ fn raw_idents_work() {
async fn r#fn(r#impl: r#yield) -> r#yield;
async fn r#async();
}
#[tarpc::server]
impl r#trait for () {
async fn r#await(
self,
_: context::Context,
r#struct: r#yield,
r#enum: i32,
) -> (r#yield, i32) {
(r#struct, r#enum)
}
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
r#impl
}
async fn r#async(self, _: context::Context) {}
}
}
#[test]
@@ -100,45 +44,4 @@ fn syntax() {
#[doc = "attr"]
async fn one_arg_implicit_return_error(one: String);
}
#[tarpc::server]
impl Syntax for () {
#[deny(warnings)]
#[allow(non_snake_case)]
async fn TestCamelCaseDoesntConflict(self, _: context::Context) {}
async fn hello(self, _: context::Context) -> String {
String::new()
}
async fn attr(self, _: context::Context, _s: String) -> String {
String::new()
}
async fn no_args_no_return(self, _: context::Context) {}
async fn no_args(self, _: context::Context) -> () {}
async fn one_arg(self, _: context::Context, _one: String) -> i32 {
0
}
async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {}
async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String {
String::new()
}
async fn no_args_ret_error(self, _: context::Context) -> i32 {
0
}
async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String {
String::new()
}
async fn no_arg_implicit_return_error(self, _: context::Context) {}
async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {}
}
}

View File

@@ -1,9 +1,10 @@
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use tarpc::context;
#[test]
fn att_service_trait() {
use futures::future::{ready, Ready};
#[tarpc::service]
trait Foo {
async fn two_part(s: String, i: i32) -> (String, i32);
@@ -12,19 +13,16 @@ fn att_service_trait() {
}
impl Foo for () {
type TwoPartFut = Ready<(String, i32)>;
fn two_part(self, _: context::Context, s: String, i: i32) -> Self::TwoPartFut {
ready((s, i))
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
(s, i)
}
type BarFut = Ready<String>;
fn bar(self, _: context::Context, s: String) -> Self::BarFut {
ready(s)
async fn bar(self, _: context::Context, s: String) -> String {
s
}
type BazFut = Ready<()>;
fn baz(self, _: context::Context) -> Self::BazFut {
ready(())
async fn baz(self, _: context::Context) {
()
}
}
}
@@ -32,8 +30,6 @@ fn att_service_trait() {
#[allow(non_camel_case_types)]
#[test]
fn raw_idents() {
use futures::future::{ready, Ready};
type r#yield = String;
#[tarpc::service]
@@ -44,19 +40,21 @@ fn raw_idents() {
}
impl r#trait for () {
type AwaitFut = Ready<(r#yield, i32)>;
fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut {
ready((r#struct, r#enum))
async fn r#await(
self,
_: context::Context,
r#struct: r#yield,
r#enum: i32,
) -> (r#yield, i32) {
(r#struct, r#enum)
}
type FnFut = Ready<r#yield>;
fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut {
ready(r#impl)
async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
r#impl
}
type AsyncFut = Ready<()>;
fn r#async(self, _: context::Context) -> Self::AsyncFut {
ready(())
async fn r#async(self, _: context::Context) {
()
}
}
}

View File

@@ -75,7 +75,8 @@ opentelemetry-jaeger = { version = "0.17.0", features = ["rt-tokio"] }
pin-utils = "0.1.0-alpha"
serde_bytes = "0.11"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tokio = { version = "1", features = ["full", "test-util"] }
tokio = { version = "1", features = ["full", "test-util", "tracing"] }
console-subscriber = "0.1"
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
trybuild = "1.0"
tokio-rustls = "0.23"

View File

@@ -1,5 +1,14 @@
// Copyright 2022 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt};
use futures::{prelude::*, Sink, SinkExt, Stream, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use serde_bytes::ByteBuf;
use std::{io, io::Read, io::Write};
@@ -99,13 +108,16 @@ pub trait World {
#[derive(Clone, Debug)]
struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hey, {name}!")
}
}
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let mut incoming = tcp::listen("localhost:0", Bincode::default).await?;
@@ -114,6 +126,7 @@ async fn main() -> anyhow::Result<()> {
let transport = incoming.next().await.unwrap().unwrap();
BaseChannel::with_defaults(add_compression(transport))
.execute(HelloServer.serve())
.for_each(spawn)
.await;
});

View File

@@ -1,3 +1,13 @@
// Copyright 2022 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use futures::prelude::*;
use tarpc::context::Context;
use tarpc::serde_transport as transport;
use tarpc::server::{BaseChannel, Channel};
@@ -13,7 +23,6 @@ pub trait PingService {
#[derive(Clone)]
struct Service;
#[tarpc::server]
impl PingService for Service {
async fn ping(self, _: Context) {}
}
@@ -26,13 +35,18 @@ async fn main() -> anyhow::Result<()> {
let listener = UnixListener::bind(bind_addr).unwrap();
let codec_builder = LengthDelimitedCodec::builder();
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
tokio::spawn(async move {
loop {
let (conn, _addr) = listener.accept().await.unwrap();
let framed = codec_builder.new_framed(conn);
let transport = transport::new(framed, Bincode::default());
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
let fut = BaseChannel::with_defaults(transport)
.execute(Service.serve())
.for_each(spawn);
tokio::spawn(fut);
}
});

View File

@@ -4,6 +4,9 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
/// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher"
/// port. Because both publishers and subscribers initiate their connections to the PubSub
/// server, the server requires no prior knowledge of either publishers or subscribers.
@@ -79,7 +82,6 @@ struct Subscriber {
topics: Vec<String>,
}
#[tarpc::server]
impl subscriber::Subscriber for Subscriber {
async fn topics(self, _: context::Context) -> Vec<String> {
self.topics.clone()
@@ -117,7 +119,8 @@ impl Subscriber {
))
}
};
let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve()));
let (handler, abort_handle) =
future::abortable(handler.execute(subscriber.serve()).for_each(spawn));
tokio::spawn(async move {
match handler.await {
Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."),
@@ -143,6 +146,10 @@ struct PublisherAddrs {
subscriptions: SocketAddr,
}
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
impl Publisher {
async fn start(self) -> io::Result<PublisherAddrs> {
let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?;
@@ -162,6 +169,7 @@ impl Publisher {
server::BaseChannel::with_defaults(publisher)
.execute(self.serve())
.for_each(spawn)
.await
});
@@ -257,7 +265,6 @@ impl Publisher {
}
}
#[tarpc::server]
impl publisher::Publisher for Publisher {
async fn publish(self, _: context::Context, topic: String, message: String) {
info!("received message to publish.");

View File

@@ -4,7 +4,10 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use futures::future::{self, Ready};
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use futures::prelude::*;
use tarpc::{
client, context,
server::{self, Channel},
@@ -23,22 +26,21 @@ pub trait World {
struct HelloServer;
impl World for HelloServer {
// Each defined rpc generates two items in the trait, a fn that serves the RPC, and
// an associated type representing the future output by the fn.
type HelloFut = Ready<String>;
fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
future::ready(format!("Hello, {name}!"))
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hello, {name}!")
}
}
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = server::BaseChannel::with_defaults(server_transport);
tokio::spawn(server.execute(HelloServer.serve()));
tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn));
// WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
// that takes a config and any Transport as input.

View File

@@ -1,3 +1,13 @@
// Copyright 2023 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use futures::prelude::*;
use rustls_pemfile::certs;
use std::io::{BufReader, Cursor};
use std::net::{IpAddr, Ipv4Addr};
@@ -23,7 +33,6 @@ pub trait PingService {
#[derive(Clone)]
struct Service;
#[tarpc::server]
impl PingService for Service {
async fn ping(self, _: Context) -> String {
"🔒".to_owned()
@@ -65,6 +74,10 @@ pub fn load_private_key(key: &str) -> rustls::PrivateKey {
panic!("no keys found in {:?} (encrypted keys not supported)", key);
}
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// -------------------- start here to setup tls tcp tokio stream --------------------------
@@ -100,7 +113,9 @@ async fn main() -> anyhow::Result<()> {
let transport = transport::new(framed, Bincode::default());
let fut = BaseChannel::with_defaults(transport).execute(Service.serve());
let fut = BaseChannel::with_defaults(transport)
.execute(Service.serve())
.for_each(spawn);
tokio::spawn(fut);
}
});

View File

@@ -4,7 +4,8 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![feature(type_alias_impl_trait)]
#![allow(incomplete_features)]
#![feature(async_fn_in_trait, type_alias_impl_trait)]
use crate::{
add::{Add as AddService, AddStub},
@@ -25,7 +26,10 @@ use tarpc::{
RpcError,
},
context, serde_transport,
server::{incoming::Incoming, BaseChannel, Serve},
server::{
incoming::{spawn_incoming, Incoming},
BaseChannel, Serve,
},
tokio_serde::formats::Json,
ClientMessage, Response, ServerError, Transport,
};
@@ -51,7 +55,6 @@ pub mod double {
#[derive(Clone)]
struct AddServer;
#[tarpc::server]
impl AddService for AddServer {
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
x + y
@@ -63,7 +66,6 @@ struct DoubleServer<Stub> {
add_client: add::AddClient<Stub>,
}
#[tarpc::server]
impl<Stub> DoubleService for DoubleServer<Stub>
where
Stub: AddStub + Clone + Send + Sync + 'static,
@@ -158,9 +160,8 @@ async fn main() -> anyhow::Result<()> {
});
let add_server = add_listener1
.chain(add_listener2)
.map(BaseChannel::with_defaults)
.execute(server);
tokio::spawn(add_server);
.map(BaseChannel::with_defaults);
tokio::spawn(spawn_incoming(add_server.execute(server)));
let add_client = add::AddClient::from(make_stub([
tarpc::serde_transport::tcp::connect(addr1, Json::default).await?,
@@ -171,11 +172,9 @@ async fn main() -> anyhow::Result<()> {
.await?
.filter_map(|r| future::ready(r.ok()));
let addr = double_listener.get_ref().local_addr();
let double_server = double_listener
.map(BaseChannel::with_defaults)
.take(1)
.execute(DoubleServer { add_client }.serve());
tokio::spawn(double_server);
let double_server = double_listener.map(BaseChannel::with_defaults).take(1);
let server = DoubleServer { add_client }.serve();
tokio::spawn(spawn_incoming(double_server.execute(server)));
let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?;
let double_client =

View File

@@ -80,6 +80,8 @@
//! First, let's set up the dependencies and service definition.
//!
//! ```rust
//! #![allow(incomplete_features)]
//! #![feature(async_fn_in_trait)]
//! # extern crate futures;
//!
//! use futures::{
@@ -104,6 +106,8 @@
//! implement it for our Server struct.
//!
//! ```rust
//! # #![allow(incomplete_features)]
//! # #![feature(async_fn_in_trait)]
//! # extern crate futures;
//! # use futures::{
//! # future::{self, Ready},
@@ -126,13 +130,9 @@
//! struct HelloServer;
//!
//! impl World for HelloServer {
//! // Each defined rpc generates two items in the trait, a fn that serves the RPC, and
//! // an associated type representing the future output by the fn.
//!
//! type HelloFut = Ready<String>;
//!
//! fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
//! future::ready(format!("Hello, {name}!"))
//! // Each defined rpc generates an async fn that serves the RPC
//! async fn hello(self, _: context::Context, name: String) -> String {
//! format!("Hello, {name}!")
//! }
//! }
//! ```
@@ -143,6 +143,8 @@
//! available behind the `tcp` feature.
//!
//! ```rust
//! # #![allow(incomplete_features)]
//! # #![feature(async_fn_in_trait)]
//! # extern crate futures;
//! # use futures::{
//! # future::{self, Ready},
@@ -164,11 +166,9 @@
//! # #[derive(Clone)]
//! # struct HelloServer;
//! # impl World for HelloServer {
//! # // Each defined rpc generates two items in the trait, a fn that serves the RPC, and
//! # // an associated type representing the future output by the fn.
//! # type HelloFut = Ready<String>;
//! # fn hello(self, _: context::Context, name: String) -> Self::HelloFut {
//! # future::ready(format!("Hello, {name}!"))
//! // Each defined rpc generates an async fn that serves the RPC
//! # async fn hello(self, _: context::Context, name: String) -> String {
//! # format!("Hello, {name}!")
//! # }
//! # }
//! # #[cfg(not(feature = "tokio1"))]
@@ -179,7 +179,12 @@
//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
//!
//! let server = server::BaseChannel::with_defaults(server_transport);
//! tokio::spawn(server.execute(HelloServer.serve()));
//! tokio::spawn(
//! server.execute(HelloServer.serve())
//! // Handle all requests concurrently.
//! .for_each(|response| async move {
//! tokio::spawn(response);
//! }));
//!
//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new`
//! // that takes a config and any Transport as input.
@@ -234,6 +239,7 @@ pub use tarpc_plugins::derive_serde;
/// Rpc methods are specified, mirroring trait syntax:
///
/// ```
/// #![feature(async_fn_in_trait)]
/// #[tarpc::service]
/// trait Service {
/// /// Say hello
@@ -253,62 +259,6 @@ pub use tarpc_plugins::derive_serde;
/// * `fn new_stub` -- creates a new Client stub.
pub use tarpc_plugins::service;
/// A utility macro that can be used for RPC server implementations.
///
/// Syntactic sugar to make using async functions in the server implementation
/// easier. It does this by rewriting code like this, which would normally not
/// compile because async functions are disallowed in trait implementations:
///
/// ```rust
/// # use tarpc::context;
/// # use std::net::SocketAddr;
/// #[tarpc::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// #[tarpc::server]
/// impl World for HelloServer {
/// async fn hello(self, _: context::Context, name: String) -> String {
/// format!("Hello, {name}! You are connected from {:?}.", self.0)
/// }
/// }
/// ```
///
/// Into code like this, which matches the service trait definition:
///
/// ```rust
/// # use tarpc::context;
/// # use std::pin::Pin;
/// # use futures::Future;
/// # use std::net::SocketAddr;
/// #[derive(Clone)]
/// struct HelloServer(SocketAddr);
///
/// #[tarpc::service]
/// trait World {
/// async fn hello(name: String) -> String;
/// }
///
/// impl World for HelloServer {
/// type HelloFut = Pin<Box<dyn Future<Output = String> + Send>>;
///
/// fn hello(self, _: context::Context, name: String) -> Pin<Box<dyn Future<Output = String>
/// + Send>> {
/// Box::pin(async move {
/// format!("Hello, {name}! You are connected from {:?}.", self.0)
/// })
/// }
/// }
/// ```
///
/// Note that this won't touch functions unless they have been annotated with
/// `async`, meaning that this should not break existing code.
pub use tarpc_plugins::server;
pub(crate) mod cancellations;
pub mod client;
pub mod context;

View File

@@ -35,11 +35,6 @@ pub mod limits;
/// Provides helper methods for streams of Channels.
pub mod incoming;
/// Provides convenience functionality for tokio-enabled applications.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
pub mod tokio;
use request_hook::{
AfterRequest, AfterRequestHook, BeforeAndAfterRequestHook, BeforeRequest, BeforeRequestHook,
};
@@ -79,11 +74,8 @@ pub trait Serve {
/// Type of response.
type Resp;
/// Type of response future.
type Fut: Future<Output = Result<Self::Resp, ServerError>>;
/// Responds to a single request.
fn serve(self, ctx: context::Context, req: Self::Req) -> Self::Fut;
async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError>;
/// Extracts a method name from the request.
fn method(&self, _request: &Self::Req) -> Option<&'static str> {
@@ -274,10 +266,9 @@ where
{
type Req = Req;
type Resp = Resp;
type Fut = Fut;
fn serve(self, ctx: context::Context, req: Req) -> Self::Fut {
(self.f)(ctx, req)
async fn serve(self, ctx: context::Context, req: Req) -> Result<Resp, ServerError> {
(self.f)(ctx, req).await
}
}
@@ -533,34 +524,42 @@ where
}
}
/// Runs the channel until completion by executing all requests using the given service
/// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's
/// default executor.
/// Returns a stream of request execution futures. Each future represents an in-flight request
/// being responded to by the server. The futures must be awaited or spawned to complete their
/// requests.
///
/// # Example
///
/// ```rust
/// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport};
/// use futures::prelude::*;
/// use tracing_subscriber::prelude::*;
///
/// #[derive(PartialEq, Eq, Debug)]
/// struct MyInt(i32);
///
/// # #[cfg(not(feature = "tokio1"))]
/// # fn main() {}
/// # #[cfg(feature = "tokio1")]
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = transport::channel::unbounded();
/// let client = client::new(client::Config::default(), tx).spawn();
/// let channel = BaseChannel::new(server::Config::default(), rx);
/// tokio::spawn(channel.execute(serve(|_, i| async move { Ok(i + 1) })));
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
/// let channel = BaseChannel::with_defaults(rx);
/// tokio::spawn(
/// channel.execute(serve(|_, MyInt(i)| async move { Ok(MyInt(i + 1)) }))
/// .for_each(|response| async move {
/// tokio::spawn(response);
/// }));
/// assert_eq!(
/// client.call(context::current(), "AddOne", MyInt(1)).await.unwrap(),
/// MyInt(2));
/// }
/// ```
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
fn execute<S>(self, serve: S) -> self::tokio::TokioChannelExecutor<Requests<Self>, S>
fn execute<S>(self, serve: S) -> impl Stream<Item = impl Future<Output = ()>>
where
Self: Sized,
S: Serve<Req = Self::Req, Resp = Self::Resp> + Send + 'static,
S::Fut: Send,
Self::Req: Send + 'static,
Self::Resp: Send + 'static,
S: Serve<Req = Self::Req, Resp = Self::Resp> + Clone,
{
self.requests().execute(serve)
}
@@ -654,15 +653,17 @@ where
Poll::Pending => Pending,
};
tracing::trace!(
"Expired requests: {:?}, Inbound: {:?}",
expiration_status,
request_status
);
match cancellation_status
let status = cancellation_status
.combine(expiration_status)
.combine(request_status)
{
.combine(request_status);
tracing::trace!(
"Cancellations: {cancellation_status:?}, \
Expired requests: {expiration_status:?}, \
Inbound: {request_status:?}, \
Overall: {status:?}",
);
match status {
Ready => continue,
Closed => return Poll::Ready(None),
Pending => return Poll::Pending,
@@ -872,6 +873,51 @@ where
}
Poll::Ready(Some(Ok(())))
}
/// Returns a stream of request execution futures. Each future represents an in-flight request
/// being responded to by the server. The futures must be awaited or spawned to complete their
/// requests.
///
/// If the channel encounters an error, the stream is terminated and the error is logged.
///
/// # Example
///
/// ```rust
/// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport};
/// use futures::prelude::*;
///
/// # #[cfg(not(feature = "tokio1"))]
/// # fn main() {}
/// # #[cfg(feature = "tokio1")]
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = transport::channel::unbounded();
/// let requests = BaseChannel::new(server::Config::default(), rx).requests();
/// let client = client::new(client::Config::default(), tx).spawn();
/// tokio::spawn(
/// requests.execute(serve(|_, i| async move { Ok(i + 1) }))
/// .for_each(|response| async move {
/// tokio::spawn(response);
/// }));
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
/// }
/// ```
pub fn execute<S>(self, serve: S) -> impl Stream<Item = impl Future<Output = ()>>
where
S: Serve<Req = C::Req, Resp = C::Resp> + Clone,
{
self.take_while(|result| {
if let Err(e) = result {
tracing::warn!("Requests stream errored out: {}", e);
}
futures::future::ready(result.is_ok())
})
.filter_map(|result| async move { result.ok() })
.map(move |request| {
let serve = serve.clone();
request.execute(serve)
})
}
}
impl<C> fmt::Debug for Requests<C>
@@ -1003,6 +1049,13 @@ impl<Req, Res> InFlightRequest<Req, Res> {
}
}
fn print_err(e: &(dyn Error + 'static)) -> String {
anyhow::Chain::new(e)
.map(|e| e.to_string())
.intersperse(": ".into())
.collect::<String>()
}
impl<C> Stream for Requests<C>
where
C: Channel,
@@ -1011,17 +1064,33 @@ where
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
let read = self.as_mut().pump_read(cx)?;
let read = self.as_mut().pump_read(cx).map_err(|e| {
tracing::trace!("read: {}", print_err(&e));
e
})?;
let read_closed = matches!(read, Poll::Ready(None));
match (read, self.as_mut().pump_write(cx, read_closed)?) {
let write = self.as_mut().pump_write(cx, read_closed).map_err(|e| {
tracing::trace!("write: {}", print_err(&e));
e
})?;
match (read, write) {
(Poll::Ready(None), Poll::Ready(None)) => {
tracing::trace!("read: Poll::Ready(None), write: Poll::Ready(None)");
return Poll::Ready(None);
}
(Poll::Ready(Some(request_handler)), _) => {
tracing::trace!("read: Poll::Ready(Some), write: _");
return Poll::Ready(Some(Ok(request_handler)));
}
(_, Poll::Ready(Some(()))) => {}
_ => {
(_, Poll::Ready(Some(()))) => {
tracing::trace!("read: _, write: Poll::Ready(Some)");
}
(read @ Poll::Pending, write) | (read, write @ Poll::Pending) => {
tracing::trace!(
"read pending: {}, write pending: {}",
read.is_pending(),
write.is_pending()
);
return Poll::Pending;
}
}

View File

@@ -1,13 +1,10 @@
use super::{
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
Channel,
Channel, Serve,
};
use futures::prelude::*;
use std::{fmt, hash::Hash};
#[cfg(feature = "tokio1")]
use super::{tokio::TokioServerExecutor, Serve};
/// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel).
pub trait Incoming<C>
where
@@ -28,16 +25,62 @@ where
MaxRequestsPerChannel::new(self, n)
}
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
/// concurrently by spawning on tokio's default executor, and each request will be also
/// be spawned on tokio's default executor.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
/// Returns a stream of channels in execution. Each channel in execution is a stream of
/// futures, where each future is an in-flight request being rsponded to.
fn execute<S>(
self,
serve: S,
) -> impl Stream<Item = impl Stream<Item = impl Future<Output = ()>>>
where
S: Serve<Req = C::Req, Resp = C::Resp>,
S: Serve<Req = C::Req, Resp = C::Resp> + Clone,
{
TokioServerExecutor::new(self, serve)
self.map(move |channel| channel.execute(serve.clone()))
}
}
#[cfg(feature = "tokio1")]
/// Spawns all channels-in-execution, delegating to the tokio runtime to manage their completion.
/// Each channel is spawned, and each request from each channel is spawned.
/// Note that this function is generic over any stream-of-streams-of-futures, but it is intended
/// for spawning streams of channels.
///
/// # Example
/// ```rust
/// use tarpc::{
/// context,
/// client::{self, NewClient},
/// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve},
/// transport,
/// };
/// use futures::prelude::*;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = transport::channel::unbounded();
/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx);
/// tokio::spawn(dispatch);
///
/// let incoming = stream::once(async move {
/// BaseChannel::new(server::Config::default(), rx)
/// }).execute(serve(|_, i| async move { Ok(i + 1) }));
/// tokio::spawn(spawn_incoming(incoming));
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
/// }
/// ```
pub async fn spawn_incoming(
incoming: impl Stream<
Item = impl Stream<Item = impl Future<Output = ()> + Send + 'static> + Send + 'static,
>,
) {
use futures::pin_mut;
pin_mut!(incoming);
while let Some(channel) = incoming.next().await {
tokio::spawn(async move {
pin_mut!(channel);
while let Some(request) = channel.next().await {
tokio::spawn(request);
}
});
}
}

View File

@@ -71,19 +71,17 @@ where
{
type Req = Serv::Req;
type Resp = Serv::Resp;
type Fut = AfterRequestHookFut<Serv, Hook>;
fn serve(self, mut ctx: context::Context, req: Serv::Req) -> Self::Fut {
async move {
let AfterRequestHook {
serve, mut hook, ..
} = self;
let mut resp = serve.serve(ctx, req).await;
hook.after(&mut ctx, &mut resp).await;
resp
}
async fn serve(
self,
mut ctx: context::Context,
req: Serv::Req,
) -> Result<Serv::Resp, ServerError> {
let AfterRequestHook {
serve, mut hook, ..
} = self;
let mut resp = serve.serve(ctx, req).await;
hook.after(&mut ctx, &mut resp).await;
resp
}
}
type AfterRequestHookFut<Serv: Serve, Hook: AfterRequest<Serv::Resp>> =
impl Future<Output = Result<Serv::Resp, ServerError>>;

View File

@@ -67,18 +67,16 @@ where
{
type Req = Serv::Req;
type Resp = Serv::Resp;
type Fut = BeforeRequestHookFut<Serv, Hook>;
fn serve(self, mut ctx: context::Context, req: Self::Req) -> Self::Fut {
async fn serve(
self,
mut ctx: context::Context,
req: Self::Req,
) -> Result<Serv::Resp, ServerError> {
let BeforeRequestHook {
serve, mut hook, ..
} = self;
async move {
hook.before(&mut ctx, &req).await?;
serve.serve(ctx, req).await
}
hook.before(&mut ctx, &req).await?;
serve.serve(ctx, req).await
}
}
type BeforeRequestHookFut<Serv: Serve, Hook: BeforeRequest<Serv::Req>> =
impl Future<Output = Result<Serv::Resp, ServerError>>;

View File

@@ -8,7 +8,6 @@
use super::{after::AfterRequest, before::BeforeRequest};
use crate::{context, server::Serve, ServerError};
use futures::prelude::*;
use std::marker::PhantomData;
/// A Service function that runs a hook both before and after request execution.
@@ -47,24 +46,14 @@ where
{
type Req = Req;
type Resp = Resp;
type Fut = BeforeAndAfterRequestHookFut<Req, Resp, Serv, Hook>;
fn serve(self, mut ctx: context::Context, req: Req) -> Self::Fut {
async move {
let BeforeAndAfterRequestHook {
serve, mut hook, ..
} = self;
hook.before(&mut ctx, &req).await?;
let mut resp = serve.serve(ctx, req).await;
hook.after(&mut ctx, &mut resp).await;
resp
}
async fn serve(self, mut ctx: context::Context, req: Req) -> Result<Serv::Resp, ServerError> {
let BeforeAndAfterRequestHook {
serve, mut hook, ..
} = self;
hook.before(&mut ctx, &req).await?;
let mut resp = serve.serve(ctx, req).await;
hook.after(&mut ctx, &mut resp).await;
resp
}
}
type BeforeAndAfterRequestHookFut<
Req,
Resp,
Serv: Serve<Req = Req, Resp = Resp>,
Hook: BeforeRequest<Req> + AfterRequest<Resp>,
> = impl Future<Output = Result<Serv::Resp, ServerError>>;

View File

@@ -1,129 +0,0 @@
use super::{Channel, Requests, Serve};
use futures::{prelude::*, ready, task::*};
use pin_project::pin_project;
use std::pin::Pin;
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
/// for each new channel. Returned by
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
#[must_use]
#[pin_project]
#[derive(Debug)]
pub struct TokioServerExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
impl<T, S> TokioServerExecutor<T, S> {
pub(crate) fn new(inner: T, serve: S) -> Self {
Self { inner, serve }
}
}
/// A future that drives the server by [spawning](tokio::spawn) each [response
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
/// [`Channel::execute`](crate::server::Channel::execute).
#[must_use]
#[pin_project]
#[derive(Debug)]
pub struct TokioChannelExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
impl<T, S> TokioServerExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
impl<T, S> TokioChannelExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
// Send + 'static execution helper methods.
impl<C> Requests<C>
where
C: Channel,
C::Req: Send + 'static,
C::Resp: Send + 'static,
{
/// Executes all requests using the given service function. Requests are handled concurrently
/// by [spawning](::tokio::spawn) each handler on tokio's default executor.
///
/// # Example
///
/// ```rust
/// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport};
/// use futures::prelude::*;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = transport::channel::unbounded();
/// let requests = BaseChannel::new(server::Config::default(), rx).requests();
/// let client = client::new(client::Config::default(), tx).spawn();
/// tokio::spawn(requests.execute(serve(|_, i| async move { Ok(i + 1) })));
/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2);
/// }
/// ```
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
where
S: Serve<Req = C::Req, Resp = C::Resp> + Send + 'static,
{
TokioChannelExecutor { inner: self, serve }
}
}
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
where
St: Sized + Stream<Item = C>,
C: Channel + Send + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
Se: Serve<Req = C::Req, Resp = C::Resp> + Send + 'static + Clone,
Se::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
tokio::spawn(channel.execute(self.serve.clone()));
}
tracing::info!("Server shutting down.");
Poll::Ready(())
}
}
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
where
C: Channel + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
S: Serve<Req = C::Req, Resp = C::Resp> + Send + 'static + Clone,
S::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
match response_handler {
Ok(resp) => {
let server = self.serve.clone();
tokio::spawn(async move {
resp.execute(server).await;
});
}
Err(e) => {
tracing::warn!("Requests stream errored out: {}", e);
break;
}
}
}
Poll::Ready(())
}
}

View File

@@ -14,9 +14,15 @@ use tokio::sync::mpsc;
/// Errors that occur in the sending or receiving of messages over a channel.
#[derive(thiserror::Error, Debug)]
pub enum ChannelError {
/// An error occurred sending over the channel.
#[error("an error occurred sending over the channel")]
/// An error occurred readying to send into the channel.
#[error("an error occurred readying to send into the channel")]
Ready(#[source] Box<dyn Error + Send + Sync + 'static>),
/// An error occurred sending into the channel.
#[error("an error occurred sending into the channel")]
Send(#[source] Box<dyn Error + Send + Sync + 'static>),
/// An error occurred receiving from the channel.
#[error("an error occurred receiving from the channel")]
Receive(#[source] Box<dyn Error + Send + Sync + 'static>),
}
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
@@ -48,7 +54,10 @@ impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, ChannelError>>> {
self.rx.poll_recv(cx).map(|option| option.map(Ok))
self.rx
.poll_recv(cx)
.map(|option| option.map(Ok))
.map_err(ChannelError::Receive)
}
}
@@ -59,7 +68,7 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(if self.tx.is_closed() {
Err(ChannelError::Send(CLOSED_MESSAGE.into()))
Err(ChannelError::Ready(CLOSED_MESSAGE.into()))
} else {
Ok(())
})
@@ -110,7 +119,11 @@ impl<Item, SinkItem> Stream for Channel<Item, SinkItem> {
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, ChannelError>>> {
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
self.project()
.rx
.poll_next(cx)
.map(|option| option.map(Ok))
.map_err(ChannelError::Receive)
}
}
@@ -121,7 +134,7 @@ impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
self.project()
.tx
.poll_ready(cx)
.map_err(|e| ChannelError::Send(Box::new(e)))
.map_err(|e| ChannelError::Ready(Box::new(e)))
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
@@ -146,8 +159,7 @@ impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
}
}
#[cfg(test)]
#[cfg(feature = "tokio1")]
#[cfg(all(test, feature = "tokio1"))]
mod tests {
use crate::{
client::{self, RpcError},
@@ -186,7 +198,10 @@ mod tests {
format!("{request:?} is not an int"),
)
})
})),
}))
.for_each(|channel| async move {
tokio::spawn(channel.for_each(|response| response));
}),
);
let client = client::new(client::Config::default(), client_channel).spawn();

View File

@@ -2,8 +2,6 @@
fn ui() {
let t = trybuild::TestCases::new();
t.compile_fail("tests/compile_fail/*.rs");
#[cfg(feature = "tokio1")]
t.compile_fail("tests/compile_fail/tokio/*.rs");
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
t.compile_fail("tests/compile_fail/serde_transport/*.rs");
}

View File

@@ -1,3 +1,6 @@
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use tarpc::client;
#[tarpc::service]

View File

@@ -1,15 +1,15 @@
error: unused `RequestDispatch` that must be used
--> tests/compile_fail/must_use_request_dispatch.rs:13:9
--> tests/compile_fail/must_use_request_dispatch.rs:16:9
|
13 | WorldClient::new(client::Config::default(), client_transport).dispatch;
16 | WorldClient::new(client::Config::default(), client_transport).dispatch;
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/must_use_request_dispatch.rs:11:12
--> tests/compile_fail/must_use_request_dispatch.rs:14:12
|
11 | #[deny(unused_must_use)]
14 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^
help: use `let _ = ...` to ignore the resulting value
|
13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch;
16 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch;
| +++++++

View File

@@ -1,15 +0,0 @@
#[tarpc::service(derive_serde = false)]
trait World {
async fn hello(name: String) -> String;
}
struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
fn hello(name: String) -> String {
format!("Hello, {name}!", name)
}
}
fn main() {}

View File

@@ -1,15 +0,0 @@
error: not all trait items implemented, missing: `HelloFut`
--> tests/compile_fail/tarpc_server_missing_async.rs:9:1
|
9 | / impl World for HelloServer {
10 | | fn hello(name: String) -> String {
11 | | format!("Hello, {name}!", name)
12 | | }
13 | | }
| |_^
error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async
--> tests/compile_fail/tarpc_server_missing_async.rs:10:5
|
10 | fn hello(name: String) -> String {
| ^^^^^^^^

View File

@@ -1,29 +0,0 @@
use tarpc::{
context,
server::{self, Channel},
};
#[tarpc::service]
trait World {
async fn hello(name: String) -> String;
}
#[derive(Clone)]
struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hello, {name}!")
}
}
fn main() {
let (_, server_transport) = tarpc::transport::channel::unbounded();
let server = server::BaseChannel::with_defaults(server_transport);
#[deny(unused_must_use)]
{
server.execute(HelloServer.serve());
}
}

View File

@@ -1,15 +0,0 @@
error: unused `TokioChannelExecutor` that must be used
--> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9
|
27 | server.execute(HelloServer.serve());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12
|
25 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^
help: use `let _ = ...` to ignore the resulting value
|
27 | let _ = server.execute(HelloServer.serve());
| +++++++

View File

@@ -1,30 +0,0 @@
use futures::stream::once;
use tarpc::{
context,
server::{self, incoming::Incoming},
};
#[tarpc::service]
trait World {
async fn hello(name: String) -> String;
}
#[derive(Clone)]
struct HelloServer;
#[tarpc::server]
impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
format!("Hello, {name}!")
}
}
fn main() {
let (_, server_transport) = tarpc::transport::channel::unbounded();
let server = once(async move { server::BaseChannel::with_defaults(server_transport) });
#[deny(unused_must_use)]
{
server.execute(HelloServer.serve());
}
}

View File

@@ -1,15 +0,0 @@
error: unused `TokioServerExecutor` that must be used
--> tests/compile_fail/tokio/must_use_server_executor.rs:28:9
|
28 | server.execute(HelloServer.serve());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: the lint level is defined here
--> tests/compile_fail/tokio/must_use_server_executor.rs:26:12
|
26 | #[deny(unused_must_use)]
| ^^^^^^^^^^^^^^^
help: use `let _ = ...` to ignore the resulting value
|
28 | let _ = server.execute(HelloServer.serve());
| +++++++

View File

@@ -1,3 +1,6 @@
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use futures::prelude::*;
use tarpc::serde_transport;
use tarpc::{
@@ -21,7 +24,6 @@ pub trait ColorProtocol {
#[derive(Clone)]
struct ColorServer;
#[tarpc::server]
impl ColorProtocol for ColorServer {
async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData {
match color {
@@ -31,6 +33,11 @@ impl ColorProtocol for ColorServer {
}
}
#[cfg(test)]
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::test]
async fn test_call() -> anyhow::Result<()> {
let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?;
@@ -40,7 +47,9 @@ async fn test_call() -> anyhow::Result<()> {
.take(1)
.filter_map(|r| async { r.ok() })
.map(BaseChannel::with_defaults)
.execute(ColorServer.serve()),
.execute(ColorServer.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
);
let transport = serde_transport::tcp::connect(addr, Json::default).await?;

View File

@@ -1,13 +1,16 @@
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]
use assert_matches::assert_matches;
use futures::{
future::{join_all, ready, Ready},
future::{join_all, ready},
prelude::*,
};
use std::time::{Duration, SystemTime};
use tarpc::{
client::{self},
context,
server::{self, incoming::Incoming, BaseChannel, Channel},
server::{incoming::Incoming, BaseChannel, Channel},
transport::channel,
};
use tokio::join;
@@ -22,39 +25,29 @@ trait Service {
struct Server;
impl Service for Server {
type AddFut = Ready<i32>;
fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut {
ready(x + y)
async fn add(self, _: context::Context, x: i32, y: i32) -> i32 {
x + y
}
type HeyFut = Ready<String>;
fn hey(self, _: context::Context, name: String) -> Self::HeyFut {
ready(format!("Hey, {name}."))
async fn hey(self, _: context::Context, name: String) -> String {
format!("Hey, {name}.")
}
}
#[tokio::test]
async fn sequential() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
async fn sequential() {
let (tx, rx) = tarpc::transport::channel::unbounded();
let client = client::new(client::Config::default(), tx).spawn();
let channel = BaseChannel::with_defaults(rx);
tokio::spawn(
BaseChannel::new(server::Config::default(), rx)
.requests()
.execute(Server.serve()),
channel
.execute(tarpc::server::serve(|_, i| async move { Ok(i + 1) }))
.for_each(|response| response),
);
assert_eq!(
client.call(context::current(), "AddOne", 1).await.unwrap(),
2
);
let client = ServiceClient::new(client::Config::default(), tx).spawn();
assert_matches!(client.add(context::current(), 1, 2).await, Ok(3));
assert_matches!(
client.hey(context::current(), "Tim".into()).await,
Ok(ref s) if s == "Hey, Tim.");
Ok(())
}
#[tokio::test]
@@ -70,7 +63,6 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> {
#[derive(Debug)]
struct AllHandlersComplete;
#[tarpc::server]
impl Loop for LoopServer {
async fn r#loop(self, _: context::Context) {
loop {
@@ -121,7 +113,9 @@ async fn serde_tcp() -> anyhow::Result<()> {
.take(1)
.filter_map(|r| async { r.ok() })
.map(BaseChannel::with_defaults)
.execute(Server.serve()),
.execute(Server.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
);
let transport = serde_transport::tcp::connect(addr, Json::default).await?;
@@ -151,7 +145,9 @@ async fn serde_uds() -> anyhow::Result<()> {
.take(1)
.filter_map(|r| async { r.ok() })
.map(BaseChannel::with_defaults)
.execute(Server.serve()),
.execute(Server.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
);
let transport = serde_transport::unix::connect(&sock, Json::default).await?;
@@ -175,7 +171,9 @@ async fn concurrent() -> anyhow::Result<()> {
tokio::spawn(
stream::once(ready(rx))
.map(BaseChannel::with_defaults)
.execute(Server.serve()),
.execute(Server.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn();
@@ -199,7 +197,9 @@ async fn concurrent_join() -> anyhow::Result<()> {
tokio::spawn(
stream::once(ready(rx))
.map(BaseChannel::with_defaults)
.execute(Server.serve()),
.execute(Server.serve())
.map(|channel| channel.for_each(spawn))
.for_each(spawn),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn();
@@ -216,15 +216,20 @@ async fn concurrent_join() -> anyhow::Result<()> {
Ok(())
}
#[cfg(test)]
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
#[tokio::test]
async fn concurrent_join_all() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (tx, rx) = channel::unbounded();
tokio::spawn(
stream::once(ready(rx))
.map(BaseChannel::with_defaults)
.execute(Server.serve()),
BaseChannel::with_defaults(rx)
.execute(Server.serve())
.for_each(spawn),
);
let client = ServiceClient::new(client::Config::default(), tx).spawn();
@@ -249,11 +254,9 @@ async fn counter() -> anyhow::Result<()> {
struct CountService(u32);
impl Counter for &mut CountService {
type CountFut = futures::future::Ready<u32>;
fn count(self, _: context::Context) -> Self::CountFut {
async fn count(self, _: context::Context) -> u32 {
self.0 += 1;
futures::future::ready(self.0)
self.0
}
}