diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index f77908c..754e6a7 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -148,49 +148,123 @@ impl Parse for DeriveSerde { #[proc_macro_attribute] pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { let derive_serde = parse_macro_input!(attr as DeriveSerde); - let unit_type: &Type = &parse_quote!(()); - let Service { - attrs, - vis, - ident, - rpcs, + ref attrs, + ref vis, + ref ident, + ref rpcs, } = parse_macro_input!(input as Service); - let camel_case_fn_names: &Vec = &rpcs + let camel_case_fn_names: &[String] = &rpcs .iter() .map(|rpc| snake_to_camel(&rpc.ident.to_string())) - .collect(); - let output_types: &Vec<&Type> = &rpcs - .iter() - .map(|rpc| match rpc.output { - ReturnType::Type(_, ref ty) => ty, - ReturnType::Default => unit_type, - }) - .collect(); - let future_types: &Vec = &camel_case_fn_names - .iter() - .map(|name| format_ident!("{}Fut", name)) - .collect(); - let camel_case_idents: &Vec = &rpcs - .iter() - .zip(camel_case_fn_names.iter()) - .map(|(rpc, name)| Ident::new(name, rpc.ident.span())) - .collect(); + .collect::>(); + let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::>(); + let response_fut_name = &format!("{}ResponseFut", ident); + let derive_serialize = quote!(#[derive(serde::Serialize, serde::Deserialize)]); - let args: &Vec<&Vec> = &rpcs.iter().map(|rpc| &rpc.args).collect(); - let arg_vars: &Vec> = &args - .iter() - .map(|args| args.iter().map(|arg| &*arg.pat).collect()) - .collect(); - let method_names: &Vec<&Ident> = &rpcs.iter().map(|rpc| &rpc.ident).collect(); - let method_attrs: Vec<_> = rpcs.iter().map(|rpc| &rpc.attrs).collect(); + let gen_args = &GenArgs { + attrs, + vis, + rpcs, + args, + response_fut_name, + service_ident: ident, + server_ident: &format_ident!("Serve{}", ident), + response_fut_ident: &Ident::new(&response_fut_name, ident.span()), + client_ident: &format_ident!("{}Client", ident), + request_ident: &format_ident!("{}Request", ident), + response_ident: &format_ident!("{}Response", ident), + method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::>(), + method_names: &rpcs.iter().map(|rpc| &rpc.ident).collect::>(), + return_types: &rpcs + .iter() + .map(|rpc| match rpc.output { + ReturnType::Type(_, ref ty) => ty, + ReturnType::Default => unit_type, + }) + .collect::>(), + arg_vars: &args + .iter() + .map(|args| args.iter().map(|arg| &*arg.pat).collect()) + .collect::>(), + camel_case_idents: &rpcs + .iter() + .zip(camel_case_fn_names.iter()) + .map(|(rpc, name)| Ident::new(name, rpc.ident.span())) + .collect::>(), + future_idents: &camel_case_fn_names + .iter() + .map(|name| format_ident!("{}Fut", name)) + .collect::>(), + derive_serialize: if derive_serde.0 { + Some(&derive_serialize) + } else { + None + }, + }; + let tokens = vec![ + trait_service(gen_args), + struct_server(gen_args), + impl_serve_for_server(gen_args), + enum_request(gen_args), + enum_response(gen_args), + enum_response_future(gen_args), + impl_debug_for_response_future(gen_args), + impl_future_for_response_future(gen_args), + struct_client(gen_args), + impl_from_for_client(gen_args), + impl_client_new(gen_args), + impl_client_rpc_methods(gen_args), + ]; + + tokens + .into_iter() + .collect::() + .into() +} + +// Things needed to generate the service items: trait, serve impl, request/response enums, and +// the client stub. +struct GenArgs<'a> { + attrs: &'a [Attribute], + rpcs: &'a [RpcMethod], + service_ident: &'a Ident, + server_ident: &'a Ident, + response_fut_ident: &'a Ident, + response_fut_name: &'a str, + client_ident: &'a Ident, + request_ident: &'a Ident, + response_ident: &'a Ident, + method_attrs: &'a [&'a [Attribute]], + vis: &'a Visibility, + method_names: &'a [&'a Ident], + args: &'a [&'a [PatType]], + return_types: &'a [&'a Type], + arg_vars: &'a [Vec<&'a Pat>], + camel_case_idents: &'a [Ident], + future_idents: &'a [Ident], + derive_serialize: Option<&'a proc_macro2::TokenStream>, +} + +fn trait_service( + &GenArgs { + attrs, + rpcs, + vis, + future_idents, + return_types, + service_ident, + server_ident, + .. + }: &GenArgs, +) -> proc_macro2::TokenStream { let types_and_fns = rpcs .iter() - .zip(future_types.iter()) - .zip(output_types.iter()) + .zip(future_idents.iter()) + .zip(return_types.iter()) .map( |( ( @@ -212,24 +286,9 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { }, ); - let service_name = &ident; - - let client_ident = format_ident!("{}Client", ident); - let request_ident = format_ident!("{}Request", ident); - let response_ident = format_ident!("{}Response", ident); - let response_fut_name = format!("{}ResponseFut", ident); - let response_fut_ident = Ident::new(&response_fut_name, ident.span()); - let server_ident = format_ident!("Serve{}", ident); - - let derive_serialize = if derive_serde.0 { - Some(quote!(#[derive(serde::Serialize, serde::Deserialize)])) - } else { - None - }; - - let tokens = quote! { + quote! { #( #attrs )* - #vis trait #ident: Clone { + #vis trait #service_ident: Clone { #( #types_and_fns )* /// Returns a serving function to use with tarpc::server::Server. @@ -237,14 +296,38 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { #server_ident { service: self } } } + } +} +fn struct_server( + &GenArgs { + vis, server_ident, .. + }: &GenArgs, +) -> proc_macro2::TokenStream { + quote! { #[derive(Clone)] #vis struct #server_ident { service: S, } + } +} +fn impl_serve_for_server( + &GenArgs { + request_ident, + server_ident, + service_ident, + response_ident, + response_fut_ident, + camel_case_idents, + arg_vars, + method_names, + .. + }: &GenArgs, +) -> proc_macro2::TokenStream { + quote! { impl tarpc::server::Serve<#request_ident> for #server_ident - where S: #ident + where S: #service_ident { type Resp = #response_ident; type Fut = #response_fut_ident; @@ -254,40 +337,102 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { #( #request_ident::#camel_case_idents{ #( #arg_vars ),* } => { #response_fut_ident::#camel_case_idents( - #service_name::#method_names( + #service_ident::#method_names( self.service, ctx, #( #arg_vars ),*)) } )* } } } + } +} +fn enum_request( + &GenArgs { + derive_serialize, + vis, + request_ident, + camel_case_idents, + args, + .. + }: &GenArgs, +) -> proc_macro2::TokenStream { + quote! { /// The request sent over the wire from the client to the server. #[derive(Debug)] #derive_serialize #vis enum #request_ident { #( #camel_case_idents{ #( #args ),* } ),* } + } +} +fn enum_response( + &GenArgs { + derive_serialize, + vis, + response_ident, + camel_case_idents, + return_types, + .. + }: &GenArgs, +) -> proc_macro2::TokenStream { + quote! { /// The response sent over the wire from the server to the client. #[derive(Debug)] #derive_serialize #vis enum #response_ident { - #( #camel_case_idents(#output_types) ),* + #( #camel_case_idents(#return_types) ),* } + } +} +fn enum_response_future( + &GenArgs { + vis, + service_ident, + response_fut_ident, + camel_case_idents, + future_idents, + .. + }: &GenArgs, +) -> proc_macro2::TokenStream { + quote! { /// A future resolving to a server response. - #vis enum #response_fut_ident { - #( #camel_case_idents(::#future_types) ),* + #vis enum #response_fut_ident { + #( #camel_case_idents(::#future_idents) ),* } + } +} - impl std::fmt::Debug for #response_fut_ident { +fn impl_debug_for_response_future( + &GenArgs { + service_ident, + response_fut_ident, + response_fut_name, + .. + }: &GenArgs, +) -> proc_macro2::TokenStream { + quote! { + impl std::fmt::Debug for #response_fut_ident { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { fmt.debug_struct(#response_fut_name).finish() } } + } +} - impl std::future::Future for #response_fut_ident { +fn impl_future_for_response_future( + &GenArgs { + service_ident, + response_fut_ident, + response_ident, + camel_case_idents, + .. + }: &GenArgs, +) -> proc_macro2::TokenStream { + quote! { + impl std::future::Future for #response_fut_ident { type Output = #response_ident; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) @@ -305,12 +450,35 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { } } } + } +} +fn struct_client( + &GenArgs { + vis, + client_ident, + request_ident, + response_ident, + .. + }: &GenArgs, +) -> proc_macro2::TokenStream { + quote! { #[allow(unused)] #[derive(Clone, Debug)] /// The client stub that makes RPC calls to the server. Exposes a Future interface. #vis struct #client_ident>(C); + } +} +fn impl_from_for_client( + &GenArgs { + client_ident, + request_ident, + response_ident, + .. + }: &GenArgs, +) -> proc_macro2::TokenStream { + quote! { impl From for #client_ident where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident> { @@ -318,7 +486,19 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { #client_ident(client) } } + } +} +fn impl_client_new( + &GenArgs { + client_ident, + vis, + request_ident, + response_ident, + .. + }: &GenArgs, +) -> proc_macro2::TokenStream { + quote! { impl #client_ident { /// Returns a new client stub that sends requests over the given transport. #vis fn new(config: tarpc::client::Config, transport: T) @@ -336,7 +516,25 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { } } + } +} +fn impl_client_rpc_methods( + &GenArgs { + client_ident, + request_ident, + response_ident, + method_attrs, + vis, + method_names, + args, + return_types, + arg_vars, + camel_case_idents, + .. + }: &GenArgs, +) -> proc_macro2::TokenStream { + quote! { impl #client_ident where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident> { @@ -344,7 +542,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { #[allow(unused)] #( #method_attrs )* #vis fn #method_names(&mut self, ctx: tarpc::context::Context, #( #args ),*) - -> impl std::future::Future> + '_ { + -> impl std::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_vars ),* }; let resp = tarpc::Client::call(&mut self.0, ctx, request); async move { @@ -356,9 +554,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { } )* } - }; - - tokens.into() + } } fn snake_to_camel(ident_str: &str) -> String {