9 Commits

3 changed files with 117 additions and 37 deletions

View File

@@ -220,15 +220,15 @@ impl Parse for DeriveSerde {
/// Adds the following annotations to the annotated item: /// Adds the following annotations to the annotated item:
/// ///
/// ```rust /// ```rust
/// #[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)] /// #[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)]
/// #[serde(crate = "tarpc::serde")] /// #[serde(crate = "tarpc::serde")]
/// # struct Foo; /// # struct Foo;
/// ``` /// ```
#[proc_macro_attribute] #[proc_macro_attribute]
pub fn derive_serde(_attr: TokenStream, item: TokenStream) -> TokenStream { pub fn derive_serde(_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut gen: proc_macro2::TokenStream = quote! { let mut gen: proc_macro2::TokenStream = quote! {
#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)] #[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)]
#[serde(crate = "tarpc::serde")] #[serde(crate = "::tarpc::serde")]
}; };
gen.extend(proc_macro2::TokenStream::from(item)); gen.extend(proc_macro2::TokenStream::from(item));
proc_macro::TokenStream::from(gen) proc_macro::TokenStream::from(gen)
@@ -259,8 +259,8 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>(); let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
let derive_serialize = if derive_serde.0 { let derive_serialize = if derive_serde.0 {
Some( Some(
quote! {#[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)] quote! {#[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)]
#[serde(crate = "tarpc::serde")]}, #[serde(crate = "::tarpc::serde")]},
) )
} else { } else {
None None
@@ -357,7 +357,7 @@ impl<'a> ServiceGenerator<'a> {
)| { )| {
quote! { quote! {
#( #attrs )* #( #attrs )*
async fn #ident(self, context: tarpc::context::Context, #( #args ),*) -> #output; async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output;
} }
}, },
); );
@@ -365,22 +365,22 @@ impl<'a> ServiceGenerator<'a> {
let stub_doc = format!("The stub trait for service [`{service_ident}`]."); let stub_doc = format!("The stub trait for service [`{service_ident}`].");
quote! { quote! {
#( #attrs )* #( #attrs )*
#vis trait #service_ident: Sized { #vis trait #service_ident: ::core::marker::Sized {
#( #rpc_fns )* #( #rpc_fns )*
/// Returns a serving function to use with /// Returns a serving function to use with
/// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute). /// [InFlightRequest::execute](::tarpc::server::InFlightRequest::execute).
fn serve(self) -> #server_ident<Self> { fn serve(self) -> #server_ident<Self> {
#server_ident { service: self } #server_ident { service: self }
} }
} }
#[doc = #stub_doc] #[doc = #stub_doc]
#vis trait #client_stub_ident: tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> { #vis trait #client_stub_ident: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
} }
impl<S> #client_stub_ident for S impl<S> #client_stub_ident for S
where S: tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> where S: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
{ {
} }
} }
@@ -392,7 +392,7 @@ impl<'a> ServiceGenerator<'a> {
} = self; } = self;
quote! { quote! {
/// A serving function to use with [tarpc::server::InFlightRequest::execute]. /// A serving function to use with [::tarpc::server::InFlightRequest::execute].
#[derive(Clone)] #[derive(Clone)]
#vis struct #server_ident<S> { #vis struct #server_ident<S> {
service: S, service: S,
@@ -414,14 +414,14 @@ impl<'a> ServiceGenerator<'a> {
} = self; } = self;
quote! { quote! {
impl<S> tarpc::server::Serve for #server_ident<S> impl<S> ::tarpc::server::Serve for #server_ident<S>
where S: #service_ident where S: #service_ident
{ {
type Req = #request_ident; type Req = #request_ident;
type Resp = #response_ident; type Resp = #response_ident;
fn method(&self, req: &#request_ident) -> Option<&'static str> { fn method(&self, req: &#request_ident) -> ::core::option::Option<&'static str> {
Some(match req { ::core::option::Option::Some(match req {
#( #(
#request_ident::#camel_case_idents{..} => { #request_ident::#camel_case_idents{..} => {
#request_names #request_names
@@ -430,12 +430,12 @@ impl<'a> ServiceGenerator<'a> {
}) })
} }
async fn serve(self, ctx: tarpc::context::Context, req: #request_ident) async fn serve(self, ctx: ::tarpc::context::Context, req: #request_ident)
-> Result<#response_ident, tarpc::ServerError> { -> ::core::result::Result<#response_ident, ::tarpc::ServerError> {
match req { match req {
#( #(
#request_ident::#camel_case_idents{ #( #arg_pats ),* } => { #request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
Ok(#response_ident::#camel_case_idents( ::core::result::Result::Ok(#response_ident::#camel_case_idents(
#service_ident::#method_idents( #service_ident::#method_idents(
self.service, ctx, #( #arg_pats ),* self.service, ctx, #( #arg_pats ),*
).await ).await
@@ -503,9 +503,9 @@ impl<'a> ServiceGenerator<'a> {
#[allow(unused)] #[allow(unused)]
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
/// The client stub that makes RPC calls to the server. All request methods return /// The client stub that makes RPC calls to the server. All request methods return
/// [Futures](std::future::Future). /// [Futures](::core::future::Future).
#vis struct #client_ident< #vis struct #client_ident<
Stub = tarpc::client::Channel<#request_ident, #response_ident> Stub = ::tarpc::client::Channel<#request_ident, #response_ident>
>(Stub); >(Stub);
} }
} }
@@ -522,24 +522,24 @@ impl<'a> ServiceGenerator<'a> {
quote! { quote! {
impl #client_ident { impl #client_ident {
/// Returns a new client stub that sends requests over the given transport. /// Returns a new client stub that sends requests over the given transport.
#vis fn new<T>(config: tarpc::client::Config, transport: T) #vis fn new<T>(config: ::tarpc::client::Config, transport: T)
-> tarpc::client::NewClient< -> ::tarpc::client::NewClient<
Self, Self,
tarpc::client::RequestDispatch<#request_ident, #response_ident, T> ::tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
> >
where where
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>> T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>>
{ {
let new_client = tarpc::client::new(config, transport); let new_client = ::tarpc::client::new(config, transport);
tarpc::client::NewClient { ::tarpc::client::NewClient {
client: #client_ident(new_client.client), client: #client_ident(new_client.client),
dispatch: new_client.dispatch, dispatch: new_client.dispatch,
} }
} }
} }
impl<Stub> From<Stub> for #client_ident<Stub> impl<Stub> ::core::convert::From<Stub> for #client_ident<Stub>
where Stub: tarpc::client::stub::Stub< where Stub: ::tarpc::client::stub::Stub<
Req = #request_ident, Req = #request_ident,
Resp = #response_ident> Resp = #response_ident>
{ {
@@ -570,21 +570,21 @@ impl<'a> ServiceGenerator<'a> {
quote! { quote! {
impl<Stub> #client_ident<Stub> impl<Stub> #client_ident<Stub>
where Stub: tarpc::client::stub::Stub< where Stub: ::tarpc::client::stub::Stub<
Req = #request_ident, Req = #request_ident,
Resp = #response_ident> Resp = #response_ident>
{ {
#( #(
#[allow(unused)] #[allow(unused)]
#( #method_attrs )* #( #method_attrs )*
#vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*) #vis fn #method_idents(&self, ctx: ::tarpc::context::Context, #( #args ),*)
-> impl std::future::Future<Output = Result<#return_types, tarpc::client::RpcError>> + '_ { -> impl ::core::future::Future<Output = ::core::result::Result<#return_types, ::tarpc::client::RpcError>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
let resp = self.0.call(ctx, #request_names, request); let resp = self.0.call(ctx, #request_names, request);
async move { async move {
match resp.await? { match resp.await? {
#response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg), #response_ident::#camel_case_idents(msg) => ::core::result::Result::Ok(msg),
_ => unreachable!(), _ => ::core::unreachable!(),
} }
} }
} }

View File

@@ -210,7 +210,19 @@ pub mod tcp {
Codec: Serializer<SinkItem> + Deserializer<Item>, Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec, CodecFn: Fn() -> Codec,
{ {
let listener = TcpListener::bind(addr).await?; listen_on(TcpListener::bind(addr).await?, codec_fn).await
}
/// Wrap accepted connections from `listener` in TCP transports.
pub async fn listen_on<Item, SinkItem, Codec, CodecFn>(
listener: TcpListener,
codec_fn: CodecFn,
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
where
Item: for<'de> Deserialize<'de>,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
let local_addr = listener.local_addr()?; let local_addr = listener.local_addr()?;
Ok(Incoming { Ok(Incoming {
listener, listener,
@@ -364,7 +376,19 @@ pub mod unix {
Codec: Serializer<SinkItem> + Deserializer<Item>, Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec, CodecFn: Fn() -> Codec,
{ {
let listener = UnixListener::bind(path)?; listen_on(UnixListener::bind(path)?, codec_fn).await
}
/// Wrap accepted connections from `listener` in Unix Domain Socket transports.
pub async fn listen_on<Item, SinkItem, Codec, CodecFn>(
listener: UnixListener,
codec_fn: CodecFn,
) -> io::Result<Incoming<Item, SinkItem, Codec, CodecFn>>
where
Item: for<'de> Deserialize<'de>,
Codec: Serializer<SinkItem> + Deserializer<Item>,
CodecFn: Fn() -> Codec,
{
let local_addr = listener.local_addr()?; let local_addr = listener.local_addr()?;
Ok(Incoming { Ok(Incoming {
listener, listener,
@@ -537,7 +561,7 @@ pub mod unix {
mod tests { mod tests {
use super::Transport; use super::Transport;
use assert_matches::assert_matches; use assert_matches::assert_matches;
use futures::{task::*, Sink, Stream}; use futures::{task::*, Sink, SinkExt, Stream, StreamExt};
use pin_utils::pin_mut; use pin_utils::pin_mut;
use std::{ use std::{
io::{self, Cursor}, io::{self, Cursor},
@@ -631,7 +655,7 @@ mod tests {
); );
} }
#[cfg(tcp)] #[cfg(feature = "tcp")]
#[tokio::test] #[tokio::test]
async fn tcp() -> io::Result<()> { async fn tcp() -> io::Result<()> {
use super::tcp; use super::tcp;
@@ -650,11 +674,30 @@ mod tests {
Ok(()) Ok(())
} }
#[cfg(feature = "tcp")]
#[tokio::test]
async fn tcp_on_existing_transport() -> io::Result<()> {
use super::tcp;
let transport = tokio::net::TcpListener::bind("0.0.0.0:0").await?;
let mut listener = tcp::listen_on(transport, SymmetricalJson::<String>::default).await?;
let addr = listener.local_addr();
tokio::spawn(async move {
let mut transport = listener.next().await.unwrap().unwrap();
let message = transport.next().await.unwrap().unwrap();
transport.send(message).await.unwrap();
});
let mut transport = tcp::connect(addr, SymmetricalJson::<String>::default).await?;
transport.send(String::from("test")).await?;
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
assert_matches!(transport.next().await, None);
Ok(())
}
#[cfg(all(unix, feature = "unix"))] #[cfg(all(unix, feature = "unix"))]
#[tokio::test] #[tokio::test]
async fn uds() -> io::Result<()> { async fn uds() -> io::Result<()> {
use super::unix; use super::unix;
use super::*;
let sock = unix::TempPathBuf::with_random("uds"); let sock = unix::TempPathBuf::with_random("uds");
let mut listener = unix::listen(&sock, SymmetricalJson::<String>::default).await?; let mut listener = unix::listen(&sock, SymmetricalJson::<String>::default).await?;
@@ -669,4 +712,24 @@ mod tests {
assert_matches!(transport.next().await, None); assert_matches!(transport.next().await, None);
Ok(()) Ok(())
} }
#[cfg(all(unix, feature = "unix"))]
#[tokio::test]
async fn uds_on_existing_transport() -> io::Result<()> {
use super::unix;
let sock = unix::TempPathBuf::with_random("uds");
let transport = tokio::net::UnixListener::bind(&sock)?;
let mut listener = unix::listen_on(transport, SymmetricalJson::<String>::default).await?;
tokio::spawn(async move {
let mut transport = listener.next().await.unwrap().unwrap();
let message = transport.next().await.unwrap().unwrap();
transport.send(message).await.unwrap();
});
let mut transport = unix::connect(&sock, SymmetricalJson::<String>::default).await?;
transport.send(String::from("test")).await?;
assert_matches!(transport.next().await, Some(Ok(s)) if s == "test");
assert_matches!(transport.next().await, None);
Ok(())
}
} }

View File

@@ -0,0 +1,17 @@
#![no_implicit_prelude]
extern crate tarpc as some_random_other_name;
#[cfg(feature = "serde1")]
mod serde1_feature {
#[::tarpc::derive_serde]
#[derive(Debug, PartialEq, Eq)]
pub enum TestData {
Black,
White,
}
}
#[::tarpc::service]
pub trait ColorProtocol {
async fn get_opposite_color(color: u8) -> u8;
}