Add service name to generated items.

With this change, the service definitions don't need to be isolated in their own modules.
This commit is contained in:
Tim Kuehn
2019-07-23 01:47:14 -07:00
parent 5c485fe608
commit 2f24842b2d
10 changed files with 167 additions and 124 deletions

View File

@@ -196,6 +196,19 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
let service_name_repeated2 = service_name_repeated.clone();
let client_ident = Ident::new(&format!("{}Client", ident), ident.span());
let request_ident = Ident::new(&format!("{}Request", ident), ident.span());
let request_ident_repeated = std::iter::repeat(request_ident.clone());
let request_ident_repeated2 = request_ident_repeated.clone();
let response_ident = Ident::new(&format!("{}Response", ident), ident.span());
let response_ident_repeated = std::iter::repeat(response_ident.clone());
let response_ident_repeated2 = response_ident_repeated.clone();
let response_fut_name = format!("{}ResponseFut", ident);
let response_fut_ident = Ident::new(&response_fut_name, ident.span());
let response_fut_ident_repeated = std::iter::repeat(response_fut_ident.clone());
let response_fut_ident_repeated2 = response_fut_ident_repeated.clone();
let snake_ident = camel_to_snake(&ident.to_string());
let serve_ident = Ident::new(&format!("serve_{}", snake_ident), ident.span());
let stub_ident = Ident::new(&format!("{}_stub", snake_ident), ident.span());
#[cfg(feature = "serde1")]
let derive_serialize = quote!(#[derive(serde::Serialize, serde::Deserialize)]);
@@ -211,41 +224,41 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
/// The request sent over the wire from the client to the server.
#[derive(Debug)]
#derive_serialize
#vis enum Request {
#vis enum #request_ident {
#( #camel_case_idents{ #args } ),*
}
/// The response sent over the wire from the server to the client.
#[derive(Debug)]
#derive_serialize
#vis enum Response {
#vis enum #response_ident {
#( #camel_case_idents(#outputs) ),*
}
/// A future resolving to a server [`Response`].
#vis enum ResponseFut<S: #ident> {
/// A future resolving to a server response.
#vis enum #response_fut_ident<S: #ident> {
#( #camel_case_idents(<S as #service_name_repeated>::#future_types) ),*
}
impl<S: #ident> std::fmt::Debug for ResponseFut<S> {
impl<S: #ident> std::fmt::Debug for #response_fut_ident<S> {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("ResponseFut").finish()
fmt.debug_struct(#response_fut_name).finish()
}
}
impl<S: #ident> std::future::Future for ResponseFut<S> {
type Output = std::io::Result<Response>;
impl<S: #ident> std::future::Future for #response_fut_ident<S> {
type Output = std::io::Result<#response_ident>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
-> std::task::Poll<std::io::Result<Response>>
-> std::task::Poll<std::io::Result<#response_ident>>
{
unsafe {
match std::pin::Pin::get_unchecked_mut(self) {
#(
ResponseFut::#camel_case_idents(resp) =>
#response_fut_ident_repeated::#camel_case_idents(resp) =>
std::pin::Pin::new_unchecked(resp)
.poll(cx)
.map(Response::#camel_case_idents2)
.map(#response_ident_repeated::#camel_case_idents2)
.map(Ok),
)*
}
@@ -254,13 +267,13 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
}
/// Returns a serving function to use with tarpc::server::Server.
#vis fn serve<S: #ident>(service: S)
-> impl FnOnce(tarpc::context::Context, Request) -> ResponseFut<S> + Send + 'static + Clone {
#vis fn #serve_ident<S: #ident>(service: S)
-> impl FnOnce(tarpc::context::Context, #request_ident) -> #response_fut_ident<S> + Send + 'static + Clone {
move |ctx, req| {
match req {
#(
Request::#camel_case_idents{ #arg_vars } => {
ResponseFut::#camel_case_idents2(
#request_ident_repeated::#camel_case_idents{ #arg_vars } => {
#response_fut_ident_repeated2::#camel_case_idents2(
#service_name_repeated2::#method_names(
service.clone(), ctx, #arg_vars2))
}
@@ -272,19 +285,19 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
#[allow(unused)]
#[derive(Clone, Debug)]
/// The client stub that makes RPC calls to the server. Exposes a Future interface.
#vis struct #client_ident<C = tarpc::client::Channel<Request, Response>>(C);
#vis struct #client_ident<C = tarpc::client::Channel<#request_ident, #response_ident>>(C);
/// Returns a new client stub that sends requests over the given transport.
#vis async fn new_stub<T>(config: tarpc::client::Config, transport: T)
#vis async fn #stub_ident<T>(config: tarpc::client::Config, transport: T)
-> std::io::Result<#client_ident>
where
T: tarpc::Transport<tarpc::ClientMessage<Request>, tarpc::Response<Response>> + Send + 'static,
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>> + Send + 'static,
{
Ok(#client_ident(tarpc::client::new(config, transport).await?))
}
impl<C> From<C> for #client_ident<C>
where for <'a> C: tarpc::Client<'a, Request, Response = Response>
where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
fn from(client: C) -> Self {
#client_ident(client)
@@ -292,18 +305,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
}
impl<C> #client_ident<C>
where for<'a> C: tarpc::Client<'a, Request, Response = Response>
where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident>
{
#(
#[allow(unused)]
#( #method_attrs )*
pub fn #method_names(&mut self, ctx: tarpc::context::Context, #args)
-> impl std::future::Future<Output = std::io::Result<#outputs>> + '_ {
let request = Request::#camel_case_idents { #arg_vars };
let request = #request_ident_repeated2::#camel_case_idents { #arg_vars };
let resp = tarpc::Client::call(&mut self.0, ctx, request);
async move {
match resp.await? {
Response::#camel_case_idents2(msg) => std::result::Result::Ok(msg),
#response_ident_repeated2::#camel_case_idents2(msg) => std::result::Result::Ok(msg),
_ => unreachable!(),
}
}
@@ -334,6 +347,28 @@ fn snake_to_camel(ident_str: &str) -> String {
camel_ty
}
// Really basic camel to snake that assumes capitals are always the start of a new segment.
fn camel_to_snake(ident_str: &str) -> String {
let mut snake = String::new();
let mut chars = ident_str.chars();
if let Some(c) = chars.next() {
snake.extend(c.to_lowercase());
}
while let Some(c) = chars.next() {
if c.is_uppercase() {
// New word
snake.push('_');
snake.extend(c.to_lowercase());
} else {
// Same word
snake.push(c)
}
}
snake
}
#[test]
fn snake_to_camel_basic() {
assert_eq!(snake_to_camel("abc_def"), "AbcDef");
@@ -358,3 +393,8 @@ fn snake_to_camel_underscore_consecutive() {
fn snake_to_camel_capital_in_middle() {
assert_eq!(snake_to_camel("aBc_dEf"), "AbcDef");
}
#[test]
fn camel_to_snake_basic() {
assert_eq!(camel_to_snake("AbcDef"), "abc_def");
}