diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index ca63107..b6ce0cf 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -12,7 +12,7 @@ extern crate quote; extern crate syn; use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; +use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{format_ident, quote, ToTokens}; use syn::{ braced, @@ -20,12 +20,24 @@ use syn::{ parenthesized, parse::{Parse, ParseStream}, parse_macro_input, parse_quote, parse_str, - punctuated::Punctuated, + spanned::Spanned, token::Comma, Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility, }; +/// Accumulates multiple errors into a result. +/// Only use this for recoverable errors, i.e. non-parse errors. Fatal errors should early exit to +/// avoid further complications. +macro_rules! extend_errors { + ($errors: ident, $e: expr) => { + match $errors { + Ok(_) => $errors = Err($e), + Err(ref mut errors) => errors.extend($e), + } + }; +} + struct Service { attrs: Vec, vis: Visibility, @@ -52,20 +64,31 @@ impl Parse for Service { while !content.is_empty() { rpcs.push(content.parse()?); } + let mut ident_errors = Ok(()); for rpc in &rpcs { if rpc.ident == "new" { - return Err(input.error(format!( - "method name conflicts with generated fn `{}Client::new`", - ident.unraw() - ))); + extend_errors!( + ident_errors, + syn::Error::new( + rpc.ident.span(), + format!( + "method name conflicts with generated fn `{}Client::new`", + ident.unraw() + ) + ) + ); } if rpc.ident == "serve" { - return Err(input.error(format!( - "method name conflicts with generated fn `{}::serve`", - ident - ))); + extend_errors!( + ident_errors, + syn::Error::new( + rpc.ident.span(), + format!("method name conflicts with generated fn `{}::serve`", ident) + ) + ); } } + ident_errors?; Ok(Self { attrs, @@ -84,17 +107,28 @@ impl Parse for RpcMethod { let ident = input.parse()?; let content; parenthesized!(content in input); - let args: Punctuated = content.parse_terminated(FnArg::parse)?; - let args = args - .into_iter() - .map(|arg| match arg { - FnArg::Typed(captured) => match *captured.pat { - Pat::Ident(_) => Ok(captured), - _ => Err(input.error("patterns aren't allowed in RPC args")), - }, - FnArg::Receiver(_) => Err(input.error("method args cannot start with self")), - }) - .collect::>()?; + let mut args = Vec::new(); + let mut errors = Ok(()); + for arg in content.parse_terminated::(FnArg::parse)? { + match arg { + FnArg::Typed(captured) if matches!(&*captured.pat, Pat::Ident(_)) => { + args.push(captured); + } + FnArg::Typed(captured) => { + extend_errors!( + errors, + syn::Error::new(captured.pat.span(), "patterns aren't allowed in RPC args") + ); + } + FnArg::Receiver(_) => { + extend_errors!( + errors, + syn::Error::new(arg.span(), "method args cannot start with self") + ); + } + } + } + errors?; let output = input.parse()?; input.parse::()?; @@ -113,32 +147,71 @@ struct DeriveSerde(bool); impl Parse for DeriveSerde { fn parse(input: ParseStream) -> syn::Result { - if input.is_empty() { - return Ok(Self(cfg!(feature = "serde1"))); - } - match input.parse::()? { - MetaNameValue { - ref path, ref lit, .. - } if path.segments.len() == 1 - && path.segments.first().unwrap().ident == "derive_serde" => - { - match lit { - Lit::Bool(LitBool { value: true, .. }) if cfg!(feature = "serde1") => { - Ok(Self(true)) - } - Lit::Bool(LitBool { value: true, .. }) => { - Err(input - .error("To enable serde, first enable the `serde1` feature of tarpc")) - } - Lit::Bool(LitBool { value: false, .. }) => Ok(Self(false)), - _ => Err(input.error("`derive_serde` expects a value of type `bool`")), + let mut result = Ok(None); + let mut derive_serde = Vec::new(); + let meta_items = input.parse_terminated::(MetaNameValue::parse)?; + for meta in meta_items { + if meta.path.segments.len() != 1 { + extend_errors!( + result, + syn::Error::new( + meta.span(), + format!("tarpc::service does not support this meta item") + ) + ); + continue; + } + let segment = meta.path.segments.first().unwrap(); + if segment.ident != "derive_serde" { + extend_errors!( + result, + syn::Error::new( + meta.span(), + format!("tarpc::service does not support this meta item") + ) + ); + continue; + } + match meta.lit { + Lit::Bool(LitBool { value: true, .. }) if cfg!(feature = "serde1") => { + result = result.and(Ok(Some(true))) } + Lit::Bool(LitBool { value: true, .. }) => { + extend_errors!( + result, + syn::Error::new( + meta.span(), + "To enable serde, first enable the `serde1` feature of tarpc" + ) + ); + } + Lit::Bool(LitBool { value: false, .. }) => result = result.and(Ok(Some(false))), + _ => extend_errors!( + result, + syn::Error::new( + meta.lit.span(), + "`derive_serde` expects a value of type `bool`" + ) + ), } - _ => { - Err(input - .error("tarpc::service only supports one meta item, `derive_serde = {bool}`")) + derive_serde.push(meta); + } + if derive_serde.len() > 1 { + for (i, derive_serde) in derive_serde.iter().enumerate() { + extend_errors!( + result, + syn::Error::new( + derive_serde.span(), + format!( + "`derive_serde` appears more than once (occurrence #{})", + i + 1 + ) + ) + ); } } + let derive_serde = result?.unwrap_or(cfg!(feature = "serde1")); + Ok(Self(derive_serde)) } } @@ -212,6 +285,12 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .into() } +/// generate an identifier consisting of the method name to CamelCase with +/// Fut appended to it. +fn associated_type_for_rpc(method: &ImplItemMethod) -> String { + snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut" +} + /// Transforms an async function into a sync one, returning a type declaration /// for the return type (a future). fn transform_method(method: &mut ImplItemMethod) -> ImplItemType { @@ -223,9 +302,7 @@ fn transform_method(method: &mut ImplItemMethod) -> ImplItemType { ReturnType::Type(_, ret) => quote!(#ret), }; - // generate an identifier consisting of the method name to CamelCase with - // Fut appended to it. - let fut_name = snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut"; + let fut_name = associated_type_for_rpc(method); let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span()); // generate the updated return signature. @@ -308,22 +385,37 @@ fn transform_method(method: &mut ImplItemMethod) -> ImplItemType { #[proc_macro_attribute] pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream { let mut item = syn::parse_macro_input!(input as ItemImpl); + let span = item.span(); // the generated type declarations let mut types: Vec = Vec::new(); + let mut expected_non_async_types: Vec<(&ImplItemMethod, String)> = Vec::new(); + let mut found_non_async_types: Vec<&ImplItemType> = Vec::new(); for inner in &mut item.items { - if let ImplItem::Method(method) = inner { - let sig = &method.sig; - - // if this function is declared async, transform it into a regular function - if sig.asyncness.is_some() { - let typedecl = transform_method(method); - types.push(typedecl); + match inner { + ImplItem::Method(method) => { + if method.sig.asyncness.is_some() { + // if this function is declared async, transform it into a regular function + let typedecl = transform_method(method); + types.push(typedecl); + } else { + // If it's not async, keep track of all required associated types for better + // error reporting. + expected_non_async_types.push((method, associated_type_for_rpc(method))); + } } + ImplItem::Type(typedecl) => found_non_async_types.push(typedecl), + _ => {} } } + if let Err(e) = + verify_types_were_provided(span, &expected_non_async_types, &found_non_async_types) + { + return TokenStream::from(e.to_compile_error()); + } + // add the type declarations into the impl block for t in types.into_iter() { item.items.push(syn::ImplItem::Type(t)); @@ -332,6 +424,39 @@ pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream { TokenStream::from(quote!(#item)) } +fn verify_types_were_provided( + span: Span, + expected: &[(&ImplItemMethod, String)], + provided: &[&ImplItemType], +) -> syn::Result<()> { + let mut result = Ok(()); + for (method, expected) in expected { + if provided + .iter() + .find(|typedecl| typedecl.ident == expected) + .is_none() + { + let mut e = syn::Error::new( + span, + format!("not all trait items implemented, missing: `{}`", expected), + ); + let fn_span = method.sig.fn_token.span(); + e.extend(syn::Error::new( + fn_span.join(method.sig.ident.span()).unwrap_or(fn_span), + format!( + "hint: `#[tarpc::server]` only rewrites async fns, and `fn {}` is not async", + method.sig.ident + ), + )); + match result { + Ok(_) => result = Err(e), + Err(ref mut error) => error.extend(Some(e)), + } + } + } + result +} + // Things needed to generate the service items: trait, serve impl, request/response enums, and // the client stub. struct ServiceGenerator<'a> {