diff --git a/example-service/src/lib.rs b/example-service/src/lib.rs index bc38fe9..822d821 100644 --- a/example-service/src/lib.rs +++ b/example-service/src/lib.rs @@ -4,6 +4,9 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use std::env; use tracing_subscriber::{fmt::format::FmtSpan, prelude::*}; diff --git a/example-service/src/server.rs b/example-service/src/server.rs index b0281e9..6c78598 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -4,6 +4,9 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use clap::Parser; use futures::{future, prelude::*}; use rand::{ @@ -34,7 +37,6 @@ struct Flags { #[derive(Clone)] struct HelloServer(SocketAddr); -#[tarpc::server] impl World for HelloServer { async fn hello(self, _: context::Context, name: String) -> String { let sleep_time = @@ -44,6 +46,10 @@ impl World for HelloServer { } } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let flags = Flags::parse(); @@ -66,7 +72,7 @@ async fn main() -> anyhow::Result<()> { // the generated World trait. .map(|channel| { let server = HelloServer(channel.transport().peer_addr().unwrap()); - channel.execute(server.serve()) + channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. .buffer_unordered(10) diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index efab161..f33cea0 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -12,18 +12,18 @@ extern crate quote; extern crate syn; use proc_macro::TokenStream; -use proc_macro2::{Span, TokenStream as TokenStream2}; +use proc_macro2::TokenStream as TokenStream2; use quote::{format_ident, quote, ToTokens}; use syn::{ braced, ext::IdentExt, parenthesized, parse::{Parse, ParseStream}, - parse_macro_input, parse_quote, parse_str, + parse_macro_input, parse_quote, spanned::Spanned, token::Comma, - Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool, - MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility, + Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type, + Visibility, }; /// Accumulates multiple errors into a result. @@ -257,7 +257,6 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string())) .collect(); let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::>(); - let response_fut_name = &format!("{}ResponseFut", ident.unraw()); let derive_serialize = if derive_serde.0 { Some( quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)] @@ -274,11 +273,9 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .collect::>(); ServiceGenerator { - response_fut_name, service_ident: ident, client_stub_ident: &format_ident!("{}Stub", ident), server_ident: &format_ident!("Serve{}", ident), - response_fut_ident: &Ident::new(response_fut_name, ident.span()), client_ident: &format_ident!("{}Client", ident), request_ident: &format_ident!("{}Request", ident), response_ident: &format_ident!("{}Response", ident), @@ -305,138 +302,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .zip(camel_case_fn_names.iter()) .map(|(rpc, name)| Ident::new(name, rpc.ident.span())) .collect::>(), - future_types: &camel_case_fn_names - .iter() - .map(|name| parse_str(&format!("{name}Fut")).unwrap()) - .collect::>(), derive_serialize: derive_serialize.as_ref(), } .into_token_stream() .into() } -/// generate an identifier consisting of the method name to CamelCase with -/// Fut appended to it. -fn associated_type_for_rpc(method: &ImplItemMethod) -> String { - snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut" -} - -/// Transforms an async function into a sync one, returning a type declaration -/// for the return type (a future). -fn transform_method(method: &mut ImplItemMethod) -> ImplItemType { - method.sig.asyncness = None; - - // get either the return type or (). - let ret = match &method.sig.output { - ReturnType::Default => quote!(()), - ReturnType::Type(_, ret) => quote!(#ret), - }; - - let fut_name = associated_type_for_rpc(method); - let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span()); - - // generate the updated return signature. - method.sig.output = parse_quote! { - -> ::core::pin::Pin + ::core::marker::Send - >> - }; - - // transform the body of the method into Box::pin(async move { body }). - let block = method.block.clone(); - method.block = parse_quote! [{ - Box::pin(async move - #block - ) - }]; - - // generate and return type declaration for return type. - let t: ImplItemType = parse_quote! { - type #fut_name_ident = ::core::pin::Pin + ::core::marker::Send>>; - }; - - t -} - -#[proc_macro_attribute] -pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream { - let mut item = syn::parse_macro_input!(input as ItemImpl); - let span = item.span(); - - // the generated type declarations - let mut types: Vec = Vec::new(); - let mut expected_non_async_types: Vec<(&ImplItemMethod, String)> = Vec::new(); - let mut found_non_async_types: Vec<&ImplItemType> = Vec::new(); - - for inner in &mut item.items { - match inner { - ImplItem::Method(method) => { - if method.sig.asyncness.is_some() { - // if this function is declared async, transform it into a regular function - let typedecl = transform_method(method); - types.push(typedecl); - } else { - // If it's not async, keep track of all required associated types for better - // error reporting. - expected_non_async_types.push((method, associated_type_for_rpc(method))); - } - } - ImplItem::Type(typedecl) => found_non_async_types.push(typedecl), - _ => {} - } - } - - if let Err(e) = - verify_types_were_provided(span, &expected_non_async_types, &found_non_async_types) - { - return TokenStream::from(e.to_compile_error()); - } - - // add the type declarations into the impl block - for t in types.into_iter() { - item.items.push(syn::ImplItem::Type(t)); - } - - TokenStream::from(quote!(#item)) -} - -fn verify_types_were_provided( - span: Span, - expected: &[(&ImplItemMethod, String)], - provided: &[&ImplItemType], -) -> syn::Result<()> { - let mut result = Ok(()); - for (method, expected) in expected { - if !provided.iter().any(|typedecl| typedecl.ident == expected) { - let mut e = syn::Error::new( - span, - format!("not all trait items implemented, missing: `{expected}`"), - ); - let fn_span = method.sig.fn_token.span(); - e.extend(syn::Error::new( - fn_span.join(method.sig.ident.span()).unwrap_or(fn_span), - format!( - "hint: `#[tarpc::server]` only rewrites async fns, and `fn {}` is not async", - method.sig.ident - ), - )); - match result { - Ok(_) => result = Err(e), - Err(ref mut error) => error.extend(Some(e)), - } - } - } - result -} - // Things needed to generate the service items: trait, serve impl, request/response enums, and // the client stub. struct ServiceGenerator<'a> { service_ident: &'a Ident, client_stub_ident: &'a Ident, server_ident: &'a Ident, - response_fut_ident: &'a Ident, - response_fut_name: &'a str, client_ident: &'a Ident, request_ident: &'a Ident, response_ident: &'a Ident, @@ -444,7 +321,6 @@ struct ServiceGenerator<'a> { attrs: &'a [Attribute], rpcs: &'a [RpcMethod], camel_case_idents: &'a [Ident], - future_types: &'a [Type], method_idents: &'a [&'a Ident], request_names: &'a [String], method_attrs: &'a [&'a [Attribute]], @@ -460,7 +336,6 @@ impl<'a> ServiceGenerator<'a> { attrs, rpcs, vis, - future_types, return_types, service_ident, client_stub_ident, @@ -470,27 +345,19 @@ impl<'a> ServiceGenerator<'a> { .. } = self; - let types_and_fns = rpcs + let rpc_fns = rpcs .iter() - .zip(future_types.iter()) .zip(return_types.iter()) .map( |( - ( - RpcMethod { - attrs, ident, args, .. - }, - future_type, - ), + RpcMethod { + attrs, ident, args, .. + }, output, )| { - let ty_doc = format!("The response future returned by [`{service_ident}::{ident}`]."); quote! { - #[doc = #ty_doc] - type #future_type: std::future::Future; - #( #attrs )* - fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type; + async fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> #output; } }, ); @@ -499,7 +366,7 @@ impl<'a> ServiceGenerator<'a> { quote! { #( #attrs )* #vis trait #service_ident: Sized { - #( #types_and_fns )* + #( #rpc_fns )* /// Returns a serving function to use with /// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute). @@ -539,7 +406,6 @@ impl<'a> ServiceGenerator<'a> { server_ident, service_ident, response_ident, - response_fut_ident, camel_case_idents, arg_pats, method_idents, @@ -553,7 +419,6 @@ impl<'a> ServiceGenerator<'a> { { type Req = #request_ident; type Resp = #response_ident; - type Fut = #response_fut_ident; fn method(&self, req: &#request_ident) -> Option<&'static str> { Some(match req { @@ -565,15 +430,16 @@ impl<'a> ServiceGenerator<'a> { }) } - fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut { + async fn serve(self, ctx: tarpc::context::Context, req: #request_ident) + -> Result<#response_ident, tarpc::ServerError> { match req { #( #request_ident::#camel_case_idents{ #( #arg_pats ),* } => { - #response_fut_ident::#camel_case_idents( + Ok(#response_ident::#camel_case_idents( #service_ident::#method_idents( self.service, ctx, #( #arg_pats ),* - ) - ) + ).await + )) } )* } @@ -624,74 +490,6 @@ impl<'a> ServiceGenerator<'a> { } } - fn enum_response_future(&self) -> TokenStream2 { - let &Self { - vis, - service_ident, - response_fut_ident, - camel_case_idents, - future_types, - .. - } = self; - - quote! { - /// A future resolving to a server response. - #[allow(missing_docs)] - #vis enum #response_fut_ident { - #( #camel_case_idents(::#future_types) ),* - } - } - } - - fn impl_debug_for_response_future(&self) -> TokenStream2 { - let &Self { - service_ident, - response_fut_ident, - response_fut_name, - .. - } = self; - - quote! { - impl std::fmt::Debug for #response_fut_ident { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - fmt.debug_struct(#response_fut_name).finish() - } - } - } - } - - fn impl_future_for_response_future(&self) -> TokenStream2 { - let &Self { - service_ident, - response_fut_ident, - response_ident, - camel_case_idents, - .. - } = self; - - quote! { - impl std::future::Future for #response_fut_ident { - type Output = Result<#response_ident, tarpc::ServerError>; - - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) - -> std::task::Poll> - { - unsafe { - match std::pin::Pin::get_unchecked_mut(self) { - #( - #response_fut_ident::#camel_case_idents(resp) => - std::pin::Pin::new_unchecked(resp) - .poll(cx) - .map(#response_ident::#camel_case_idents) - .map(Ok), - )* - } - } - } - } - } - } - fn struct_client(&self) -> TokenStream2 { let &Self { vis, @@ -804,9 +602,6 @@ impl<'a> ToTokens for ServiceGenerator<'a> { self.impl_serve_for_server(), self.enum_request(), self.enum_response(), - self.enum_response_future(), - self.impl_debug_for_response_future(), - self.impl_future_for_response_future(), self.struct_client(), self.impl_client_new(), self.impl_client_rpc_methods(), diff --git a/plugins/tests/server.rs b/plugins/tests/server.rs index f0222ff..7fcec79 100644 --- a/plugins/tests/server.rs +++ b/plugins/tests/server.rs @@ -1,7 +1,5 @@ -use assert_type_eq::assert_type_eq; -use futures::Future; -use std::pin::Pin; -use tarpc::context; +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] // these need to be out here rather than inside the function so that the // assert_type_eq macro can pick them up. @@ -12,42 +10,6 @@ trait Foo { async fn baz(); } -#[test] -fn type_generation_works() { - #[tarpc::server] - impl Foo for () { - async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { - (s, i) - } - - async fn bar(self, _: context::Context, s: String) -> String { - s - } - - async fn baz(self, _: context::Context) {} - } - - // the assert_type_eq macro can only be used once per block. - { - assert_type_eq!( - <() as Foo>::TwoPartFut, - Pin + Send>> - ); - } - { - assert_type_eq!( - <() as Foo>::BarFut, - Pin + Send>> - ); - } - { - assert_type_eq!( - <() as Foo>::BazFut, - Pin + Send>> - ); - } -} - #[allow(non_camel_case_types)] #[test] fn raw_idents_work() { @@ -59,24 +21,6 @@ fn raw_idents_work() { async fn r#fn(r#impl: r#yield) -> r#yield; async fn r#async(); } - - #[tarpc::server] - impl r#trait for () { - async fn r#await( - self, - _: context::Context, - r#struct: r#yield, - r#enum: i32, - ) -> (r#yield, i32) { - (r#struct, r#enum) - } - - async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { - r#impl - } - - async fn r#async(self, _: context::Context) {} - } } #[test] @@ -100,45 +44,4 @@ fn syntax() { #[doc = "attr"] async fn one_arg_implicit_return_error(one: String); } - - #[tarpc::server] - impl Syntax for () { - #[deny(warnings)] - #[allow(non_snake_case)] - async fn TestCamelCaseDoesntConflict(self, _: context::Context) {} - - async fn hello(self, _: context::Context) -> String { - String::new() - } - - async fn attr(self, _: context::Context, _s: String) -> String { - String::new() - } - - async fn no_args_no_return(self, _: context::Context) {} - - async fn no_args(self, _: context::Context) -> () {} - - async fn one_arg(self, _: context::Context, _one: String) -> i32 { - 0 - } - - async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {} - - async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String { - String::new() - } - - async fn no_args_ret_error(self, _: context::Context) -> i32 { - 0 - } - - async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String { - String::new() - } - - async fn no_arg_implicit_return_error(self, _: context::Context) {} - - async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {} - } } diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index b37cbce..38bd7f0 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -1,9 +1,10 @@ +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use tarpc::context; #[test] fn att_service_trait() { - use futures::future::{ready, Ready}; - #[tarpc::service] trait Foo { async fn two_part(s: String, i: i32) -> (String, i32); @@ -12,19 +13,16 @@ fn att_service_trait() { } impl Foo for () { - type TwoPartFut = Ready<(String, i32)>; - fn two_part(self, _: context::Context, s: String, i: i32) -> Self::TwoPartFut { - ready((s, i)) + async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { + (s, i) } - type BarFut = Ready; - fn bar(self, _: context::Context, s: String) -> Self::BarFut { - ready(s) + async fn bar(self, _: context::Context, s: String) -> String { + s } - type BazFut = Ready<()>; - fn baz(self, _: context::Context) -> Self::BazFut { - ready(()) + async fn baz(self, _: context::Context) { + () } } } @@ -32,8 +30,6 @@ fn att_service_trait() { #[allow(non_camel_case_types)] #[test] fn raw_idents() { - use futures::future::{ready, Ready}; - type r#yield = String; #[tarpc::service] @@ -44,19 +40,21 @@ fn raw_idents() { } impl r#trait for () { - type AwaitFut = Ready<(r#yield, i32)>; - fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut { - ready((r#struct, r#enum)) + async fn r#await( + self, + _: context::Context, + r#struct: r#yield, + r#enum: i32, + ) -> (r#yield, i32) { + (r#struct, r#enum) } - type FnFut = Ready; - fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut { - ready(r#impl) + async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { + r#impl } - type AsyncFut = Ready<()>; - fn r#async(self, _: context::Context) -> Self::AsyncFut { - ready(()) + async fn r#async(self, _: context::Context) { + () } } } diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index c6f8064..8780877 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -75,7 +75,8 @@ opentelemetry-jaeger = { version = "0.17.0", features = ["rt-tokio"] } pin-utils = "0.1.0-alpha" serde_bytes = "0.11" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -tokio = { version = "1", features = ["full", "test-util"] } +tokio = { version = "1", features = ["full", "test-util", "tracing"] } +console-subscriber = "0.1" tokio-serde = { version = "0.8", features = ["json", "bincode"] } trybuild = "1.0" tokio-rustls = "0.23" diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 942fdc8..cc993f0 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -1,5 +1,14 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression}; -use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt}; +use futures::{prelude::*, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; @@ -99,13 +108,16 @@ pub trait World { #[derive(Clone, Debug)] struct HelloServer; -#[tarpc::server] impl World for HelloServer { async fn hello(self, _: context::Context, name: String) -> String { format!("Hey, {name}!") } } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let mut incoming = tcp::listen("localhost:0", Bincode::default).await?; @@ -114,6 +126,7 @@ async fn main() -> anyhow::Result<()> { let transport = incoming.next().await.unwrap().unwrap(); BaseChannel::with_defaults(add_compression(transport)) .execute(HelloServer.serve()) + .for_each(spawn) .await; }); diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index e7e2ce3..2c5fd4d 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -1,3 +1,13 @@ +// Copyright 2022 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + +use futures::prelude::*; use tarpc::context::Context; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; @@ -13,7 +23,6 @@ pub trait PingService { #[derive(Clone)] struct Service; -#[tarpc::server] impl PingService for Service { async fn ping(self, _: Context) {} } @@ -26,13 +35,18 @@ async fn main() -> anyhow::Result<()> { let listener = UnixListener::bind(bind_addr).unwrap(); let codec_builder = LengthDelimitedCodec::builder(); + async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); + } tokio::spawn(async move { loop { let (conn, _addr) = listener.accept().await.unwrap(); let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); - let fut = BaseChannel::with_defaults(transport).execute(Service.serve()); + let fut = BaseChannel::with_defaults(transport) + .execute(Service.serve()) + .for_each(spawn); tokio::spawn(fut); } }); diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 910ab53..5b5b2ee 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -4,6 +4,9 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + /// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher" /// port. Because both publishers and subscribers initiate their connections to the PubSub /// server, the server requires no prior knowledge of either publishers or subscribers. @@ -79,7 +82,6 @@ struct Subscriber { topics: Vec, } -#[tarpc::server] impl subscriber::Subscriber for Subscriber { async fn topics(self, _: context::Context) -> Vec { self.topics.clone() @@ -117,7 +119,8 @@ impl Subscriber { )) } }; - let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve())); + let (handler, abort_handle) = + future::abortable(handler.execute(subscriber.serve()).for_each(spawn)); tokio::spawn(async move { match handler.await { Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."), @@ -143,6 +146,10 @@ struct PublisherAddrs { subscriptions: SocketAddr, } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + impl Publisher { async fn start(self) -> io::Result { let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; @@ -162,6 +169,7 @@ impl Publisher { server::BaseChannel::with_defaults(publisher) .execute(self.serve()) + .for_each(spawn) .await }); @@ -257,7 +265,6 @@ impl Publisher { } } -#[tarpc::server] impl publisher::Publisher for Publisher { async fn publish(self, _: context::Context, topic: String, message: String) { info!("received message to publish."); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 8079231..c6ef61e 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -4,7 +4,10 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use futures::future::{self, Ready}; +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + +use futures::prelude::*; use tarpc::{ client, context, server::{self, Channel}, @@ -23,22 +26,21 @@ pub trait World { struct HelloServer; impl World for HelloServer { - // Each defined rpc generates two items in the trait, a fn that serves the RPC, and - // an associated type representing the future output by the fn. - - type HelloFut = Ready; - - fn hello(self, _: context::Context, name: String) -> Self::HelloFut { - future::ready(format!("Hello, {name}!")) + async fn hello(self, _: context::Context, name: String) -> String { + format!("Hello, {name}!") } } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); let server = server::BaseChannel::with_defaults(server_transport); - tokio::spawn(server.execute(HelloServer.serve())); + tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` // that takes a config and any Transport as input. diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index d7b0c02..a042740 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -1,3 +1,13 @@ +// Copyright 2023 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + +use futures::prelude::*; use rustls_pemfile::certs; use std::io::{BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; @@ -23,7 +33,6 @@ pub trait PingService { #[derive(Clone)] struct Service; -#[tarpc::server] impl PingService for Service { async fn ping(self, _: Context) -> String { "🔒".to_owned() @@ -65,6 +74,10 @@ pub fn load_private_key(key: &str) -> rustls::PrivateKey { panic!("no keys found in {:?} (encrypted keys not supported)", key); } +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::main] async fn main() -> anyhow::Result<()> { // -------------------- start here to setup tls tcp tokio stream -------------------------- @@ -100,7 +113,9 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(framed, Bincode::default()); - let fut = BaseChannel::with_defaults(transport).execute(Service.serve()); + let fut = BaseChannel::with_defaults(transport) + .execute(Service.serve()) + .for_each(spawn); tokio::spawn(fut); } }); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 589c16f..d37fbab 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -4,7 +4,8 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -#![feature(type_alias_impl_trait)] +#![allow(incomplete_features)] +#![feature(async_fn_in_trait, type_alias_impl_trait)] use crate::{ add::{Add as AddService, AddStub}, @@ -25,7 +26,10 @@ use tarpc::{ RpcError, }, context, serde_transport, - server::{incoming::Incoming, BaseChannel, Serve}, + server::{ + incoming::{spawn_incoming, Incoming}, + BaseChannel, Serve, + }, tokio_serde::formats::Json, ClientMessage, Response, ServerError, Transport, }; @@ -51,7 +55,6 @@ pub mod double { #[derive(Clone)] struct AddServer; -#[tarpc::server] impl AddService for AddServer { async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { x + y @@ -63,7 +66,6 @@ struct DoubleServer { add_client: add::AddClient, } -#[tarpc::server] impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, @@ -158,9 +160,8 @@ async fn main() -> anyhow::Result<()> { }); let add_server = add_listener1 .chain(add_listener2) - .map(BaseChannel::with_defaults) - .execute(server); - tokio::spawn(add_server); + .map(BaseChannel::with_defaults); + tokio::spawn(spawn_incoming(add_server.execute(server))); let add_client = add::AddClient::from(make_stub([ tarpc::serde_transport::tcp::connect(addr1, Json::default).await?, @@ -171,11 +172,9 @@ async fn main() -> anyhow::Result<()> { .await? .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); - let double_server = double_listener - .map(BaseChannel::with_defaults) - .take(1) - .execute(DoubleServer { add_client }.serve()); - tokio::spawn(double_server); + let double_server = double_listener.map(BaseChannel::with_defaults).take(1); + let server = DoubleServer { add_client }.serve(); + tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; let double_client = diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index b47d13b..7e6ff6b 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -80,6 +80,8 @@ //! First, let's set up the dependencies and service definition. //! //! ```rust +//! #![allow(incomplete_features)] +//! #![feature(async_fn_in_trait)] //! # extern crate futures; //! //! use futures::{ @@ -104,6 +106,8 @@ //! implement it for our Server struct. //! //! ```rust +//! # #![allow(incomplete_features)] +//! # #![feature(async_fn_in_trait)] //! # extern crate futures; //! # use futures::{ //! # future::{self, Ready}, @@ -126,13 +130,9 @@ //! struct HelloServer; //! //! impl World for HelloServer { -//! // Each defined rpc generates two items in the trait, a fn that serves the RPC, and -//! // an associated type representing the future output by the fn. -//! -//! type HelloFut = Ready; -//! -//! fn hello(self, _: context::Context, name: String) -> Self::HelloFut { -//! future::ready(format!("Hello, {name}!")) +//! // Each defined rpc generates an async fn that serves the RPC +//! async fn hello(self, _: context::Context, name: String) -> String { +//! format!("Hello, {name}!") //! } //! } //! ``` @@ -143,6 +143,8 @@ //! available behind the `tcp` feature. //! //! ```rust +//! # #![allow(incomplete_features)] +//! # #![feature(async_fn_in_trait)] //! # extern crate futures; //! # use futures::{ //! # future::{self, Ready}, @@ -164,11 +166,9 @@ //! # #[derive(Clone)] //! # struct HelloServer; //! # impl World for HelloServer { -//! # // Each defined rpc generates two items in the trait, a fn that serves the RPC, and -//! # // an associated type representing the future output by the fn. -//! # type HelloFut = Ready; -//! # fn hello(self, _: context::Context, name: String) -> Self::HelloFut { -//! # future::ready(format!("Hello, {name}!")) +//! // Each defined rpc generates an async fn that serves the RPC +//! # async fn hello(self, _: context::Context, name: String) -> String { +//! # format!("Hello, {name}!") //! # } //! # } //! # #[cfg(not(feature = "tokio1"))] @@ -179,7 +179,12 @@ //! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); //! //! let server = server::BaseChannel::with_defaults(server_transport); -//! tokio::spawn(server.execute(HelloServer.serve())); +//! tokio::spawn( +//! server.execute(HelloServer.serve()) +//! // Handle all requests concurrently. +//! .for_each(|response| async move { +//! tokio::spawn(response); +//! })); //! //! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` //! // that takes a config and any Transport as input. @@ -234,6 +239,7 @@ pub use tarpc_plugins::derive_serde; /// Rpc methods are specified, mirroring trait syntax: /// /// ``` +/// #![feature(async_fn_in_trait)] /// #[tarpc::service] /// trait Service { /// /// Say hello @@ -253,62 +259,6 @@ pub use tarpc_plugins::derive_serde; /// * `fn new_stub` -- creates a new Client stub. pub use tarpc_plugins::service; -/// A utility macro that can be used for RPC server implementations. -/// -/// Syntactic sugar to make using async functions in the server implementation -/// easier. It does this by rewriting code like this, which would normally not -/// compile because async functions are disallowed in trait implementations: -/// -/// ```rust -/// # use tarpc::context; -/// # use std::net::SocketAddr; -/// #[tarpc::service] -/// trait World { -/// async fn hello(name: String) -> String; -/// } -/// -/// #[derive(Clone)] -/// struct HelloServer(SocketAddr); -/// -/// #[tarpc::server] -/// impl World for HelloServer { -/// async fn hello(self, _: context::Context, name: String) -> String { -/// format!("Hello, {name}! You are connected from {:?}.", self.0) -/// } -/// } -/// ``` -/// -/// Into code like this, which matches the service trait definition: -/// -/// ```rust -/// # use tarpc::context; -/// # use std::pin::Pin; -/// # use futures::Future; -/// # use std::net::SocketAddr; -/// #[derive(Clone)] -/// struct HelloServer(SocketAddr); -/// -/// #[tarpc::service] -/// trait World { -/// async fn hello(name: String) -> String; -/// } -/// -/// impl World for HelloServer { -/// type HelloFut = Pin + Send>>; -/// -/// fn hello(self, _: context::Context, name: String) -> Pin -/// + Send>> { -/// Box::pin(async move { -/// format!("Hello, {name}! You are connected from {:?}.", self.0) -/// }) -/// } -/// } -/// ``` -/// -/// Note that this won't touch functions unless they have been annotated with -/// `async`, meaning that this should not break existing code. -pub use tarpc_plugins::server; - pub(crate) mod cancellations; pub mod client; pub mod context; diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 70f28d9..d0caa2c 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -35,11 +35,6 @@ pub mod limits; /// Provides helper methods for streams of Channels. pub mod incoming; -/// Provides convenience functionality for tokio-enabled applications. -#[cfg(feature = "tokio1")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -pub mod tokio; - use request_hook::{ AfterRequest, AfterRequestHook, BeforeAndAfterRequestHook, BeforeRequest, BeforeRequestHook, }; @@ -79,11 +74,8 @@ pub trait Serve { /// Type of response. type Resp; - /// Type of response future. - type Fut: Future>; - /// Responds to a single request. - fn serve(self, ctx: context::Context, req: Self::Req) -> Self::Fut; + async fn serve(self, ctx: context::Context, req: Self::Req) -> Result; /// Extracts a method name from the request. fn method(&self, _request: &Self::Req) -> Option<&'static str> { @@ -274,10 +266,9 @@ where { type Req = Req; type Resp = Resp; - type Fut = Fut; - fn serve(self, ctx: context::Context, req: Req) -> Self::Fut { - (self.f)(ctx, req) + async fn serve(self, ctx: context::Context, req: Req) -> Result { + (self.f)(ctx, req).await } } @@ -533,34 +524,42 @@ where } } - /// Runs the channel until completion by executing all requests using the given service - /// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's - /// default executor. + /// Returns a stream of request execution futures. Each future represents an in-flight request + /// being responded to by the server. The futures must be awaited or spawned to complete their + /// requests. /// /// # Example /// /// ```rust /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; /// use futures::prelude::*; + /// use tracing_subscriber::prelude::*; /// + /// #[derive(PartialEq, Eq, Debug)] + /// struct MyInt(i32); + /// + /// # #[cfg(not(feature = "tokio1"))] + /// # fn main() {} + /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { /// let (tx, rx) = transport::channel::unbounded(); /// let client = client::new(client::Config::default(), tx).spawn(); - /// let channel = BaseChannel::new(server::Config::default(), rx); - /// tokio::spawn(channel.execute(serve(|_, i| async move { Ok(i + 1) }))); - /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// let channel = BaseChannel::with_defaults(rx); + /// tokio::spawn( + /// channel.execute(serve(|_, MyInt(i)| async move { Ok(MyInt(i + 1)) })) + /// .for_each(|response| async move { + /// tokio::spawn(response); + /// })); + /// assert_eq!( + /// client.call(context::current(), "AddOne", MyInt(1)).await.unwrap(), + /// MyInt(2)); /// } /// ``` - #[cfg(feature = "tokio1")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] - fn execute(self, serve: S) -> self::tokio::TokioChannelExecutor, S> + fn execute(self, serve: S) -> impl Stream> where Self: Sized, - S: Serve + Send + 'static, - S::Fut: Send, - Self::Req: Send + 'static, - Self::Resp: Send + 'static, + S: Serve + Clone, { self.requests().execute(serve) } @@ -654,15 +653,17 @@ where Poll::Pending => Pending, }; - tracing::trace!( - "Expired requests: {:?}, Inbound: {:?}", - expiration_status, - request_status - ); - match cancellation_status + let status = cancellation_status .combine(expiration_status) - .combine(request_status) - { + .combine(request_status); + + tracing::trace!( + "Cancellations: {cancellation_status:?}, \ + Expired requests: {expiration_status:?}, \ + Inbound: {request_status:?}, \ + Overall: {status:?}", + ); + match status { Ready => continue, Closed => return Poll::Ready(None), Pending => return Poll::Pending, @@ -872,6 +873,51 @@ where } Poll::Ready(Some(Ok(()))) } + + /// Returns a stream of request execution futures. Each future represents an in-flight request + /// being responded to by the server. The futures must be awaited or spawned to complete their + /// requests. + /// + /// If the channel encounters an error, the stream is terminated and the error is logged. + /// + /// # Example + /// + /// ```rust + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use futures::prelude::*; + /// + /// # #[cfg(not(feature = "tokio1"))] + /// # fn main() {} + /// # #[cfg(feature = "tokio1")] + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = transport::channel::unbounded(); + /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); + /// let client = client::new(client::Config::default(), tx).spawn(); + /// tokio::spawn( + /// requests.execute(serve(|_, i| async move { Ok(i + 1) })) + /// .for_each(|response| async move { + /// tokio::spawn(response); + /// })); + /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); + /// } + /// ``` + pub fn execute(self, serve: S) -> impl Stream> + where + S: Serve + Clone, + { + self.take_while(|result| { + if let Err(e) = result { + tracing::warn!("Requests stream errored out: {}", e); + } + futures::future::ready(result.is_ok()) + }) + .filter_map(|result| async move { result.ok() }) + .map(move |request| { + let serve = serve.clone(); + request.execute(serve) + }) + } } impl fmt::Debug for Requests @@ -1003,6 +1049,13 @@ impl InFlightRequest { } } +fn print_err(e: &(dyn Error + 'static)) -> String { + anyhow::Chain::new(e) + .map(|e| e.to_string()) + .intersperse(": ".into()) + .collect::() +} + impl Stream for Requests where C: Channel, @@ -1011,17 +1064,33 @@ where fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { - let read = self.as_mut().pump_read(cx)?; + let read = self.as_mut().pump_read(cx).map_err(|e| { + tracing::trace!("read: {}", print_err(&e)); + e + })?; let read_closed = matches!(read, Poll::Ready(None)); - match (read, self.as_mut().pump_write(cx, read_closed)?) { + let write = self.as_mut().pump_write(cx, read_closed).map_err(|e| { + tracing::trace!("write: {}", print_err(&e)); + e + })?; + match (read, write) { (Poll::Ready(None), Poll::Ready(None)) => { + tracing::trace!("read: Poll::Ready(None), write: Poll::Ready(None)"); return Poll::Ready(None); } (Poll::Ready(Some(request_handler)), _) => { + tracing::trace!("read: Poll::Ready(Some), write: _"); return Poll::Ready(Some(Ok(request_handler))); } - (_, Poll::Ready(Some(()))) => {} - _ => { + (_, Poll::Ready(Some(()))) => { + tracing::trace!("read: _, write: Poll::Ready(Some)"); + } + (read @ Poll::Pending, write) | (read, write @ Poll::Pending) => { + tracing::trace!( + "read pending: {}, write pending: {}", + read.is_pending(), + write.is_pending() + ); return Poll::Pending; } } diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 931e876..9195ee3 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -1,13 +1,10 @@ use super::{ limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel}, - Channel, + Channel, Serve, }; use futures::prelude::*; use std::{fmt, hash::Hash}; -#[cfg(feature = "tokio1")] -use super::{tokio::TokioServerExecutor, Serve}; - /// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel). pub trait Incoming where @@ -28,16 +25,62 @@ where MaxRequestsPerChannel::new(self, n) } - /// [Executes](Channel::execute) each incoming channel. Each channel will be handled - /// concurrently by spawning on tokio's default executor, and each request will be also - /// be spawned on tokio's default executor. - #[cfg(feature = "tokio1")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] - fn execute(self, serve: S) -> TokioServerExecutor + /// Returns a stream of channels in execution. Each channel in execution is a stream of + /// futures, where each future is an in-flight request being rsponded to. + fn execute( + self, + serve: S, + ) -> impl Stream>> where - S: Serve, + S: Serve + Clone, { - TokioServerExecutor::new(self, serve) + self.map(move |channel| channel.execute(serve.clone())) + } +} + +#[cfg(feature = "tokio1")] +/// Spawns all channels-in-execution, delegating to the tokio runtime to manage their completion. +/// Each channel is spawned, and each request from each channel is spawned. +/// Note that this function is generic over any stream-of-streams-of-futures, but it is intended +/// for spawning streams of channels. +/// +/// # Example +/// ```rust +/// use tarpc::{ +/// context, +/// client::{self, NewClient}, +/// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, +/// transport, +/// }; +/// use futures::prelude::*; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = transport::channel::unbounded(); +/// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); +/// tokio::spawn(dispatch); +/// +/// let incoming = stream::once(async move { +/// BaseChannel::new(server::Config::default(), rx) +/// }).execute(serve(|_, i| async move { Ok(i + 1) })); +/// tokio::spawn(spawn_incoming(incoming)); +/// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); +/// } +/// ``` +pub async fn spawn_incoming( + incoming: impl Stream< + Item = impl Stream + Send + 'static> + Send + 'static, + >, +) { + use futures::pin_mut; + pin_mut!(incoming); + while let Some(channel) = incoming.next().await { + tokio::spawn(async move { + pin_mut!(channel); + while let Some(request) = channel.next().await { + tokio::spawn(request); + } + }); } } diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index a3803ba..4fd48dd 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -71,19 +71,17 @@ where { type Req = Serv::Req; type Resp = Serv::Resp; - type Fut = AfterRequestHookFut; - fn serve(self, mut ctx: context::Context, req: Serv::Req) -> Self::Fut { - async move { - let AfterRequestHook { - serve, mut hook, .. - } = self; - let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; - resp - } + async fn serve( + self, + mut ctx: context::Context, + req: Serv::Req, + ) -> Result { + let AfterRequestHook { + serve, mut hook, .. + } = self; + let mut resp = serve.serve(ctx, req).await; + hook.after(&mut ctx, &mut resp).await; + resp } } - -type AfterRequestHookFut> = - impl Future>; diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index 38ad54d..2c478db 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -67,18 +67,16 @@ where { type Req = Serv::Req; type Resp = Serv::Resp; - type Fut = BeforeRequestHookFut; - fn serve(self, mut ctx: context::Context, req: Self::Req) -> Self::Fut { + async fn serve( + self, + mut ctx: context::Context, + req: Self::Req, + ) -> Result { let BeforeRequestHook { serve, mut hook, .. } = self; - async move { - hook.before(&mut ctx, &req).await?; - serve.serve(ctx, req).await - } + hook.before(&mut ctx, &req).await?; + serve.serve(ctx, req).await } } - -type BeforeRequestHookFut> = - impl Future>; diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index ca42460..ff61a53 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -8,7 +8,6 @@ use super::{after::AfterRequest, before::BeforeRequest}; use crate::{context, server::Serve, ServerError}; -use futures::prelude::*; use std::marker::PhantomData; /// A Service function that runs a hook both before and after request execution. @@ -47,24 +46,14 @@ where { type Req = Req; type Resp = Resp; - type Fut = BeforeAndAfterRequestHookFut; - fn serve(self, mut ctx: context::Context, req: Req) -> Self::Fut { - async move { - let BeforeAndAfterRequestHook { - serve, mut hook, .. - } = self; - hook.before(&mut ctx, &req).await?; - let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; - resp - } + async fn serve(self, mut ctx: context::Context, req: Req) -> Result { + let BeforeAndAfterRequestHook { + serve, mut hook, .. + } = self; + hook.before(&mut ctx, &req).await?; + let mut resp = serve.serve(ctx, req).await; + hook.after(&mut ctx, &mut resp).await; + resp } } - -type BeforeAndAfterRequestHookFut< - Req, - Resp, - Serv: Serve, - Hook: BeforeRequest + AfterRequest, -> = impl Future>; diff --git a/tarpc/src/server/tokio.rs b/tarpc/src/server/tokio.rs deleted file mode 100644 index e9ad842..0000000 --- a/tarpc/src/server/tokio.rs +++ /dev/null @@ -1,129 +0,0 @@ -use super::{Channel, Requests, Serve}; -use futures::{prelude::*, ready, task::*}; -use pin_project::pin_project; -use std::pin::Pin; - -/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor) -/// for each new channel. Returned by -/// [`Incoming::execute`](crate::server::incoming::Incoming::execute). -#[must_use] -#[pin_project] -#[derive(Debug)] -pub struct TokioServerExecutor { - #[pin] - inner: T, - serve: S, -} - -impl TokioServerExecutor { - pub(crate) fn new(inner: T, serve: S) -> Self { - Self { inner, serve } - } -} - -/// A future that drives the server by [spawning](tokio::spawn) each [response -/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by -/// [`Channel::execute`](crate::server::Channel::execute). -#[must_use] -#[pin_project] -#[derive(Debug)] -pub struct TokioChannelExecutor { - #[pin] - inner: T, - serve: S, -} - -impl TokioServerExecutor { - fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { - self.as_mut().project().inner - } -} - -impl TokioChannelExecutor { - fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { - self.as_mut().project().inner - } -} - -// Send + 'static execution helper methods. - -impl Requests -where - C: Channel, - C::Req: Send + 'static, - C::Resp: Send + 'static, -{ - /// Executes all requests using the given service function. Requests are handled concurrently - /// by [spawning](::tokio::spawn) each handler on tokio's default executor. - /// - /// # Example - /// - /// ```rust - /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; - /// use futures::prelude::*; - /// - /// #[tokio::main] - /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded(); - /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); - /// let client = client::new(client::Config::default(), tx).spawn(); - /// tokio::spawn(requests.execute(serve(|_, i| async move { Ok(i + 1) }))); - /// assert_eq!(client.call(context::current(), "AddOne", 1).await.unwrap(), 2); - /// } - /// ``` - pub fn execute(self, serve: S) -> TokioChannelExecutor - where - S: Serve + Send + 'static, - { - TokioChannelExecutor { inner: self, serve } - } -} - -impl Future for TokioServerExecutor -where - St: Sized + Stream, - C: Channel + Send + 'static, - C::Req: Send + 'static, - C::Resp: Send + 'static, - Se: Serve + Send + 'static + Clone, - Se::Fut: Send, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) { - tokio::spawn(channel.execute(self.serve.clone())); - } - tracing::info!("Server shutting down."); - Poll::Ready(()) - } -} - -impl Future for TokioChannelExecutor, S> -where - C: Channel + 'static, - C::Req: Send + 'static, - C::Resp: Send + 'static, - S: Serve + Send + 'static + Clone, - S::Fut: Send, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) { - match response_handler { - Ok(resp) => { - let server = self.serve.clone(); - tokio::spawn(async move { - resp.execute(server).await; - }); - } - Err(e) => { - tracing::warn!("Requests stream errored out: {}", e); - break; - } - } - } - Poll::Ready(()) - } -} diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 7f3035d..98ea0aa 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -14,9 +14,15 @@ use tokio::sync::mpsc; /// Errors that occur in the sending or receiving of messages over a channel. #[derive(thiserror::Error, Debug)] pub enum ChannelError { - /// An error occurred sending over the channel. - #[error("an error occurred sending over the channel")] + /// An error occurred readying to send into the channel. + #[error("an error occurred readying to send into the channel")] + Ready(#[source] Box), + /// An error occurred sending into the channel. + #[error("an error occurred sending into the channel")] Send(#[source] Box), + /// An error occurred receiving from the channel. + #[error("an error occurred receiving from the channel")] + Receive(#[source] Box), } /// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's @@ -48,7 +54,10 @@ impl Stream for UnboundedChannel { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - self.rx.poll_recv(cx).map(|option| option.map(Ok)) + self.rx + .poll_recv(cx) + .map(|option| option.map(Ok)) + .map_err(ChannelError::Receive) } } @@ -59,7 +68,7 @@ impl Sink for UnboundedChannel { fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(if self.tx.is_closed() { - Err(ChannelError::Send(CLOSED_MESSAGE.into())) + Err(ChannelError::Ready(CLOSED_MESSAGE.into())) } else { Ok(()) }) @@ -110,7 +119,11 @@ impl Stream for Channel { self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - self.project().rx.poll_next(cx).map(|option| option.map(Ok)) + self.project() + .rx + .poll_next(cx) + .map(|option| option.map(Ok)) + .map_err(ChannelError::Receive) } } @@ -121,7 +134,7 @@ impl Sink for Channel { self.project() .tx .poll_ready(cx) - .map_err(|e| ChannelError::Send(Box::new(e))) + .map_err(|e| ChannelError::Ready(Box::new(e))) } fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> { @@ -146,8 +159,7 @@ impl Sink for Channel { } } -#[cfg(test)] -#[cfg(feature = "tokio1")] +#[cfg(all(test, feature = "tokio1"))] mod tests { use crate::{ client::{self, RpcError}, @@ -186,7 +198,10 @@ mod tests { format!("{request:?} is not an int"), ) }) - })), + })) + .for_each(|channel| async move { + tokio::spawn(channel.for_each(|response| response)); + }), ); let client = client::new(client::Config::default(), client_channel).spawn(); diff --git a/tarpc/tests/compile_fail.rs b/tarpc/tests/compile_fail.rs index 4c5a28e..c28fe2f 100644 --- a/tarpc/tests/compile_fail.rs +++ b/tarpc/tests/compile_fail.rs @@ -2,8 +2,6 @@ fn ui() { let t = trybuild::TestCases::new(); t.compile_fail("tests/compile_fail/*.rs"); - #[cfg(feature = "tokio1")] - t.compile_fail("tests/compile_fail/tokio/*.rs"); #[cfg(all(feature = "serde-transport", feature = "tcp"))] t.compile_fail("tests/compile_fail/serde_transport/*.rs"); } diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.rs b/tarpc/tests/compile_fail/must_use_request_dispatch.rs index 2915d32..18cda0d 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.rs +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.rs @@ -1,3 +1,6 @@ +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use tarpc::client; #[tarpc::service] diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr index e652cc8..387e9b8 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr @@ -1,15 +1,15 @@ error: unused `RequestDispatch` that must be used - --> tests/compile_fail/must_use_request_dispatch.rs:13:9 + --> tests/compile_fail/must_use_request_dispatch.rs:16:9 | -13 | WorldClient::new(client::Config::default(), client_transport).dispatch; +16 | WorldClient::new(client::Config::default(), client_transport).dispatch; | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here - --> tests/compile_fail/must_use_request_dispatch.rs:11:12 + --> tests/compile_fail/must_use_request_dispatch.rs:14:12 | -11 | #[deny(unused_must_use)] +14 | #[deny(unused_must_use)] | ^^^^^^^^^^^^^^^ help: use `let _ = ...` to ignore the resulting value | -13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch; +16 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch; | +++++++ diff --git a/tarpc/tests/compile_fail/tarpc_server_missing_async.rs b/tarpc/tests/compile_fail/tarpc_server_missing_async.rs deleted file mode 100644 index 99d858b..0000000 --- a/tarpc/tests/compile_fail/tarpc_server_missing_async.rs +++ /dev/null @@ -1,15 +0,0 @@ -#[tarpc::service(derive_serde = false)] -trait World { - async fn hello(name: String) -> String; -} - -struct HelloServer; - -#[tarpc::server] -impl World for HelloServer { - fn hello(name: String) -> String { - format!("Hello, {name}!", name) - } -} - -fn main() {} diff --git a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr b/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr deleted file mode 100644 index d96cda8..0000000 --- a/tarpc/tests/compile_fail/tarpc_server_missing_async.stderr +++ /dev/null @@ -1,15 +0,0 @@ -error: not all trait items implemented, missing: `HelloFut` - --> tests/compile_fail/tarpc_server_missing_async.rs:9:1 - | -9 | / impl World for HelloServer { -10 | | fn hello(name: String) -> String { -11 | | format!("Hello, {name}!", name) -12 | | } -13 | | } - | |_^ - -error: hint: `#[tarpc::server]` only rewrites async fns, and `fn hello` is not async - --> tests/compile_fail/tarpc_server_missing_async.rs:10:5 - | -10 | fn hello(name: String) -> String { - | ^^^^^^^^ diff --git a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs b/tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs deleted file mode 100644 index 6fc2f2b..0000000 --- a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.rs +++ /dev/null @@ -1,29 +0,0 @@ -use tarpc::{ - context, - server::{self, Channel}, -}; - -#[tarpc::service] -trait World { - async fn hello(name: String) -> String; -} - -#[derive(Clone)] -struct HelloServer; - -#[tarpc::server] -impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { - format!("Hello, {name}!") - } -} - -fn main() { - let (_, server_transport) = tarpc::transport::channel::unbounded(); - let server = server::BaseChannel::with_defaults(server_transport); - - #[deny(unused_must_use)] - { - server.execute(HelloServer.serve()); - } -} diff --git a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr b/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr deleted file mode 100644 index d7ca6e3..0000000 --- a/tarpc/tests/compile_fail/tokio/must_use_channel_executor.stderr +++ /dev/null @@ -1,15 +0,0 @@ -error: unused `TokioChannelExecutor` that must be used - --> tests/compile_fail/tokio/must_use_channel_executor.rs:27:9 - | -27 | server.execute(HelloServer.serve()); - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | -note: the lint level is defined here - --> tests/compile_fail/tokio/must_use_channel_executor.rs:25:12 - | -25 | #[deny(unused_must_use)] - | ^^^^^^^^^^^^^^^ -help: use `let _ = ...` to ignore the resulting value - | -27 | let _ = server.execute(HelloServer.serve()); - | +++++++ diff --git a/tarpc/tests/compile_fail/tokio/must_use_server_executor.rs b/tarpc/tests/compile_fail/tokio/must_use_server_executor.rs deleted file mode 100644 index 950cf74..0000000 --- a/tarpc/tests/compile_fail/tokio/must_use_server_executor.rs +++ /dev/null @@ -1,30 +0,0 @@ -use futures::stream::once; -use tarpc::{ - context, - server::{self, incoming::Incoming}, -}; - -#[tarpc::service] -trait World { - async fn hello(name: String) -> String; -} - -#[derive(Clone)] -struct HelloServer; - -#[tarpc::server] -impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { - format!("Hello, {name}!") - } -} - -fn main() { - let (_, server_transport) = tarpc::transport::channel::unbounded(); - let server = once(async move { server::BaseChannel::with_defaults(server_transport) }); - - #[deny(unused_must_use)] - { - server.execute(HelloServer.serve()); - } -} diff --git a/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr b/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr deleted file mode 100644 index f0bbb68..0000000 --- a/tarpc/tests/compile_fail/tokio/must_use_server_executor.stderr +++ /dev/null @@ -1,15 +0,0 @@ -error: unused `TokioServerExecutor` that must be used - --> tests/compile_fail/tokio/must_use_server_executor.rs:28:9 - | -28 | server.execute(HelloServer.serve()); - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | -note: the lint level is defined here - --> tests/compile_fail/tokio/must_use_server_executor.rs:26:12 - | -26 | #[deny(unused_must_use)] - | ^^^^^^^^^^^^^^^ -help: use `let _ = ...` to ignore the resulting value - | -28 | let _ = server.execute(HelloServer.serve()); - | +++++++ diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 365594b..7cd3cb8 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,3 +1,6 @@ +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use futures::prelude::*; use tarpc::serde_transport; use tarpc::{ @@ -21,7 +24,6 @@ pub trait ColorProtocol { #[derive(Clone)] struct ColorServer; -#[tarpc::server] impl ColorProtocol for ColorServer { async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData { match color { @@ -31,6 +33,11 @@ impl ColorProtocol for ColorServer { } } +#[cfg(test)] +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::test] async fn test_call() -> anyhow::Result<()> { let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?; @@ -40,7 +47,9 @@ async fn test_call() -> anyhow::Result<()> { .take(1) .filter_map(|r| async { r.ok() }) .map(BaseChannel::with_defaults) - .execute(ColorServer.serve()), + .execute(ColorServer.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 50d19b0..9041aae 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -1,13 +1,16 @@ +#![allow(incomplete_features)] +#![feature(async_fn_in_trait)] + use assert_matches::assert_matches; use futures::{ - future::{join_all, ready, Ready}, + future::{join_all, ready}, prelude::*, }; use std::time::{Duration, SystemTime}; use tarpc::{ client::{self}, context, - server::{self, incoming::Incoming, BaseChannel, Channel}, + server::{incoming::Incoming, BaseChannel, Channel}, transport::channel, }; use tokio::join; @@ -22,39 +25,29 @@ trait Service { struct Server; impl Service for Server { - type AddFut = Ready; - - fn add(self, _: context::Context, x: i32, y: i32) -> Self::AddFut { - ready(x + y) + async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + x + y } - type HeyFut = Ready; - - fn hey(self, _: context::Context, name: String) -> Self::HeyFut { - ready(format!("Hey, {name}.")) + async fn hey(self, _: context::Context, name: String) -> String { + format!("Hey, {name}.") } } #[tokio::test] -async fn sequential() -> anyhow::Result<()> { - let _ = tracing_subscriber::fmt::try_init(); - - let (tx, rx) = channel::unbounded(); - +async fn sequential() { + let (tx, rx) = tarpc::transport::channel::unbounded(); + let client = client::new(client::Config::default(), tx).spawn(); + let channel = BaseChannel::with_defaults(rx); tokio::spawn( - BaseChannel::new(server::Config::default(), rx) - .requests() - .execute(Server.serve()), + channel + .execute(tarpc::server::serve(|_, i| async move { Ok(i + 1) })) + .for_each(|response| response), + ); + assert_eq!( + client.call(context::current(), "AddOne", 1).await.unwrap(), + 2 ); - - let client = ServiceClient::new(client::Config::default(), tx).spawn(); - - assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); - assert_matches!( - client.hey(context::current(), "Tim".into()).await, - Ok(ref s) if s == "Hey, Tim."); - - Ok(()) } #[tokio::test] @@ -70,7 +63,6 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { #[derive(Debug)] struct AllHandlersComplete; - #[tarpc::server] impl Loop for LoopServer { async fn r#loop(self, _: context::Context) { loop { @@ -121,7 +113,9 @@ async fn serde_tcp() -> anyhow::Result<()> { .take(1) .filter_map(|r| async { r.ok() }) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; @@ -151,7 +145,9 @@ async fn serde_uds() -> anyhow::Result<()> { .take(1) .filter_map(|r| async { r.ok() }) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; @@ -175,7 +171,9 @@ async fn concurrent() -> anyhow::Result<()> { tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let client = ServiceClient::new(client::Config::default(), tx).spawn(); @@ -199,7 +197,9 @@ async fn concurrent_join() -> anyhow::Result<()> { tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) - .execute(Server.serve()), + .execute(Server.serve()) + .map(|channel| channel.for_each(spawn)) + .for_each(spawn), ); let client = ServiceClient::new(client::Config::default(), tx).spawn(); @@ -216,15 +216,20 @@ async fn concurrent_join() -> anyhow::Result<()> { Ok(()) } +#[cfg(test)] +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + #[tokio::test] async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); let (tx, rx) = channel::unbounded(); tokio::spawn( - stream::once(ready(rx)) - .map(BaseChannel::with_defaults) - .execute(Server.serve()), + BaseChannel::with_defaults(rx) + .execute(Server.serve()) + .for_each(spawn), ); let client = ServiceClient::new(client::Config::default(), tx).spawn(); @@ -249,11 +254,9 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - type CountFut = futures::future::Ready; - - fn count(self, _: context::Context) -> Self::CountFut { + async fn count(self, _: context::Context) -> u32 { self.0 += 1; - futures::future::ready(self.0) + self.0 } }