diff --git a/plugins/Cargo.toml b/plugins/Cargo.toml index 30e68fe..a468d15 100644 --- a/plugins/Cargo.toml +++ b/plugins/Cargo.toml @@ -30,3 +30,4 @@ proc-macro = true futures = "0.3" serde = { version = "1.0", features = ["derive"] } tarpc = { path = "../tarpc" } +assert-type-eq = "0.1.0" diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 41a34cb..ca63107 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -22,8 +22,8 @@ use syn::{ parse_macro_input, parse_quote, parse_str, punctuated::Punctuated, token::Comma, - Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, ReturnType, Token, Type, - Visibility, + Attribute, FnArg, Ident, ImplItem, ImplItemMethod, ImplItemType, ItemImpl, Lit, LitBool, + MetaNameValue, Pat, PatType, ReturnType, Token, Type, Visibility, }; struct Service { @@ -212,6 +212,126 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .into() } +/// Transforms an async function into a sync one, returning a type declaration +/// for the return type (a future). +fn transform_method(method: &mut ImplItemMethod) -> ImplItemType { + method.sig.asyncness = None; + + // get either the return type or (). + let ret = match &method.sig.output { + ReturnType::Default => quote!(()), + ReturnType::Type(_, ret) => quote!(#ret), + }; + + // generate an identifier consisting of the method name to CamelCase with + // Fut appended to it. + let fut_name = snake_to_camel(&method.sig.ident.unraw().to_string()) + "Fut"; + let fut_name_ident = Ident::new(&fut_name, method.sig.ident.span()); + + // generate the updated return signature. + method.sig.output = parse_quote! { + -> ::core::pin::Pin + ::core::marker::Send + >> + }; + + // transform the body of the method into Box::pin(async move { body }). + let block = method.block.clone(); + method.block = parse_quote! [{ + Box::pin(async move + #block + ) + }]; + + // generate and return type declaration for return type. + let t: ImplItemType = parse_quote! { + type #fut_name_ident = ::core::pin::Pin + ::core::marker::Send>>; + }; + + t +} + +/// Syntactic sugar to make using async functions in the server implementation +/// easier. It does this by rewriting code like this, which would normally not +/// compile because async functions are disallowed in trait implementations: +/// +/// ```rust +/// # extern crate tarpc; +/// # use tarpc::context; +/// # use std::net::SocketAddr; +/// #[tarpc_plugins::service] +/// trait World { +/// async fn hello(name: String) -> String; +/// } +/// +/// #[derive(Clone)] +/// struct HelloServer(SocketAddr); +/// +/// #[tarpc_plugins::server] +/// impl World for HelloServer { +/// async fn hello(self, _: context::Context, name: String) -> String { +/// format!("Hello, {}! You are connected from {:?}.", name, self.0) +/// } +/// } +/// ``` +/// +/// Into code like this, which matches the service trait definition: +/// +/// ```rust +/// # extern crate tarpc; +/// # use tarpc::context; +/// # use std::pin::Pin; +/// # use futures::Future; +/// # use std::net::SocketAddr; +/// #[tarpc_plugins::service] +/// trait World { +/// async fn hello(name: String) -> String; +/// } +/// +/// #[derive(Clone)] +/// struct HelloServer(SocketAddr); +/// +/// impl World for HelloServer { +/// type HelloFut = Pin + Send>>; +/// +/// fn hello(self, _: context::Context, name: String) -> Pin +/// + Send>> { +/// Box::pin(async move { +/// format!("Hello, {}! You are connected from {:?}.", name, self.0) +/// }) +/// } +/// } +/// ``` +/// +/// Note that this won't touch functions unless they have been annotated with +/// `async`, meaning that this should not break existing code. +#[proc_macro_attribute] +pub fn server(_attr: TokenStream, input: TokenStream) -> TokenStream { + let mut item = syn::parse_macro_input!(input as ItemImpl); + + // the generated type declarations + let mut types: Vec = Vec::new(); + + for inner in &mut item.items { + if let ImplItem::Method(method) = inner { + let sig = &method.sig; + + // if this function is declared async, transform it into a regular function + if sig.asyncness.is_some() { + let typedecl = transform_method(method); + types.push(typedecl); + } + } + } + + // add the type declarations into the impl block + for t in types.into_iter() { + item.items.push(syn::ImplItem::Type(t)); + } + + TokenStream::from(quote!(#item)) +} + // Things needed to generate the service items: trait, serve impl, request/response enums, and // the client stub. struct ServiceGenerator<'a> { diff --git a/plugins/tests/server.rs b/plugins/tests/server.rs new file mode 100644 index 0000000..f0222ff --- /dev/null +++ b/plugins/tests/server.rs @@ -0,0 +1,144 @@ +use assert_type_eq::assert_type_eq; +use futures::Future; +use std::pin::Pin; +use tarpc::context; + +// these need to be out here rather than inside the function so that the +// assert_type_eq macro can pick them up. +#[tarpc::service] +trait Foo { + async fn two_part(s: String, i: i32) -> (String, i32); + async fn bar(s: String) -> String; + async fn baz(); +} + +#[test] +fn type_generation_works() { + #[tarpc::server] + impl Foo for () { + async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { + (s, i) + } + + async fn bar(self, _: context::Context, s: String) -> String { + s + } + + async fn baz(self, _: context::Context) {} + } + + // the assert_type_eq macro can only be used once per block. + { + assert_type_eq!( + <() as Foo>::TwoPartFut, + Pin + Send>> + ); + } + { + assert_type_eq!( + <() as Foo>::BarFut, + Pin + Send>> + ); + } + { + assert_type_eq!( + <() as Foo>::BazFut, + Pin + Send>> + ); + } +} + +#[allow(non_camel_case_types)] +#[test] +fn raw_idents_work() { + type r#yield = String; + + #[tarpc::service] + trait r#trait { + async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32); + async fn r#fn(r#impl: r#yield) -> r#yield; + async fn r#async(); + } + + #[tarpc::server] + impl r#trait for () { + async fn r#await( + self, + _: context::Context, + r#struct: r#yield, + r#enum: i32, + ) -> (r#yield, i32) { + (r#struct, r#enum) + } + + async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { + r#impl + } + + async fn r#async(self, _: context::Context) {} + } +} + +#[test] +fn syntax() { + #[tarpc::service] + trait Syntax { + #[deny(warnings)] + #[allow(non_snake_case)] + async fn TestCamelCaseDoesntConflict(); + async fn hello() -> String; + #[doc = "attr"] + async fn attr(s: String) -> String; + async fn no_args_no_return(); + async fn no_args() -> (); + async fn one_arg(one: String) -> i32; + async fn two_args_no_return(one: String, two: u64); + async fn two_args(one: String, two: u64) -> String; + async fn no_args_ret_error() -> i32; + async fn one_arg_ret_error(one: String) -> String; + async fn no_arg_implicit_return_error(); + #[doc = "attr"] + async fn one_arg_implicit_return_error(one: String); + } + + #[tarpc::server] + impl Syntax for () { + #[deny(warnings)] + #[allow(non_snake_case)] + async fn TestCamelCaseDoesntConflict(self, _: context::Context) {} + + async fn hello(self, _: context::Context) -> String { + String::new() + } + + async fn attr(self, _: context::Context, _s: String) -> String { + String::new() + } + + async fn no_args_no_return(self, _: context::Context) {} + + async fn no_args(self, _: context::Context) -> () {} + + async fn one_arg(self, _: context::Context, _one: String) -> i32 { + 0 + } + + async fn two_args_no_return(self, _: context::Context, _one: String, _two: u64) {} + + async fn two_args(self, _: context::Context, _one: String, _two: u64) -> String { + String::new() + } + + async fn no_args_ret_error(self, _: context::Context) -> i32 { + 0 + } + + async fn one_arg_ret_error(self, _: context::Context, _one: String) -> String { + String::new() + } + + async fn no_arg_implicit_return_error(self, _: context::Context) {} + + async fn one_arg_implicit_return_error(self, _: context::Context, _one: String) {} + } +} diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 23849dd..a97f86d 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -233,3 +233,59 @@ pub mod trace; /// * `Client` -- a client stub with a fn for each RPC. /// * `fn new_stub` -- creates a new Client stub. pub use tarpc_plugins::service; + +/// A utility macro that can be used for RPC server implementations. +/// +/// Syntactic sugar to make using async functions in the server implementation +/// easier. It does this by rewriting code like this, which would normally not +/// compile because async functions are disallowed in trait implementations: +/// +/// ```rust +/// # use tarpc::context; +/// # use std::net::SocketAddr; +/// #[tarpc::service] +/// trait World { +/// async fn hello(name: String) -> String; +/// } +/// +/// #[derive(Clone)] +/// struct HelloServer(SocketAddr); +/// +/// #[tarpc::server] +/// impl World for HelloServer { +/// async fn hello(self, _: context::Context, name: String) -> String { +/// format!("Hello, {}! You are connected from {:?}.", name, self.0) +/// } +/// } +/// ``` +/// +/// Into code like this, which matches the service trait definition: +/// +/// ```rust +/// # use tarpc::context; +/// # use std::pin::Pin; +/// # use futures::Future; +/// # use std::net::SocketAddr; +/// #[derive(Clone)] +/// struct HelloServer(SocketAddr); +/// +/// #[tarpc::service] +/// trait World { +/// async fn hello(name: String) -> String; +/// } +/// +/// impl World for HelloServer { +/// type HelloFut = Pin + Send>>; +/// +/// fn hello(self, _: context::Context, name: String) -> Pin +/// + Send>> { +/// Box::pin(async move { +/// format!("Hello, {}! You are connected from {:?}.", name, self.0) +/// }) +/// } +/// } +/// ``` +/// +/// Note that this won't touch functions unless they have been annotated with +/// `async`, meaning that this should not break existing code. +pub use tarpc_plugins::server;