mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-01-04 18:43:54 +01:00
Use async fn in generated traits!!
The major breaking change is that Channel::execute no longer internally
spawns RPC handlers, because it is no longer possible to place a Send
bound on the return type of Serve::serve. Instead, Channel::execute
returns a stream of RPC handler futures.
Service authors can reproduce the old behavior by spawning each response
handler (the compiler knows whether or not the futures can be spawned;
it's just that the bounds can't be expressed generically):
channel.execute(server.serve())
.for_each(|rpc| { tokio::spawn(rpc); })
This commit is contained in:
@@ -12,18 +12,18 @@ extern crate quote;
|
||||
extern crate syn;
|
||||
|
||||
use proc_macro::TokenStream;
|
||||
use proc_macro2::{Span, TokenStream as TokenStream2};
|
||||
use proc_macro2::TokenStream as TokenStream2;
|
||||
use quote::{format_ident, quote, ToTokens};
|
||||
use syn::{
|
||||
braced,
|
||||
ext::IdentExt,
|
||||
parenthesized,
|
||||
parse::{Parse, ParseStream},
|
||||
parse_macro_input, parse_quote, parse_str,
|
||||
parse_macro_input, parse_quote,
|
||||
spanned::Spanned,
|
||||
token::Comma,
|
||||
Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool,
|
||||
MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility,
|
||||
Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type,
|
||||
Visibility,
|
||||
};
|
||||
|
||||
/// Accumulates multiple errors into a result.
|
||||
@@ -257,7 +257,6 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
.map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
|
||||
.collect();
|
||||
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
|
||||
let response_fut_name = &format!("{}ResponseFut", ident.unraw());
|
||||
let derive_serialize = if derive_serde.0 {
|
||||
Some(
|
||||
quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
|
||||
@@ -274,11 +273,9 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
ServiceGenerator {
|
||||
response_fut_name,
|
||||
service_ident: ident,
|
||||
client_stub_ident: &format_ident!("{}Stub", ident),
|
||||
server_ident: &format_ident!("Serve{}", ident),
|
||||
response_fut_ident: &Ident::new(response_fut_name, ident.span()),
|
||||
client_ident: &format_ident!("{}Client", ident),
|
||||
request_ident: &format_ident!("{}Request", ident),
|
||||
response_ident: &format_ident!("{}Response", ident),
|
||||
@@ -305,138 +302,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
.zip(camel_case_fn_names.iter())
|
||||
.map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
|
||||
.collect::<Vec<_>>(),
|
||||
future_types: &camel_case_fn_names
|
||||
.iter()
|
||||
.map(|name| parse_str(&format!("{name}Fut")).unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
derive_serialize: derive_serialize.as_ref(),
|
||||
}
|
||||
.into_token_stream()
|
||||
.into()
|
||||
}
|
||||
|
||||
/// generate an identifier consisting of the method name to CamelCase with
|
||||
/// Fut appended to it.
|
||||
fn associated_type_for_rpc(method: &ImplItemMethod) -> String {
|
||||
snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut"
|
||||
}
|
||||
|
||||
/// Transforms an async function into a sync one, returning a type declaration
|
||||
/// for the return type (a future).
|
||||
fn transform_method(method: &mut ImplItemMethod) -> ImplItemType {
|
||||
method.sig.asyncness = None;
|
||||
|
||||
// get either the return type or ().
|
||||
let ret = match &method.sig.output {
|
||||
ReturnType::Default => quote!(()),
|
||||
ReturnType::Type(_, ret) => quote!(#ret),
|
||||
};
|
||||
|
||||
let fut_name = associated_type_for_rpc(method);
|
||||
let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span());
|
||||
|
||||
// generate the updated return signature.
|
||||
method.sig.output = parse_quote! {
|
||||
-> ::core::pin::Pin<Box<
|
||||
dyn ::core::future::Future<Output = #ret> + ::core::marker::Send
|
||||
>>
|
||||
};
|
||||
|
||||
// transform the body of the method into Box::pin(async move { body }).
|
||||
let block = method.block.clone();
|
||||
method.block = parse_quote! [{
|
||||
Box::pin(async move
|
||||
#block
|
||||
)
|
||||
}];
|
||||
|
||||
// generate and return type declaration for return type.
|
||||
let t: ImplItemType = parse_quote! {
|
||||
type #fut_name_ident = ::core::pin::Pin<Box<dyn ::core::future::Future<Output = #ret> + ::core::marker::Send>>;
|
||||
};
|
||||
|
||||
t
|
||||
}
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let mut item = syn::parse_macro_input!(input as ItemImpl);
|
||||
let span = item.span();
|
||||
|
||||
// the generated type declarations
|
||||
let mut types: Vec<ImplItemType> = Vec::new();
|
||||
let mut expected_non_async_types: Vec<(&ImplItemMethod, String)> = Vec::new();
|
||||
let mut found_non_async_types: Vec<&ImplItemType> = Vec::new();
|
||||
|
||||
for inner in &mut item.items {
|
||||
match inner {
|
||||
ImplItem::Method(method) => {
|
||||
if method.sig.asyncness.is_some() {
|
||||
// if this function is declared async, transform it into a regular function
|
||||
let typedecl = transform_method(method);
|
||||
types.push(typedecl);
|
||||
} else {
|
||||
// If it's not async, keep track of all required associated types for better
|
||||
// error reporting.
|
||||
expected_non_async_types.push((method, associated_type_for_rpc(method)));
|
||||
}
|
||||
}
|
||||
ImplItem::Type(typedecl) => found_non_async_types.push(typedecl),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) =
|
||||
verify_types_were_provided(span, &expected_non_async_types, &found_non_async_types)
|
||||
{
|
||||
return TokenStream::from(e.to_compile_error());
|
||||
}
|
||||
|
||||
// add the type declarations into the impl block
|
||||
for t in types.into_iter() {
|
||||
item.items.push(syn::ImplItem::Type(t));
|
||||
}
|
||||
|
||||
TokenStream::from(quote!(#item))
|
||||
}
|
||||
|
||||
fn verify_types_were_provided(
|
||||
span: Span,
|
||||
expected: &[(&ImplItemMethod, String)],
|
||||
provided: &[&ImplItemType],
|
||||
) -> syn::Result<()> {
|
||||
let mut result = Ok(());
|
||||
for (method, expected) in expected {
|
||||
if !provided.iter().any(|typedecl| typedecl.ident == expected) {
|
||||
let mut e = syn::Error::new(
|
||||
span,
|
||||
format!("not all trait items implemented, missing: `{expected}`"),
|
||||
);
|
||||
let fn_span = method.sig.fn_token.span();
|
||||
e.extend(syn::Error::new(
|
||||
fn_span.join(method.sig.ident.span()).unwrap_or(fn_span),
|
||||
format!(
|
||||
"hint: `#[tarpc::server]` only rewrites async fns, and `fn {}` is not async",
|
||||
method.sig.ident
|
||||
),
|
||||
));
|
||||
match result {
|
||||
Ok(_) => result = Err(e),
|
||||
Err(ref mut error) => error.extend(Some(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
// Things needed to generate the service items: trait, serve impl, request/response enums, and
|
||||
// the client stub.
|
||||
struct ServiceGenerator<'a> {
|
||||
service_ident: &'a Ident,
|
||||
client_stub_ident: &'a Ident,
|
||||
server_ident: &'a Ident,
|
||||
response_fut_ident: &'a Ident,
|
||||
response_fut_name: &'a str,
|
||||
client_ident: &'a Ident,
|
||||
request_ident: &'a Ident,
|
||||
response_ident: &'a Ident,
|
||||
@@ -444,7 +321,6 @@ struct ServiceGenerator<'a> {
|
||||
attrs: &'a [Attribute],
|
||||
rpcs: &'a [RpcMethod],
|
||||
camel_case_idents: &'a [Ident],
|
||||
future_types: &'a [Type],
|
||||
method_idents: &'a [&'a Ident],
|
||||
request_names: &'a [String],
|
||||
method_attrs: &'a [&'a [Attribute]],
|
||||
@@ -460,7 +336,6 @@ impl<'a> ServiceGenerator<'a> {
|
||||
attrs,
|
||||
rpcs,
|
||||
vis,
|
||||
future_types,
|
||||
return_types,
|
||||
service_ident,
|
||||
client_stub_ident,
|
||||
@@ -470,27 +345,19 @@ impl<'a> ServiceGenerator<'a> {
|
||||
..
|
||||
} = self;
|
||||
|
||||
let types_and_fns = rpcs
|
||||
let rpc_fns = rpcs
|
||||
.iter()
|
||||
.zip(future_types.iter())
|
||||
.zip(return_types.iter())
|
||||
.map(
|
||||
|(
|
||||
(
|
||||
RpcMethod {
|
||||
attrs, ident, args, ..
|
||||
},
|
||||
future_type,
|
||||
),
|
||||
RpcMethod {
|
||||
attrs, ident, args, ..
|
||||
},
|
||||
output,
|
||||
)| {
|
||||
let ty_doc = format!("The response future returned by [`{service_ident}::{ident}`].");
|
||||
quote! {
|
||||
#[doc = #ty_doc]
|
||||
type #future_type: std::future::Future<Output = #output>;
|
||||
|
||||
#( #attrs )*
|
||||
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type;
|
||||
async fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> #output;
|
||||
}
|
||||
},
|
||||
);
|
||||
@@ -499,7 +366,7 @@ impl<'a> ServiceGenerator<'a> {
|
||||
quote! {
|
||||
#( #attrs )*
|
||||
#vis trait #service_ident: Sized {
|
||||
#( #types_and_fns )*
|
||||
#( #rpc_fns )*
|
||||
|
||||
/// Returns a serving function to use with
|
||||
/// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute).
|
||||
@@ -539,7 +406,6 @@ impl<'a> ServiceGenerator<'a> {
|
||||
server_ident,
|
||||
service_ident,
|
||||
response_ident,
|
||||
response_fut_ident,
|
||||
camel_case_idents,
|
||||
arg_pats,
|
||||
method_idents,
|
||||
@@ -553,7 +419,6 @@ impl<'a> ServiceGenerator<'a> {
|
||||
{
|
||||
type Req = #request_ident;
|
||||
type Resp = #response_ident;
|
||||
type Fut = #response_fut_ident<S>;
|
||||
|
||||
fn method(&self, req: &#request_ident) -> Option<&'static str> {
|
||||
Some(match req {
|
||||
@@ -565,15 +430,16 @@ impl<'a> ServiceGenerator<'a> {
|
||||
})
|
||||
}
|
||||
|
||||
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
|
||||
async fn serve(self, ctx: tarpc::context::Context, req: #request_ident)
|
||||
-> Result<#response_ident, tarpc::ServerError> {
|
||||
match req {
|
||||
#(
|
||||
#request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
|
||||
#response_fut_ident::#camel_case_idents(
|
||||
Ok(#response_ident::#camel_case_idents(
|
||||
#service_ident::#method_idents(
|
||||
self.service, ctx, #( #arg_pats ),*
|
||||
)
|
||||
)
|
||||
).await
|
||||
))
|
||||
}
|
||||
)*
|
||||
}
|
||||
@@ -624,74 +490,6 @@ impl<'a> ServiceGenerator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn enum_response_future(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
vis,
|
||||
service_ident,
|
||||
response_fut_ident,
|
||||
camel_case_idents,
|
||||
future_types,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
/// A future resolving to a server response.
|
||||
#[allow(missing_docs)]
|
||||
#vis enum #response_fut_ident<S: #service_ident> {
|
||||
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_debug_for_response_future(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
service_ident,
|
||||
response_fut_ident,
|
||||
response_fut_name,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fmt.debug_struct(#response_fut_name).finish()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_future_for_response_future(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
service_ident,
|
||||
response_fut_ident,
|
||||
response_ident,
|
||||
camel_case_idents,
|
||||
..
|
||||
} = self;
|
||||
|
||||
quote! {
|
||||
impl<S: #service_ident> std::future::Future for #response_fut_ident<S> {
|
||||
type Output = Result<#response_ident, tarpc::ServerError>;
|
||||
|
||||
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
|
||||
-> std::task::Poll<Result<#response_ident, tarpc::ServerError>>
|
||||
{
|
||||
unsafe {
|
||||
match std::pin::Pin::get_unchecked_mut(self) {
|
||||
#(
|
||||
#response_fut_ident::#camel_case_idents(resp) =>
|
||||
std::pin::Pin::new_unchecked(resp)
|
||||
.poll(cx)
|
||||
.map(#response_ident::#camel_case_idents)
|
||||
.map(Ok),
|
||||
)*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn struct_client(&self) -> TokenStream2 {
|
||||
let &Self {
|
||||
vis,
|
||||
@@ -804,9 +602,6 @@ impl<'a> ToTokens for ServiceGenerator<'a> {
|
||||
self.impl_serve_for_server(),
|
||||
self.enum_request(),
|
||||
self.enum_response(),
|
||||
self.enum_response_future(),
|
||||
self.impl_debug_for_response_future(),
|
||||
self.impl_future_for_response_future(),
|
||||
self.struct_client(),
|
||||
self.impl_client_new(),
|
||||
self.impl_client_rpc_methods(),
|
||||
|
||||
Reference in New Issue
Block a user