Allow raw identifiers + fixed naming + place all code generation methods in impl (#291)

Allows defining services using raw identifiers like:

```rust
pub mod service {
    #[tarpc::service]
    pub trait r#trait {
        async fn r#fn(x: i32) -> Result<u8, String>;
    }
}
```

Also:

- Refactored names (ident -> type)
- All code generation methods placed in impl
This commit is contained in:
Oleg Nosov
2019-12-12 21:13:57 +03:00
committed by Tim
parent d905bc1591
commit 31c713d188
2 changed files with 364 additions and 323 deletions

View File

@@ -12,11 +12,14 @@ extern crate quote;
extern crate syn;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote, ToTokens};
use syn::{
braced, parenthesized,
braced,
ext::IdentExt,
parenthesized,
parse::{Parse, ParseStream},
parse_macro_input, parse_quote,
parse_macro_input, parse_quote, parse_str,
punctuated::Punctuated,
token::Comma,
Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type,
@@ -42,7 +45,7 @@ impl Parse for Service {
let attrs = input.call(Attribute::parse_outer)?;
let vis = input.parse()?;
input.parse::<Token![trait]>()?;
let ident = input.parse()?;
let ident: Ident = input.parse()?;
let content;
braced!(content in input);
let mut rpcs = Vec::<RpcMethod>::new();
@@ -53,7 +56,7 @@ impl Parse for Service {
if rpc.ident == "new" {
return Err(input.error(format!(
"method name conflicts with generated fn `{}Client::new`",
ident
ident.unraw()
)));
}
if rpc.ident == "serve" {
@@ -63,6 +66,7 @@ impl Parse for Service {
)));
}
}
Ok(Self {
attrs,
vis,
@@ -156,19 +160,19 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
ref rpcs,
} = parse_macro_input!(input as Service);
let camel_case_fn_names: &[String] = &rpcs
let camel_case_fn_names: &Vec<_> = &rpcs
.iter()
.map(|rpc| snake_to_camel(&rpc.ident.to_string()))
.collect::<Vec<_>>();
.map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
.collect();
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
let response_fut_name = &format!("{}ResponseFut", ident);
let derive_serialize = quote!(#[derive(serde::Serialize, serde::Deserialize)]);
let response_fut_name = &format!("{}ResponseFut", ident.unraw());
let derive_serialize = if derive_serde.0 {
Some(quote!(#[derive(serde::Serialize, serde::Deserialize)]))
} else {
None
};
let gen_args = &GenArgs {
attrs,
vis,
rpcs,
args,
ServiceGenerator {
response_fut_name,
service_ident: ident,
server_ident: &format_ident!("Serve{}", ident),
@@ -176,8 +180,12 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
client_ident: &format_ident!("{}Client", ident),
request_ident: &format_ident!("{}Request", ident),
response_ident: &format_ident!("{}Response", ident),
vis,
args,
method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
method_names: &rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>(),
method_idents: &rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>(),
attrs,
rpcs,
return_types: &rpcs
.iter()
.map(|rpc| match rpc.output {
@@ -185,7 +193,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
ReturnType::Default => unit_type,
})
.collect::<Vec<_>>(),
arg_vars: &args
arg_pats: &args
.iter()
.map(|args| args.iter().map(|arg| &*arg.pat).collect())
.collect::<Vec<_>>(),
@@ -194,43 +202,19 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
.zip(camel_case_fn_names.iter())
.map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
.collect::<Vec<_>>(),
future_idents: &camel_case_fn_names
future_types: &camel_case_fn_names
.iter()
.map(|name| format_ident!("{}Fut", name))
.map(|name| parse_str(&format!("{}Fut", name)).unwrap())
.collect::<Vec<_>>(),
derive_serialize: if derive_serde.0 {
Some(&derive_serialize)
} else {
None
},
};
let tokens = vec![
trait_service(gen_args),
struct_server(gen_args),
impl_serve_for_server(gen_args),
enum_request(gen_args),
enum_response(gen_args),
enum_response_future(gen_args),
impl_debug_for_response_future(gen_args),
impl_future_for_response_future(gen_args),
struct_client(gen_args),
impl_from_for_client(gen_args),
impl_client_new(gen_args),
impl_client_rpc_methods(gen_args),
];
tokens
.into_iter()
.collect::<proc_macro2::TokenStream>()
.into()
derive_serialize: derive_serialize.as_ref(),
}
.into_token_stream()
.into()
}
// Things needed to generate the service items: trait, serve impl, request/response enums, and
// the client stub.
struct GenArgs<'a> {
attrs: &'a [Attribute],
rpcs: &'a [RpcMethod],
struct ServiceGenerator<'a> {
service_ident: &'a Ident,
server_ident: &'a Ident,
response_fut_ident: &'a Ident,
@@ -238,331 +222,356 @@ struct GenArgs<'a> {
client_ident: &'a Ident,
request_ident: &'a Ident,
response_ident: &'a Ident,
method_attrs: &'a [&'a [Attribute]],
vis: &'a Visibility,
method_names: &'a [&'a Ident],
attrs: &'a [Attribute],
rpcs: &'a [RpcMethod],
camel_case_idents: &'a [Ident],
future_types: &'a [Type],
method_idents: &'a [&'a Ident],
method_attrs: &'a [&'a [Attribute]],
args: &'a [&'a [PatType]],
return_types: &'a [&'a Type],
arg_vars: &'a [Vec<&'a Pat>],
camel_case_idents: &'a [Ident],
future_idents: &'a [Ident],
derive_serialize: Option<&'a proc_macro2::TokenStream>,
arg_pats: &'a [Vec<&'a Pat>],
derive_serialize: Option<&'a TokenStream2>,
}
fn trait_service(
&GenArgs {
attrs,
rpcs,
vis,
future_idents,
return_types,
service_ident,
server_ident,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
let types_and_fns = rpcs
.iter()
.zip(future_idents.iter())
.zip(return_types.iter())
.map(
|(
(
RpcMethod {
attrs, ident, args, ..
},
future_type,
),
output,
)| {
let ty_doc = format!("The response future returned by {}.", ident);
quote! {
#[doc = #ty_doc]
type #future_type: std::future::Future<Output = #output>;
impl<'a> ServiceGenerator<'a> {
fn trait_service(&self) -> TokenStream2 {
let &Self {
attrs,
rpcs,
vis,
future_types,
return_types,
service_ident,
server_ident,
..
} = self;
#( #attrs )*
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type;
}
},
);
let types_and_fns = rpcs
.iter()
.zip(future_types.iter())
.zip(return_types.iter())
.map(
|(
(
RpcMethod {
attrs, ident, args, ..
},
future_type,
),
output,
)| {
let ty_doc = format!("The response future returned by {}.", ident);
quote! {
#[doc = #ty_doc]
type #future_type: std::future::Future<Output = #output>;
quote! {
#( #attrs )*
#vis trait #service_ident: Clone {
#( #types_and_fns )*
#( #attrs )*
fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> Self::#future_type;
}
},
);
/// Returns a serving function to use with tarpc::server::Server.
fn serve(self) -> #server_ident<Self> {
#server_ident { service: self }
}
}
}
}
quote! {
#( #attrs )*
#vis trait #service_ident: Clone {
#( #types_and_fns )*
fn struct_server(
&GenArgs {
vis, server_ident, ..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
#[derive(Clone)]
#vis struct #server_ident<S> {
service: S,
}
}
}
fn impl_serve_for_server(
&GenArgs {
request_ident,
server_ident,
service_ident,
response_ident,
response_fut_ident,
camel_case_idents,
arg_vars,
method_names,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl<S> tarpc::server::Serve<#request_ident> for #server_ident<S>
where S: #service_ident
{
type Resp = #response_ident;
type Fut = #response_fut_ident<S>;
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
match req {
#(
#request_ident::#camel_case_idents{ #( #arg_vars ),* } => {
#response_fut_ident::#camel_case_idents(
#service_ident::#method_names(
self.service, ctx, #( #arg_vars ),*))
}
)*
/// Returns a serving function to use with tarpc::server::Server.
fn serve(self) -> #server_ident<Self> {
#server_ident { service: self }
}
}
}
}
}
fn enum_request(
&GenArgs {
derive_serialize,
vis,
request_ident,
camel_case_idents,
args,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
/// The request sent over the wire from the client to the server.
#[derive(Debug)]
#derive_serialize
#vis enum #request_ident {
#( #camel_case_idents{ #( #args ),* } ),*
}
}
}
fn struct_server(&self) -> TokenStream2 {
let &Self {
vis, server_ident, ..
} = self;
fn enum_response(
&GenArgs {
derive_serialize,
vis,
response_ident,
camel_case_idents,
return_types,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
/// The response sent over the wire from the server to the client.
#[derive(Debug)]
#derive_serialize
#vis enum #response_ident {
#( #camel_case_idents(#return_types) ),*
}
}
}
fn enum_response_future(
&GenArgs {
vis,
service_ident,
response_fut_ident,
camel_case_idents,
future_idents,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
/// A future resolving to a server response.
#vis enum #response_fut_ident<S: #service_ident> {
#( #camel_case_idents(<S as #service_ident>::#future_idents) ),*
}
}
}
fn impl_debug_for_response_future(
&GenArgs {
service_ident,
response_fut_ident,
response_fut_name,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct(#response_fut_name).finish()
quote! {
#[derive(Clone)]
#vis struct #server_ident<S> {
service: S,
}
}
}
}
fn impl_future_for_response_future(
&GenArgs {
service_ident,
response_fut_ident,
response_ident,
camel_case_idents,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl<S: #service_ident> std::future::Future for #response_fut_ident<S> {
type Output = #response_ident;
fn impl_serve_for_server(&self) -> TokenStream2 {
let &Self {
request_ident,
server_ident,
service_ident,
response_ident,
response_fut_ident,
camel_case_idents,
arg_pats,
method_idents,
..
} = self;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
-> std::task::Poll<#response_ident>
quote! {
impl<S> tarpc::server::Serve<#request_ident> for #server_ident<S>
where S: #service_ident
{
unsafe {
match std::pin::Pin::get_unchecked_mut(self) {
type Resp = #response_ident;
type Fut = #response_fut_ident<S>;
fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut {
match req {
#(
#response_fut_ident::#camel_case_idents(resp) =>
std::pin::Pin::new_unchecked(resp)
.poll(cx)
.map(#response_ident::#camel_case_idents),
#request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
#response_fut_ident::#camel_case_idents(
#service_ident::#method_idents(
self.service, ctx, #( #arg_pats ),*
)
)
}
)*
}
}
}
}
}
}
fn struct_client(
&GenArgs {
vis,
client_ident,
request_ident,
response_ident,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
#[allow(unused)]
#[derive(Clone, Debug)]
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
#vis struct #client_ident<C = tarpc::client::Channel<#request_ident, #response_ident>>(C);
}
}
fn enum_request(&self) -> TokenStream2 {
let &Self {
derive_serialize,
vis,
request_ident,
camel_case_idents,
args,
..
} = self;
fn impl_from_for_client(
&GenArgs {
client_ident,
request_ident,
response_ident,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl<C> From<C> for #client_ident<C>
where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
fn from(client: C) -> Self {
#client_ident(client)
quote! {
/// The request sent over the wire from the client to the server.
#[derive(Debug)]
#derive_serialize
#vis enum #request_ident {
#( #camel_case_idents{ #( #args ),* } ),*
}
}
}
}
fn impl_client_new(
&GenArgs {
client_ident,
vis,
request_ident,
response_ident,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl #client_ident {
/// Returns a new client stub that sends requests over the given transport.
#vis fn new<T>(config: tarpc::client::Config, transport: T)
-> tarpc::client::NewClient<
Self,
tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>>
where
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>
{
let new_client = tarpc::client::new(config, transport);
tarpc::client::NewClient {
client: #client_ident(new_client.client),
dispatch: new_client.dispatch,
fn enum_response(&self) -> TokenStream2 {
let &Self {
derive_serialize,
vis,
response_ident,
camel_case_idents,
return_types,
..
} = self;
quote! {
/// The response sent over the wire from the server to the client.
#[derive(Debug)]
#derive_serialize
#vis enum #response_ident {
#( #camel_case_idents(#return_types) ),*
}
}
}
fn enum_response_future(&self) -> TokenStream2 {
let &Self {
vis,
service_ident,
response_fut_ident,
camel_case_idents,
future_types,
..
} = self;
quote! {
/// A future resolving to a server response.
#vis enum #response_fut_ident<S: #service_ident> {
#( #camel_case_idents(<S as #service_ident>::#future_types) ),*
}
}
}
fn impl_debug_for_response_future(&self) -> TokenStream2 {
let &Self {
service_ident,
response_fut_ident,
response_fut_name,
..
} = self;
quote! {
impl<S: #service_ident> std::fmt::Debug for #response_fut_ident<S> {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct(#response_fut_name).finish()
}
}
}
}
}
fn impl_client_rpc_methods(
&GenArgs {
client_ident,
request_ident,
response_ident,
method_attrs,
vis,
method_names,
args,
return_types,
arg_vars,
camel_case_idents,
..
}: &GenArgs,
) -> proc_macro2::TokenStream {
quote! {
impl<C> #client_ident<C>
where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
#(
#[allow(unused)]
#( #method_attrs )*
#vis fn #method_names(&mut self, ctx: tarpc::context::Context, #( #args ),*)
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_vars ),* };
let resp = tarpc::Client::call(&mut self.0, ctx, request);
async move {
match resp.await? {
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
_ => unreachable!(),
fn impl_future_for_response_future(&self) -> TokenStream2 {
let &Self {
service_ident,
response_fut_ident,
response_ident,
camel_case_idents,
..
} = self;
quote! {
impl<S: #service_ident> std::future::Future for #response_fut_ident<S> {
type Output = #response_ident;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
-> std::task::Poll<#response_ident>
{
unsafe {
match std::pin::Pin::get_unchecked_mut(self) {
#(
#response_fut_ident::#camel_case_idents(resp) =>
std::pin::Pin::new_unchecked(resp)
.poll(cx)
.map(#response_ident::#camel_case_idents),
)*
}
}
}
)*
}
}
}
fn struct_client(&self) -> TokenStream2 {
let &Self {
vis,
client_ident,
request_ident,
response_ident,
..
} = self;
quote! {
#[allow(unused)]
#[derive(Clone, Debug)]
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
#vis struct #client_ident<C = tarpc::client::Channel<#request_ident, #response_ident>>(C);
}
}
fn impl_from_for_client(&self) -> TokenStream2 {
let &Self {
client_ident,
request_ident,
response_ident,
..
} = self;
quote! {
impl<C> From<C> for #client_ident<C>
where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
fn from(client: C) -> Self {
#client_ident(client)
}
}
}
}
fn impl_client_new(&self) -> TokenStream2 {
let &Self {
client_ident,
vis,
request_ident,
response_ident,
..
} = self;
quote! {
impl #client_ident {
/// Returns a new client stub that sends requests over the given transport.
#vis fn new<T>(config: tarpc::client::Config, transport: T)
-> tarpc::client::NewClient<
Self,
tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>
>
where
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>
{
let new_client = tarpc::client::new(config, transport);
tarpc::client::NewClient {
client: #client_ident(new_client.client),
dispatch: new_client.dispatch,
}
}
}
}
}
fn impl_client_rpc_methods(&self) -> TokenStream2 {
let &Self {
client_ident,
request_ident,
response_ident,
method_attrs,
vis,
method_idents,
args,
return_types,
arg_pats,
camel_case_idents,
..
} = self;
quote! {
impl<C> #client_ident<C>
where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
#(
#[allow(unused)]
#( #method_attrs )*
#vis fn #method_idents(&mut self, ctx: tarpc::context::Context, #( #args ),*)
-> impl std::future::Future<Output = std::io::Result<#return_types>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
let resp = tarpc::Client::call(&mut self.0, ctx, request);
async move {
match resp.await? {
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg),
_ => unreachable!(),
}
}
}
)*
}
}
}
}
impl<'a> ToTokens for ServiceGenerator<'a> {
fn to_tokens(&self, output: &mut TokenStream2) {
output.extend(vec![
self.trait_service(),
self.struct_server(),
self.impl_serve_for_server(),
self.enum_request(),
self.enum_response(),
self.enum_response_future(),
self.impl_debug_for_response_future(),
self.impl_future_for_response_future(),
self.struct_client(),
self.impl_from_for_client(),
self.impl_client_new(),
self.impl_client_rpc_methods(),
])
}
}
fn snake_to_camel(ident_str: &str) -> String {
let chars = ident_str.chars();
let mut camel_ty = String::with_capacity(ident_str.len());
let mut last_char_was_underscore = true;
for c in chars {
for c in ident_str.chars() {
match c {
'_' => last_char_was_underscore = true,
c if last_char_was_underscore => {

View File

@@ -29,6 +29,38 @@ fn att_service_trait() {
}
}
#[allow(non_camel_case_types)]
#[test]
fn raw_idents() {
use futures::future::{ready, Ready};
type r#yield = String;
#[tarpc::service]
trait r#trait {
async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32);
async fn r#fn(r#impl: r#yield) -> r#yield;
async fn r#async();
}
impl r#trait for () {
type AwaitFut = Ready<(r#yield, i32)>;
fn r#await(self, _: context::Context, r#struct: r#yield, r#enum: i32) -> Self::AwaitFut {
ready((r#struct, r#enum))
}
type FnFut = Ready<r#yield>;
fn r#fn(self, _: context::Context, r#impl: r#yield) -> Self::FnFut {
ready(r#impl)
}
type AsyncFut = Ready<()>;
fn r#async(self, _: context::Context) -> Self::AsyncFut {
ready(())
}
}
}
#[test]
fn syntax() {
#[tarpc::service]