Updated and simplified macros (#290)

* syn updated to latest version
* quote updated to latest version
* proc-macro-2 updated to latest version
* Performance improvements
* Don't create unnecessary TokenStreams for output types
This commit is contained in:
Oleg Nosov
2019-12-11 23:28:24 +03:00
committed by Tim
parent 45af6ccdeb
commit 85d49477f5
3 changed files with 75 additions and 101 deletions

View File

@@ -19,9 +19,9 @@ serde1 = []
travis-ci = { repository = "google/tarpc" }
[dependencies]
syn = { version = "0.15", features = ["full"] }
quote = "0.6"
proc-macro2 = "0.4"
syn = { version = "1.0.11", features = ["full"] }
quote = "1.0.2"
proc-macro2 = "1.0.6"
[lib]
proc-macro = true

View File

@@ -12,16 +12,14 @@ extern crate quote;
extern crate syn;
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use quote::{format_ident, quote};
use syn::{
braced, parenthesized,
parse::{Parse, ParseStream},
parse_macro_input,
parse_macro_input, parse_quote,
punctuated::Punctuated,
spanned::Spanned,
token::Comma,
ArgCaptured, Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, ReturnType, Token,
Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type,
Visibility,
};
@@ -35,7 +33,7 @@ struct Service {
struct RpcMethod {
attrs: Vec<Attribute>,
ident: Ident,
args: Punctuated<ArgCaptured, Comma>,
args: Vec<PatType>,
output: ReturnType,
}
@@ -53,22 +51,19 @@ impl Parse for Service {
}
for rpc in &rpcs {
if rpc.ident == "new" {
return Err(syn::Error::new(
rpc.ident.span(),
format!(
"method name conflicts with generated fn `{}Client::new`",
ident
),
));
return Err(input.error(format!(
"method name conflicts with generated fn `{}Client::new`",
ident
)));
}
if rpc.ident == "serve" {
return Err(syn::Error::new(
rpc.ident.span(),
format!("method name conflicts with generated fn `{}::serve`", ident),
));
return Err(input.error(format!(
"method name conflicts with generated fn `{}::serve`",
ident
)));
}
}
Ok(Service {
Ok(Self {
attrs,
vis,
ident,
@@ -89,31 +84,17 @@ impl Parse for RpcMethod {
let args = args
.into_iter()
.map(|arg| match arg {
FnArg::Captured(captured) => match captured.pat {
FnArg::Typed(captured) => match *captured.pat {
Pat::Ident(_) => Ok(captured),
_ => Err(syn::Error::new(
captured.pat.span(),
"patterns aren't allowed in RPC args",
)),
_ => Err(input.error("patterns aren't allowed in RPC args")),
},
FnArg::SelfRef(self_ref) => Err(syn::Error::new(
self_ref.span(),
"method args cannot start with self",
)),
FnArg::SelfValue(self_val) => Err(syn::Error::new(
self_val.span(),
"method args cannot start with self",
)),
arg => Err(syn::Error::new(
arg.span(),
"method args must be explicitly typed patterns",
)),
FnArg::Receiver(_) => Err(input.error("method args cannot start with self")),
})
.collect::<Result<_, _>>()?;
let output = input.parse()?;
input.parse::<Token![;]>()?;
Ok(RpcMethod {
Ok(Self {
attrs,
ident,
args,
@@ -129,29 +110,30 @@ struct DeriveSerde(bool);
impl Parse for DeriveSerde {
fn parse(input: ParseStream) -> syn::Result<Self> {
if input.is_empty() {
return Ok(DeriveSerde(cfg!(feature = "serde1")));
return Ok(Self(cfg!(feature = "serde1")));
}
match input.parse::<MetaNameValue>()? {
MetaNameValue {
ref ident, ref lit, ..
} if ident == "derive_serde" => match lit {
Lit::Bool(LitBool { value: true, .. }) if cfg!(feature = "serde1") => {
Ok(DeriveSerde(true))
ref path, ref lit, ..
} if path.segments.len() == 1
&& path.segments.first().unwrap().ident == "derive_serde" =>
{
match lit {
Lit::Bool(LitBool { value: true, .. }) if cfg!(feature = "serde1") => {
Ok(Self(true))
}
Lit::Bool(LitBool { value: true, .. }) => {
Err(input
.error("To enable serde, first enable the `serde1` feature of tarpc"))
}
Lit::Bool(LitBool { value: false, .. }) => Ok(Self(false)),
_ => Err(input.error("`derive_serde` expects a value of type `bool`")),
}
Lit::Bool(LitBool { value: true, .. }) => Err(syn::Error::new(
lit.span(),
"To enable serde, first enable the `serde1` feature of tarpc",
)),
Lit::Bool(LitBool { value: false, .. }) => Ok(DeriveSerde(false)),
lit => Err(syn::Error::new(
lit.span(),
"`derive_serde` expects a value of type `bool`",
)),
},
MetaNameValue { ident, .. } => Err(syn::Error::new(
ident.span(),
"tarpc::service only supports one meta item, `derive_serde = {bool}`",
)),
}
_ => {
Err(input
.error("tarpc::service only supports one meta item, `derive_serde = {bool}`"))
}
}
}
}
@@ -167,49 +149,48 @@ impl Parse for DeriveSerde {
pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
let derive_serde = parse_macro_input!(attr as DeriveSerde);
let unit_type: &Type = &parse_quote!(());
let Service {
attrs,
vis,
ident,
rpcs,
} = parse_macro_input!(input as Service);
let vis_repeated = std::iter::repeat(vis.clone());
let camel_case_fn_names: Vec<String> = rpcs
let camel_case_fn_names: &Vec<String> = &rpcs
.iter()
.map(|rpc| snake_to_camel(&rpc.ident.to_string()))
.collect();
let outputs: &Vec<TokenStream2> = &rpcs
let output_types: &Vec<&Type> = &rpcs
.iter()
.map(|rpc| match rpc.output {
ReturnType::Type(_, ref ty) => quote!(#ty),
ReturnType::Default => quote!(()),
ReturnType::Type(_, ref ty) => ty,
ReturnType::Default => unit_type,
})
.collect();
let future_types: Vec<Ident> = camel_case_fn_names
let future_types: &Vec<Ident> = &camel_case_fn_names
.iter()
.map(|name| Ident::new(&format!("{}Fut", name), ident.span()))
.map(|name| format_ident!("{}Fut", name))
.collect();
let camel_case_idents: &Vec<Ident> = &rpcs
.iter()
.zip(camel_case_fn_names.iter())
.map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
.collect();
let camel_case_idents2 = camel_case_idents;
let args: &Vec<&Punctuated<ArgCaptured, Comma>> = &rpcs.iter().map(|rpc| &rpc.args).collect();
let arg_vars: &Vec<Punctuated<&Pat, Comma>> = &args
let args: &Vec<&Vec<PatType>> = &rpcs.iter().map(|rpc| &rpc.args).collect();
let arg_vars: &Vec<Vec<&Pat>> = &args
.iter()
.map(|args| args.iter().map(|arg| &arg.pat).collect())
.map(|args| args.iter().map(|arg| &*arg.pat).collect())
.collect();
let arg_vars2 = arg_vars;
let method_names: &Vec<&Ident> = &rpcs.iter().map(|rpc| &rpc.ident).collect();
let method_attrs: Vec<_> = rpcs.iter().map(|rpc| &rpc.attrs).collect();
let types_and_fns = rpcs
.iter()
.zip(future_types.iter())
.zip(outputs.iter())
.zip(output_types.iter())
.map(
|(
(
@@ -226,31 +207,24 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
type #future_type: std::future::Future<Output = #output>;
#( #attrs )*
fn #ident(self, context: tarpc::context::Context, #args) -> Self::#future_type;
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type;
}
},
);
let service_name_repeated = std::iter::repeat(ident.clone());
let service_name_repeated2 = service_name_repeated.clone();
let service_name = &ident;
let client_ident = Ident::new(&format!("{}Client", ident), ident.span());
let request_ident = Ident::new(&format!("{}Request", ident), ident.span());
let request_ident_repeated = std::iter::repeat(request_ident.clone());
let request_ident_repeated2 = request_ident_repeated.clone();
let response_ident = Ident::new(&format!("{}Response", ident), ident.span());
let response_ident_repeated = std::iter::repeat(response_ident.clone());
let response_ident_repeated2 = response_ident_repeated.clone();
let client_ident = format_ident!("{}Client", ident);
let request_ident = format_ident!("{}Request", ident);
let response_ident = format_ident!("{}Response", ident);
let response_fut_name = format!("{}ResponseFut", ident);
let response_fut_ident = Ident::new(&response_fut_name, ident.span());
let response_fut_ident_repeated = std::iter::repeat(response_fut_ident.clone());
let response_fut_ident_repeated2 = response_fut_ident_repeated.clone();
let server_ident = Ident::new(&format!("Serve{}", ident), ident.span());
let server_ident = format_ident!("Serve{}", ident);
let derive_serialize = if derive_serde.0 {
quote!(#[derive(serde::Serialize, serde::Deserialize)])
Some(quote!(#[derive(serde::Serialize, serde::Deserialize)]))
} else {
quote!()
None
};
let tokens = quote! {
@@ -278,10 +252,10 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
match req {
#(
#request_ident_repeated::#camel_case_idents{ #arg_vars } => {
#response_fut_ident_repeated2::#camel_case_idents2(
#service_name_repeated2::#method_names(
self.service, ctx, #arg_vars2))
#request_ident::#camel_case_idents{ #( #arg_vars ),* } => {
#response_fut_ident::#camel_case_idents(
#service_name::#method_names(
self.service, ctx, #( #arg_vars ),*))
}
)*
}
@@ -292,19 +266,19 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
#[derive(Debug)]
#derive_serialize
#vis enum #request_ident {
#( #camel_case_idents{ #args } ),*
#( #camel_case_idents{ #( #args ),* } ),*
}
/// The response sent over the wire from the server to the client.
#[derive(Debug)]
#derive_serialize
#vis enum #response_ident {
#( #camel_case_idents(#outputs) ),*
#( #camel_case_idents(#output_types) ),*
}
/// A future resolving to a server response.
#vis enum #response_fut_ident<S: #ident> {
#( #camel_case_idents(<S as #service_name_repeated>::#future_types) ),*
#( #camel_case_idents(<S as #service_name>::#future_types) ),*
}
impl<S: #ident> std::fmt::Debug for #response_fut_ident<S> {
@@ -322,10 +296,10 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
unsafe {
match std::pin::Pin::get_unchecked_mut(self) {
#(
#response_fut_ident_repeated::#camel_case_idents(resp) =>
#response_fut_ident::#camel_case_idents(resp) =>
std::pin::Pin::new_unchecked(resp)
.poll(cx)
.map(#response_ident_repeated::#camel_case_idents2),
.map(#response_ident::#camel_case_idents),
)*
}
}
@@ -369,13 +343,13 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
#(
#[allow(unused)]
#( #method_attrs )*
#vis_repeated fn #method_names(&mut self, ctx: tarpc::context::Context, #args)
-> impl std::future::Future<Output = std::io::Result<#outputs>> + '_ {
let request = #request_ident_repeated2::#camel_case_idents { #arg_vars };
#vis fn #method_names(&mut self, ctx: tarpc::context::Context, #( #args ),*)
-> impl std::future::Future<Output = std::io::Result<#output_types>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_vars ),* };
let resp = tarpc::Client::call(&mut self.0, ctx, request);
async move {
match resp.await? {
#response_ident_repeated2::#camel_case_idents2(msg) => std::result::Result::Ok(msg),
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
_ => unreachable!(),
}
}
@@ -388,8 +362,8 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
}
fn snake_to_camel(ident_str: &str) -> String {
let mut camel_ty = String::new();
let chars = ident_str.chars();
let mut camel_ty = String::with_capacity(ident_str.len());
let mut last_char_was_underscore = true;
for c in chars {
@@ -403,6 +377,7 @@ fn snake_to_camel(ident_str: &str) -> String {
}
}
camel_ty.shrink_to_fit();
camel_ty
}

View File

@@ -37,7 +37,6 @@ serde = { optional = true, version = "1.0", features = ["derive"] }
tokio = { optional = true, version = "0.2", features = ["time"] }
tokio-util = { optional = true, version = "0.2" }
tarpc-plugins = { path = "../plugins" }
tokio-serde = { optional = true, version = "0.6" }
[dev-dependencies]