diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 754e6a7..41a34cb 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -12,11 +12,14 @@ extern crate quote; extern crate syn; use proc_macro::TokenStream; -use quote::{format_ident, quote}; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote, ToTokens}; use syn::{ - braced, parenthesized, + braced, + ext::IdentExt, + parenthesized, parse::{Parse, ParseStream}, - parse_macro_input, parse_quote, + parse_macro_input, parse_quote, parse_str, punctuated::Punctuated, token::Comma, Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type, @@ -42,7 +45,7 @@ impl Parse for Service { let attrs = input.call(Attribute::parse_outer)?; let vis = input.parse()?; input.parse::()?; - let ident = input.parse()?; + let ident: Ident = input.parse()?; let content; braced!(content in input); let mut rpcs = Vec::::new(); @@ -53,7 +56,7 @@ impl Parse for Service { if rpc.ident == "new" { return Err(input.error(format!( "method name conflicts with generated fn `{}Client::new`", - ident + ident.unraw() ))); } if rpc.ident == "serve" { @@ -63,6 +66,7 @@ impl Parse for Service { ))); } } + Ok(Self { attrs, vis, @@ -156,19 +160,19 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { ref rpcs, } = parse_macro_input!(input as Service); - let camel_case_fn_names: &[String] = &rpcs + let camel_case_fn_names: &Vec<_> = &rpcs .iter() - .map(|rpc| snake_to_camel(&rpc.ident.to_string())) - .collect::>(); + .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string())) + .collect(); let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::>(); - let response_fut_name = &format!("{}ResponseFut", ident); - let derive_serialize = quote!(#[derive(serde::Serialize, serde::Deserialize)]); + let response_fut_name = &format!("{}ResponseFut", ident.unraw()); + let derive_serialize = if derive_serde.0 { + Some(quote!(#[derive(serde::Serialize, serde::Deserialize)])) + } else { + None + }; - let gen_args = &GenArgs { - attrs, - vis, - rpcs, - args, + ServiceGenerator { response_fut_name, service_ident: ident, server_ident: &format_ident!("Serve{}", ident), @@ -176,8 +180,12 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { client_ident: &format_ident!("{}Client", ident), request_ident: &format_ident!("{}Request", ident), response_ident: &format_ident!("{}Response", ident), + vis, + args, method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::>(), - method_names: &rpcs.iter().map(|rpc| &rpc.ident).collect::>(), + method_idents: &rpcs.iter().map(|rpc| &rpc.ident).collect::>(), + attrs, + rpcs, return_types: &rpcs .iter() .map(|rpc| match rpc.output { @@ -185,7 +193,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { ReturnType::Default => unit_type, }) .collect::>(), - arg_vars: &args + arg_pats: &args .iter() .map(|args| args.iter().map(|arg| &*arg.pat).collect()) .collect::>(), @@ -194,43 +202,19 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .zip(camel_case_fn_names.iter()) .map(|(rpc, name)| Ident::new(name, rpc.ident.span())) .collect::>(), - future_idents: &camel_case_fn_names + future_types: &camel_case_fn_names .iter() - .map(|name| format_ident!("{}Fut", name)) + .map(|name| parse_str(&format!("{}Fut", name)).unwrap()) .collect::>(), - derive_serialize: if derive_serde.0 { - Some(&derive_serialize) - } else { - None - }, - }; - - let tokens = vec![ - trait_service(gen_args), - struct_server(gen_args), - impl_serve_for_server(gen_args), - enum_request(gen_args), - enum_response(gen_args), - enum_response_future(gen_args), - impl_debug_for_response_future(gen_args), - impl_future_for_response_future(gen_args), - struct_client(gen_args), - impl_from_for_client(gen_args), - impl_client_new(gen_args), - impl_client_rpc_methods(gen_args), - ]; - - tokens - .into_iter() - .collect::() - .into() + derive_serialize: derive_serialize.as_ref(), + } + .into_token_stream() + .into() } // Things needed to generate the service items: trait, serve impl, request/response enums, and // the client stub. -struct GenArgs<'a> { - attrs: &'a [Attribute], - rpcs: &'a [RpcMethod], +struct ServiceGenerator<'a> { service_ident: &'a Ident, server_ident: &'a Ident, response_fut_ident: &'a Ident, @@ -238,331 +222,356 @@ struct GenArgs<'a> { client_ident: &'a Ident, request_ident: &'a Ident, response_ident: &'a Ident, - method_attrs: &'a [&'a [Attribute]], vis: &'a Visibility, - method_names: &'a [&'a Ident], + attrs: &'a [Attribute], + rpcs: &'a [RpcMethod], + camel_case_idents: &'a [Ident], + future_types: &'a [Type], + method_idents: &'a [&'a Ident], + method_attrs: &'a [&'a [Attribute]], args: &'a [&'a [PatType]], return_types: &'a [&'a Type], - arg_vars: &'a [Vec<&'a Pat>], - camel_case_idents: &'a [Ident], - future_idents: &'a [Ident], - derive_serialize: Option<&'a proc_macro2::TokenStream>, + arg_pats: &'a [Vec<&'a Pat>], + derive_serialize: Option<&'a TokenStream2>, } -fn trait_service( - &GenArgs { - attrs, - rpcs, - vis, - future_idents, - return_types, - service_ident, - server_ident, - .. - }: &GenArgs, -) -> proc_macro2::TokenStream { - let types_and_fns = rpcs - .iter() - .zip(future_idents.iter()) - .zip(return_types.iter()) - .map( - |( - ( - RpcMethod { - attrs, ident, args, .. - }, - future_type, - ), - output, - )| { - let ty_doc = format!("The response future returned by {}.", ident); - quote! { - #[doc = #ty_doc] - type #future_type: std::future::Future; +impl<'a> ServiceGenerator<'a> { + fn trait_service(&self) -> TokenStream2 { + let &Self { + attrs, + rpcs, + vis, + future_types, + return_types, + service_ident, + server_ident, + .. + } = self; - #( #attrs )* - fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type; - } - }, - ); + let types_and_fns = rpcs + .iter() + .zip(future_types.iter()) + .zip(return_types.iter()) + .map( + |( + ( + RpcMethod { + attrs, ident, args, .. + }, + future_type, + ), + output, + )| { + let ty_doc = format!("The response future returned by {}.", ident); + quote! { + #[doc = #ty_doc] + type #future_type: std::future::Future; - quote! { - #( #attrs )* - #vis trait #service_ident: Clone { - #( #types_and_fns )* + #( #attrs )* + fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type; + } + }, + ); - /// Returns a serving function to use with tarpc::server::Server. - fn serve(self) -> #server_ident { - #server_ident { service: self } - } - } - } -} + quote! { + #( #attrs )* + #vis trait #service_ident: Clone { + #( #types_and_fns )* -fn struct_server( - &GenArgs { - vis, server_ident, .. - }: &GenArgs, -) -> proc_macro2::TokenStream { - quote! { - #[derive(Clone)] - #vis struct #server_ident { - service: S, - } - } -} - -fn impl_serve_for_server( - &GenArgs { - request_ident, - server_ident, - service_ident, - response_ident, - response_fut_ident, - camel_case_idents, - arg_vars, - method_names, - .. - }: &GenArgs, -) -> proc_macro2::TokenStream { - quote! { - impl tarpc::server::Serve<#request_ident> for #server_ident - where S: #service_ident - { - type Resp = #response_ident; - type Fut = #response_fut_ident; - - fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut { - match req { - #( - #request_ident::#camel_case_idents{ #( #arg_vars ),* } => { - #response_fut_ident::#camel_case_idents( - #service_ident::#method_names( - self.service, ctx, #( #arg_vars ),*)) - } - )* + /// Returns a serving function to use with tarpc::server::Server. + fn serve(self) -> #server_ident { + #server_ident { service: self } } } } } -} -fn enum_request( - &GenArgs { - derive_serialize, - vis, - request_ident, - camel_case_idents, - args, - .. - }: &GenArgs, -) -> proc_macro2::TokenStream { - quote! { - /// The request sent over the wire from the client to the server. - #[derive(Debug)] - #derive_serialize - #vis enum #request_ident { - #( #camel_case_idents{ #( #args ),* } ),* - } - } -} + fn struct_server(&self) -> TokenStream2 { + let &Self { + vis, server_ident, .. + } = self; -fn enum_response( - &GenArgs { - derive_serialize, - vis, - response_ident, - camel_case_idents, - return_types, - .. - }: &GenArgs, -) -> proc_macro2::TokenStream { - quote! { - /// The response sent over the wire from the server to the client. - #[derive(Debug)] - #derive_serialize - #vis enum #response_ident { - #( #camel_case_idents(#return_types) ),* - } - } -} - -fn enum_response_future( - &GenArgs { - vis, - service_ident, - response_fut_ident, - camel_case_idents, - future_idents, - .. - }: &GenArgs, -) -> proc_macro2::TokenStream { - quote! { - /// A future resolving to a server response. - #vis enum #response_fut_ident { - #( #camel_case_idents(::#future_idents) ),* - } - } -} - -fn impl_debug_for_response_future( - &GenArgs { - service_ident, - response_fut_ident, - response_fut_name, - .. - }: &GenArgs, -) -> proc_macro2::TokenStream { - quote! { - impl std::fmt::Debug for #response_fut_ident { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - fmt.debug_struct(#response_fut_name).finish() + quote! { + #[derive(Clone)] + #vis struct #server_ident { + service: S, } } } -} -fn impl_future_for_response_future( - &GenArgs { - service_ident, - response_fut_ident, - response_ident, - camel_case_idents, - .. - }: &GenArgs, -) -> proc_macro2::TokenStream { - quote! { - impl std::future::Future for #response_fut_ident { - type Output = #response_ident; + fn impl_serve_for_server(&self) -> TokenStream2 { + let &Self { + request_ident, + server_ident, + service_ident, + response_ident, + response_fut_ident, + camel_case_idents, + arg_pats, + method_idents, + .. + } = self; - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) - -> std::task::Poll<#response_ident> + quote! { + impl tarpc::server::Serve<#request_ident> for #server_ident + where S: #service_ident { - unsafe { - match std::pin::Pin::get_unchecked_mut(self) { + type Resp = #response_ident; + type Fut = #response_fut_ident; + + fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut { + match req { #( - #response_fut_ident::#camel_case_idents(resp) => - std::pin::Pin::new_unchecked(resp) - .poll(cx) - .map(#response_ident::#camel_case_idents), + #request_ident::#camel_case_idents{ #( #arg_pats ),* } => { + #response_fut_ident::#camel_case_idents( + #service_ident::#method_idents( + self.service, ctx, #( #arg_pats ),* + ) + ) + } )* } } } } } -} -fn struct_client( - &GenArgs { - vis, - client_ident, - request_ident, - response_ident, - .. - }: &GenArgs, -) -> proc_macro2::TokenStream { - quote! { - #[allow(unused)] - #[derive(Clone, Debug)] - /// The client stub that makes RPC calls to the server. Exposes a Future interface. - #vis struct #client_ident>(C); - } -} + fn enum_request(&self) -> TokenStream2 { + let &Self { + derive_serialize, + vis, + request_ident, + camel_case_idents, + args, + .. + } = self; -fn impl_from_for_client( - &GenArgs { - client_ident, - request_ident, - response_ident, - .. - }: &GenArgs, -) -> proc_macro2::TokenStream { - quote! { - impl From for #client_ident - where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident> - { - fn from(client: C) -> Self { - #client_ident(client) + quote! { + /// The request sent over the wire from the client to the server. + #[derive(Debug)] + #derive_serialize + #vis enum #request_ident { + #( #camel_case_idents{ #( #args ),* } ),* } } } -} -fn impl_client_new( - &GenArgs { - client_ident, - vis, - request_ident, - response_ident, - .. - }: &GenArgs, -) -> proc_macro2::TokenStream { - quote! { - impl #client_ident { - /// Returns a new client stub that sends requests over the given transport. - #vis fn new(config: tarpc::client::Config, transport: T) - -> tarpc::client::NewClient< - Self, - tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>> - where - T: tarpc::Transport, tarpc::Response<#response_ident>> - { - let new_client = tarpc::client::new(config, transport); - tarpc::client::NewClient { - client: #client_ident(new_client.client), - dispatch: new_client.dispatch, + fn enum_response(&self) -> TokenStream2 { + let &Self { + derive_serialize, + vis, + response_ident, + camel_case_idents, + return_types, + .. + } = self; + + quote! { + /// The response sent over the wire from the server to the client. + #[derive(Debug)] + #derive_serialize + #vis enum #response_ident { + #( #camel_case_idents(#return_types) ),* + } + } + } + + fn enum_response_future(&self) -> TokenStream2 { + let &Self { + vis, + service_ident, + response_fut_ident, + camel_case_idents, + future_types, + .. + } = self; + + quote! { + /// A future resolving to a server response. + #vis enum #response_fut_ident { + #( #camel_case_idents(::#future_types) ),* + } + } + } + + fn impl_debug_for_response_future(&self) -> TokenStream2 { + let &Self { + service_ident, + response_fut_ident, + response_fut_name, + .. + } = self; + + quote! { + impl std::fmt::Debug for #response_fut_ident { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.debug_struct(#response_fut_name).finish() } } - } } -} -fn impl_client_rpc_methods( - &GenArgs { - client_ident, - request_ident, - response_ident, - method_attrs, - vis, - method_names, - args, - return_types, - arg_vars, - camel_case_idents, - .. - }: &GenArgs, -) -> proc_macro2::TokenStream { - quote! { - impl #client_ident - where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident> - { - #( - #[allow(unused)] - #( #method_attrs )* - #vis fn #method_names(&mut self, ctx: tarpc::context::Context, #( #args ),*) - -> impl std::future::Future> + '_ { - let request = #request_ident::#camel_case_idents { #( #arg_vars ),* }; - let resp = tarpc::Client::call(&mut self.0, ctx, request); - async move { - match resp.await? { - #response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg), - _ => unreachable!(), + fn impl_future_for_response_future(&self) -> TokenStream2 { + let &Self { + service_ident, + response_fut_ident, + response_ident, + camel_case_idents, + .. + } = self; + + quote! { + impl std::future::Future for #response_fut_ident { + type Output = #response_ident; + + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) + -> std::task::Poll<#response_ident> + { + unsafe { + match std::pin::Pin::get_unchecked_mut(self) { + #( + #response_fut_ident::#camel_case_idents(resp) => + std::pin::Pin::new_unchecked(resp) + .poll(cx) + .map(#response_ident::#camel_case_idents), + )* } } } - )* + } } } + + fn struct_client(&self) -> TokenStream2 { + let &Self { + vis, + client_ident, + request_ident, + response_ident, + .. + } = self; + + quote! { + #[allow(unused)] + #[derive(Clone, Debug)] + /// The client stub that makes RPC calls to the server. Exposes a Future interface. + #vis struct #client_ident>(C); + } + } + + fn impl_from_for_client(&self) -> TokenStream2 { + let &Self { + client_ident, + request_ident, + response_ident, + .. + } = self; + + quote! { + impl From for #client_ident + where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident> + { + fn from(client: C) -> Self { + #client_ident(client) + } + } + } + } + + fn impl_client_new(&self) -> TokenStream2 { + let &Self { + client_ident, + vis, + request_ident, + response_ident, + .. + } = self; + + quote! { + impl #client_ident { + /// Returns a new client stub that sends requests over the given transport. + #vis fn new(config: tarpc::client::Config, transport: T) + -> tarpc::client::NewClient< + Self, + tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T> + > + where + T: tarpc::Transport, tarpc::Response<#response_ident>> + { + let new_client = tarpc::client::new(config, transport); + tarpc::client::NewClient { + client: #client_ident(new_client.client), + dispatch: new_client.dispatch, + } + } + + } + } + } + + fn impl_client_rpc_methods(&self) -> TokenStream2 { + let &Self { + client_ident, + request_ident, + response_ident, + method_attrs, + vis, + method_idents, + args, + return_types, + arg_pats, + camel_case_idents, + .. + } = self; + + quote! { + impl #client_ident + where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident> + { + #( + #[allow(unused)] + #( #method_attrs )* + #vis fn #method_idents(&mut self, ctx: tarpc::context::Context, #( #args ),*) + -> impl std::future::Future> + '_ { + let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; + let resp = tarpc::Client::call(&mut self.0, ctx, request); + async move { + match resp.await? { + #response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg), + _ => unreachable!(), + } + } + } + )* + } + } + } +} + +impl<'a> ToTokens for ServiceGenerator<'a> { + fn to_tokens(&self, output: &mut TokenStream2) { + output.extend(vec![ + self.trait_service(), + self.struct_server(), + self.impl_serve_for_server(), + self.enum_request(), + self.enum_response(), + self.enum_response_future(), + self.impl_debug_for_response_future(), + self.impl_future_for_response_future(), + self.struct_client(), + self.impl_from_for_client(), + self.impl_client_new(), + self.impl_client_rpc_methods(), + ]) + } } fn snake_to_camel(ident_str: &str) -> String { - let chars = ident_str.chars(); let mut camel_ty = String::with_capacity(ident_str.len()); let mut last_char_was_underscore = true; - for c in chars { + for c in ident_str.chars() { match c { '_' => last_char_was_underscore = true, c if last_char_was_underscore => { diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 457ae2c..b37cbce 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -29,6 +29,38 @@ fn att_service_trait() { } } +#[allow(non_camel_case_types)] +#[test] +fn raw_idents() { + use futures::future::{ready, Ready}; + + type r#yield = String; + + #[tarpc::service] + trait r#trait { + async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32); + async fn r#fn(r#impl: r#yield) -> r#yield; + async fn r#async(); + } + + impl r#trait for () { + type AwaitFut = Ready<(r#yield, i32)>; + fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut { + ready((r#struct, r#enum)) + } + + type FnFut = Ready; + fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut { + ready(r#impl) + } + + type AsyncFut = Ready<()>; + fn r#async(self, _: context::Context) -> Self::AsyncFut { + ready(()) + } + } +} + #[test] fn syntax() { #[tarpc::service]