mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-01-25 03:50:29 +01:00
Add better diagnostics for missing 'async' in impls using #[tarpc::server]
This commit is contained in:
@@ -12,7 +12,7 @@ extern crate quote;
|
||||
extern crate syn;
|
||||
|
||||
use proc_macro::TokenStream;
|
||||
use proc_macro2::TokenStream as TokenStream2;
|
||||
use proc_macro2::{Span, TokenStream as TokenStream2};
|
||||
use quote::{format_ident, quote, ToTokens};
|
||||
use syn::{
|
||||
braced,
|
||||
@@ -20,12 +20,24 @@ use syn::{
|
||||
parenthesized,
|
||||
parse::{Parse, ParseStream},
|
||||
parse_macro_input, parse_quote, parse_str,
|
||||
punctuated::Punctuated,
|
||||
spanned::Spanned,
|
||||
token::Comma,
|
||||
Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool,
|
||||
MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility,
|
||||
};
|
||||
|
||||
/// Accumulates multiple errors into a result.
|
||||
/// Only use this for recoverable errors, i.e. non-parse errors. Fatal errors should early exit to
|
||||
/// avoid further complications.
|
||||
macro_rules! extend_errors {
|
||||
($errors: ident, $e: expr) => {
|
||||
match $errors {
|
||||
Ok(_) => $errors = Err($e),
|
||||
Err(ref mut errors) => errors.extend($e),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
struct Service {
|
||||
attrs: Vec<Attribute>,
|
||||
vis: Visibility,
|
||||
@@ -52,20 +64,31 @@ impl Parse for Service {
|
||||
while !content.is_empty() {
|
||||
rpcs.push(content.parse()?);
|
||||
}
|
||||
let mut ident_errors = Ok(());
|
||||
for rpc in &rpcs {
|
||||
if rpc.ident == "new" {
|
||||
return Err(input.error(format!(
|
||||
"method name conflicts with generated fn `{}Client::new`",
|
||||
ident.unraw()
|
||||
)));
|
||||
extend_errors!(
|
||||
ident_errors,
|
||||
syn::Error::new(
|
||||
rpc.ident.span(),
|
||||
format!(
|
||||
"method name conflicts with generated fn `{}Client::new`",
|
||||
ident.unraw()
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
if rpc.ident == "serve" {
|
||||
return Err(input.error(format!(
|
||||
"method name conflicts with generated fn `{}::serve`",
|
||||
ident
|
||||
)));
|
||||
extend_errors!(
|
||||
ident_errors,
|
||||
syn::Error::new(
|
||||
rpc.ident.span(),
|
||||
format!("method name conflicts with generated fn `{}::serve`", ident)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
ident_errors?;
|
||||
|
||||
Ok(Self {
|
||||
attrs,
|
||||
@@ -84,17 +107,28 @@ impl Parse for RpcMethod {
|
||||
let ident = input.parse()?;
|
||||
let content;
|
||||
parenthesized!(content in input);
|
||||
let args: Punctuated<FnArg, Comma> = content.parse_terminated(FnArg::parse)?;
|
||||
let args = args
|
||||
.into_iter()
|
||||
.map(|arg| match arg {
|
||||
FnArg::Typed(captured) => match *captured.pat {
|
||||
Pat::Ident(_) => Ok(captured),
|
||||
_ => Err(input.error("patterns aren't allowed in RPC args")),
|
||||
},
|
||||
FnArg::Receiver(_) => Err(input.error("method args cannot start with self")),
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
let mut args = Vec::new();
|
||||
let mut errors = Ok(());
|
||||
for arg in content.parse_terminated::<FnArg, Comma>(FnArg::parse)? {
|
||||
match arg {
|
||||
FnArg::Typed(captured) if matches!(&*captured.pat, Pat::Ident(_)) => {
|
||||
args.push(captured);
|
||||
}
|
||||
FnArg::Typed(captured) => {
|
||||
extend_errors!(
|
||||
errors,
|
||||
syn::Error::new(captured.pat.span(), "patterns aren't allowed in RPC args")
|
||||
);
|
||||
}
|
||||
FnArg::Receiver(_) => {
|
||||
extend_errors!(
|
||||
errors,
|
||||
syn::Error::new(arg.span(), "method args cannot start with self")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
errors?;
|
||||
let output = input.parse()?;
|
||||
input.parse::<Token![;]>()?;
|
||||
|
||||
@@ -113,32 +147,71 @@ struct DeriveSerde(bool);
|
||||
|
||||
impl Parse for DeriveSerde {
|
||||
fn parse(input: ParseStream) -> syn::Result<Self> {
|
||||
if input.is_empty() {
|
||||
return Ok(Self(cfg!(feature = "serde1")));
|
||||
}
|
||||
match input.parse::<MetaNameValue>()? {
|
||||
MetaNameValue {
|
||||
ref path, ref lit, ..
|
||||
} if path.segments.len() == 1
|
||||
&& path.segments.first().unwrap().ident == "derive_serde" =>
|
||||
{
|
||||
match lit {
|
||||
Lit::Bool(LitBool { value: true, .. }) if cfg!(feature = "serde1") => {
|
||||
Ok(Self(true))
|
||||
}
|
||||
Lit::Bool(LitBool { value: true, .. }) => {
|
||||
Err(input
|
||||
.error("To enable serde, first enable the `serde1` feature of tarpc"))
|
||||
}
|
||||
Lit::Bool(LitBool { value: false, .. }) => Ok(Self(false)),
|
||||
_ => Err(input.error("`derive_serde` expects a value of type `bool`")),
|
||||
let mut result = Ok(None);
|
||||
let mut derive_serde = Vec::new();
|
||||
let meta_items = input.parse_terminated::<MetaNameValue, Comma>(MetaNameValue::parse)?;
|
||||
for meta in meta_items {
|
||||
if meta.path.segments.len() != 1 {
|
||||
extend_errors!(
|
||||
result,
|
||||
syn::Error::new(
|
||||
meta.span(),
|
||||
format!("tarpc::service does not support this meta item")
|
||||
)
|
||||
);
|
||||
continue;
|
||||
}
|
||||
let segment = meta.path.segments.first().unwrap();
|
||||
if segment.ident != "derive_serde" {
|
||||
extend_errors!(
|
||||
result,
|
||||
syn::Error::new(
|
||||
meta.span(),
|
||||
format!("tarpc::service does not support this meta item")
|
||||
)
|
||||
);
|
||||
continue;
|
||||
}
|
||||
match meta.lit {
|
||||
Lit::Bool(LitBool { value: true, .. }) if cfg!(feature = "serde1") => {
|
||||
result = result.and(Ok(Some(true)))
|
||||
}
|
||||
Lit::Bool(LitBool { value: true, .. }) => {
|
||||
extend_errors!(
|
||||
result,
|
||||
syn::Error::new(
|
||||
meta.span(),
|
||||
"To enable serde, first enable the `serde1` feature of tarpc"
|
||||
)
|
||||
);
|
||||
}
|
||||
Lit::Bool(LitBool { value: false, .. }) => result = result.and(Ok(Some(false))),
|
||||
_ => extend_errors!(
|
||||
result,
|
||||
syn::Error::new(
|
||||
meta.lit.span(),
|
||||
"`derive_serde` expects a value of type `bool`"
|
||||
)
|
||||
),
|
||||
}
|
||||
_ => {
|
||||
Err(input
|
||||
.error("tarpc::service only supports one meta item, `derive_serde = {bool}`"))
|
||||
derive_serde.push(meta);
|
||||
}
|
||||
if derive_serde.len() > 1 {
|
||||
for (i, derive_serde) in derive_serde.iter().enumerate() {
|
||||
extend_errors!(
|
||||
result,
|
||||
syn::Error::new(
|
||||
derive_serde.span(),
|
||||
format!(
|
||||
"`derive_serde` appears more than once (occurrence #{})",
|
||||
i + 1
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
let derive_serde = result?.unwrap_or(cfg!(feature = "serde1"));
|
||||
Ok(Self(derive_serde))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,6 +285,12 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
.into()
|
||||
}
|
||||
|
||||
/// generate an identifier consisting of the method name to CamelCase with
|
||||
/// Fut appended to it.
|
||||
fn associated_type_for_rpc(method: &ImplItemMethod) -> String {
|
||||
snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut"
|
||||
}
|
||||
|
||||
/// Transforms an async function into a sync one, returning a type declaration
|
||||
/// for the return type (a future).
|
||||
fn transform_method(method: &mut ImplItemMethod) -> ImplItemType {
|
||||
@@ -223,9 +302,7 @@ fn transform_method(method: &mut ImplItemMethod) -> ImplItemType {
|
||||
ReturnType::Type(_, ret) => quote!(#ret),
|
||||
};
|
||||
|
||||
// generate an identifier consisting of the method name to CamelCase with
|
||||
// Fut appended to it.
|
||||
let fut_name = snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut";
|
||||
let fut_name = associated_type_for_rpc(method);
|
||||
let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span());
|
||||
|
||||
// generate the updated return signature.
|
||||
@@ -308,22 +385,37 @@ fn transform_method(method: &mut ImplItemMethod) -> ImplItemType {
|
||||
#[proc_macro_attribute]
|
||||
pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let mut item = syn::parse_macro_input!(input as ItemImpl);
|
||||
let span = item.span();
|
||||
|
||||
// the generated type declarations
|
||||
let mut types: Vec<ImplItemType> = Vec::new();
|
||||
let mut expected_non_async_types: Vec<(&ImplItemMethod, String)> = Vec::new();
|
||||
let mut found_non_async_types: Vec<&ImplItemType> = Vec::new();
|
||||
|
||||
for inner in &mut item.items {
|
||||
if let ImplItem::Method(method) = inner {
|
||||
let sig = &method.sig;
|
||||
|
||||
// if this function is declared async, transform it into a regular function
|
||||
if sig.asyncness.is_some() {
|
||||
let typedecl = transform_method(method);
|
||||
types.push(typedecl);
|
||||
match inner {
|
||||
ImplItem::Method(method) => {
|
||||
if method.sig.asyncness.is_some() {
|
||||
// if this function is declared async, transform it into a regular function
|
||||
let typedecl = transform_method(method);
|
||||
types.push(typedecl);
|
||||
} else {
|
||||
// If it's not async, keep track of all required associated types for better
|
||||
// error reporting.
|
||||
expected_non_async_types.push((method, associated_type_for_rpc(method)));
|
||||
}
|
||||
}
|
||||
ImplItem::Type(typedecl) => found_non_async_types.push(typedecl),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) =
|
||||
verify_types_were_provided(span, &expected_non_async_types, &found_non_async_types)
|
||||
{
|
||||
return TokenStream::from(e.to_compile_error());
|
||||
}
|
||||
|
||||
// add the type declarations into the impl block
|
||||
for t in types.into_iter() {
|
||||
item.items.push(syn::ImplItem::Type(t));
|
||||
@@ -332,6 +424,39 @@ pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
TokenStream::from(quote!(#item))
|
||||
}
|
||||
|
||||
fn verify_types_were_provided(
|
||||
span: Span,
|
||||
expected: &[(&ImplItemMethod, String)],
|
||||
provided: &[&ImplItemType],
|
||||
) -> syn::Result<()> {
|
||||
let mut result = Ok(());
|
||||
for (method, expected) in expected {
|
||||
if provided
|
||||
.iter()
|
||||
.find(|typedecl| typedecl.ident == expected)
|
||||
.is_none()
|
||||
{
|
||||
let mut e = syn::Error::new(
|
||||
span,
|
||||
format!("not all trait items implemented, missing: `{}`", expected),
|
||||
);
|
||||
let fn_span = method.sig.fn_token.span();
|
||||
e.extend(syn::Error::new(
|
||||
fn_span.join(method.sig.ident.span()).unwrap_or(fn_span),
|
||||
format!(
|
||||
"hint: `#[tarpc::server]` only rewrites async fns, and `fn {}` is not async",
|
||||
method.sig.ident
|
||||
),
|
||||
));
|
||||
match result {
|
||||
Ok(_) => result = Err(e),
|
||||
Err(ref mut error) => error.extend(Some(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
// Things needed to generate the service items: trait, serve impl, request/response enums, and
|
||||
// the client stub.
|
||||
struct ServiceGenerator<'a> {
|
||||
|
||||
Reference in New Issue
Block a user