diff --git a/README.md b/README.md index d639f63..e098a00 100644 --- a/README.md +++ b/README.md @@ -176,16 +176,13 @@ async fn main() -> io::Result<()> { // incoming() takes a stream of transports such as would be returned by // TcpListener::incoming (but a stream instead of an iterator). .incoming(stream::once(future::ready(server_transport))) - // serve_world is generated by the macro. It takes as input any type implementing - // the generated World trait. - .respond_with(serve_world(HelloServer)); + .respond_with(HelloServer.serve()); let _ = runtime::spawn(server); - // world_stub is generated by the macro. Like Server, it takes a config and any - // Transport as input, and returns a Client, also generated by the macro. - // by the service attribute. - let mut client = world_stub(client::Config::default(), client_transport).await?; + // WorldClient is generated by the macro. It has a constructor `new` that takes a config and + // any Transport as input + let mut client = WorldClient::new(client::Config::default(), client_transport).await?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 1d60fc9..5074b0e 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -44,9 +44,9 @@ async fn main() -> io::Result<()> { let transport = json_transport::connect(&server_addr).await?; - // world_stub is generated by the service attribute. Like Server, it takes a config and any - // Transport as input, and returns a Client, also generated by the attribute. - let mut client = service::world_stub(client::Config::default(), transport).await?; + // WorldClient is generated by the service attribute. It has a constructor `new` that takes a + // config and any Transport as input. + let mut client = service::WorldClient::new(client::Config::default(), transport).await?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 542505c..b60d7d2 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -11,6 +11,7 @@ use futures::{ future::{self, Ready}, prelude::*, }; +use service::World; use std::{io, net::SocketAddr}; use tarpc::{ context, @@ -22,7 +23,7 @@ use tarpc::{ #[derive(Clone)] struct HelloServer(SocketAddr); -impl service::World for HelloServer { +impl World for HelloServer { // Each defined rpc generates two items in the trait, a fn that serves the RPC, and // an associated type representing the future output by the fn. @@ -74,7 +75,7 @@ async fn main() -> io::Result<()> { // the generated World trait. .map(|channel| { let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap()); - channel.respond_with(service::serve_world(server)) + channel.respond_with(server.serve()) }) // Max 10 channels. .buffer_unordered(10) diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 0ec0d9e..96db027 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -135,6 +135,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { ident, rpcs, } = parse_macro_input!(input as Service); + let vis_repeated = std::iter::repeat(vis.clone()); let camel_case_fn_names: Vec = rpcs .iter() @@ -206,9 +207,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { let response_fut_ident = Ident::new(&response_fut_name, ident.span()); let response_fut_ident_repeated = std::iter::repeat(response_fut_ident.clone()); let response_fut_ident_repeated2 = response_fut_ident_repeated.clone(); - let snake_ident = camel_to_snake(&ident.to_string()); - let serve_ident = Ident::new(&format!("serve_{}", snake_ident), ident.span()); - let stub_ident = Ident::new(&format!("{}_stub", snake_ident), ident.span()); + let server_ident = Ident::new(&format!("{}Server", ident), ident.span()); #[cfg(feature = "serde1")] let derive_serialize = quote!(#[derive(serde::Serialize, serde::Deserialize)]); @@ -219,6 +218,35 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { #( #attrs )* #vis trait #ident: Clone + Send + 'static { #( #types_and_fns )* + + /// Returns a serving function to use with tarpc::server::Server. + fn serve(self) -> #server_ident { + #server_ident { service: self } + } + } + + #[derive(Clone)] + #vis struct #server_ident { + service: S, + } + + impl tarpc::server::Serve<#request_ident> for #server_ident + where S: #ident + { + type Resp = #response_ident; + type Fut = #response_fut_ident; + + fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut { + match req { + #( + #request_ident_repeated::#camel_case_idents{ #arg_vars } => { + #response_fut_ident_repeated2::#camel_case_idents2( + #service_name_repeated2::#method_names( + self.service, ctx, #arg_vars2)) + } + )* + } + } } /// The request sent over the wire from the client to the server. @@ -247,10 +275,10 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { } impl std::future::Future for #response_fut_ident { - type Output = std::io::Result<#response_ident>; + type Output = #response_ident; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) - -> std::task::Poll> + -> std::task::Poll<#response_ident> { unsafe { match std::pin::Pin::get_unchecked_mut(self) { @@ -258,44 +286,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { #response_fut_ident_repeated::#camel_case_idents(resp) => std::pin::Pin::new_unchecked(resp) .poll(cx) - .map(#response_ident_repeated::#camel_case_idents2) - .map(Ok), + .map(#response_ident_repeated::#camel_case_idents2), )* } } } } - /// Returns a serving function to use with tarpc::server::Server. - #vis fn #serve_ident(service: S) - -> impl FnOnce(tarpc::context::Context, #request_ident) -> #response_fut_ident + Send + 'static + Clone { - move |ctx, req| { - match req { - #( - #request_ident_repeated::#camel_case_idents{ #arg_vars } => { - #response_fut_ident_repeated2::#camel_case_idents2( - #service_name_repeated2::#method_names( - service.clone(), ctx, #arg_vars2)) - } - )* - } - } - } - #[allow(unused)] #[derive(Clone, Debug)] /// The client stub that makes RPC calls to the server. Exposes a Future interface. #vis struct #client_ident>(C); - /// Returns a new client stub that sends requests over the given transport. - #vis async fn #stub_ident(config: tarpc::client::Config, transport: T) - -> std::io::Result<#client_ident> - where - T: tarpc::Transport, tarpc::Response<#response_ident>> + Send + 'static, - { - Ok(#client_ident(tarpc::client::new(config, transport).await?)) - } - impl From for #client_ident where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident> { @@ -304,13 +306,25 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { } } + impl #client_ident { + /// Returns a new client stub that sends requests over the given transport. + #vis async fn new(config: tarpc::client::Config, transport: T) + -> std::io::Result + where + T: tarpc::Transport, tarpc::Response<#response_ident>> + Send + 'static + { + Ok(#client_ident(tarpc::client::new(config, transport).await?)) + } + + } + impl #client_ident where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident> { #( #[allow(unused)] #( #method_attrs )* - pub fn #method_names(&mut self, ctx: tarpc::context::Context, #args) + #vis_repeated fn #method_names(&mut self, ctx: tarpc::context::Context, #args) -> impl std::future::Future> + '_ { let request = #request_ident_repeated2::#camel_case_idents { #arg_vars }; let resp = tarpc::Client::call(&mut self.0, ctx, request); @@ -347,28 +361,6 @@ fn snake_to_camel(ident_str: &str) -> String { camel_ty } -// Really basic camel to snake that assumes capitals are always the start of a new segment. -fn camel_to_snake(ident_str: &str) -> String { - let mut snake = String::new(); - let mut chars = ident_str.chars(); - if let Some(c) = chars.next() { - snake.extend(c.to_lowercase()); - } - - while let Some(c) = chars.next() { - if c.is_uppercase() { - // New word - snake.push('_'); - snake.extend(c.to_lowercase()); - } else { - // Same word - snake.push(c) - } - } - - snake -} - #[test] fn snake_to_camel_basic() { assert_eq!(snake_to_camel("abc_def"), "AbcDef"); @@ -393,8 +385,3 @@ fn snake_to_camel_underscore_consecutive() { fn snake_to_camel_capital_in_middle() { assert_eq!(snake_to_camel("aBc_dEf"), "AbcDef"); } - -#[test] -fn camel_to_snake_basic() { - assert_eq!(camel_to_snake("AbcDef"), "abc_def"); -} diff --git a/rpc/src/client/channel.rs b/rpc/src/client/channel.rs index 0eaa126..c6bb428 100644 --- a/rpc/src/client/channel.rs +++ b/rpc/src/client/channel.rs @@ -180,7 +180,15 @@ impl Future for DispatchResponse { Poll::Ready(match resp { Ok(resp) => { self.complete = true; - Ok(resp.message?) + match resp { + Ok(resp) => Ok(resp.message?), + Err(oneshot::Canceled) => { + // The oneshot is Canceled when the dispatch task ends. In that case, + // there's nothing listening on the other side, so there's no point in + // propagating cancellation. + Err(io::Error::from(io::ErrorKind::ConnectionReset)) + } + } } Err(e) => Err({ let trace_id = *self.as_mut().ctx().trace_id(); @@ -211,7 +219,7 @@ impl Future for DispatchResponse { self.complete = true; io::Error::from(io::ErrorKind::ConnectionReset) } else { - panic!("[{}] Unrecognized deadline error: {}", trace_id, e) + panic!("[{}] Unrecognized deadline error: {:?}", trace_id, e) } }), }) diff --git a/rpc/src/server/mod.rs b/rpc/src/server/mod.rs index 6885c2e..60c6124 100644 --- a/rpc/src/server/mod.rs +++ b/rpc/src/server/mod.rs @@ -23,7 +23,6 @@ use humantime::format_rfc3339; use log::{debug, error, info, trace, warn}; use pin_utils::{unsafe_pinned, unsafe_unpinned}; use std::{ - error::Error as StdError, fmt, hash::Hash, io, @@ -113,29 +112,29 @@ impl Server { /// The future driving the server. #[derive(Debug)] -pub struct Running { - incoming: S, - request_handler: F, +pub struct Running { + incoming: St, + server: Se, } -impl Running { - unsafe_pinned!(incoming: S); - unsafe_unpinned!(request_handler: F); +impl Running { + unsafe_pinned!(incoming: St); + unsafe_unpinned!(server: Se); } -impl Future for Running +impl Future for Running where - S: Sized + Stream, + St: Sized + Stream, C: Channel + Send + 'static, - F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone, - Fut: Future> + Send + 'static, + Se: Serve + Send + 'static, + Se::Fut: Send + 'static { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) { if let Err(e) = - crate::spawn(channel.respond_with(self.as_mut().request_handler().clone())) + crate::spawn(channel.respond_with(self.as_mut().server().clone())) { warn!("Failed to spawn channel handler: {:?}", e); } @@ -145,6 +144,30 @@ where } } +/// Basically a Fn(Req) -> impl Future; +pub trait Serve: Sized + Clone { + /// Type of response. + type Resp; + + /// Type of response future. + type Fut: Future; + + /// Responds to a single request. + fn serve(self, ctx: context::Context, req: Req) -> Self::Fut; +} + +impl Serve for F +where F: FnOnce(context::Context, Req) -> Fut + Clone, + Fut: Future +{ + type Resp = Resp; + type Fut = Fut; + + fn serve(self, ctx: context::Context, req: Req) -> Self::Fut { + self(ctx, req) + } +} + /// A utility trait enabling a stream to fluently chain a request handler. pub trait Handler where @@ -165,15 +188,15 @@ where ThrottlerStream::new(self, n) } - /// Responds to all requests with `request_handler`. - fn respond_with(self, request_handler: F) -> Running + /// Responds to all requests with `server`. + fn respond_with(self, server: S) -> Running where - F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone, - Fut: Future> + Send + 'static, + S: Serve + Send + 'static, + S::Fut: Send + 'static, { Running { incoming: self, - request_handler, + server, } } } @@ -291,10 +314,10 @@ where /// Respond to requests coming over the channel with `f`. Returns a future that drives the /// responses and resolves when the connection is closed. - fn respond_with(self, f: F) -> ResponseHandler + fn respond_with(self, server: S) -> ResponseHandler where - F: FnOnce(context::Context, Self::Req) -> Fut + Send + 'static + Clone, - Fut: Future> + Send + 'static, + S: Serve + Send + 'static, + S::Fut: Send + 'static, Self: Sized, { let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer); @@ -302,7 +325,7 @@ where ResponseHandler { channel: self, - f, + server, pending_responses: responses, responses_tx, } @@ -406,7 +429,7 @@ where /// A running handler serving all requests coming over a channel. #[derive(Debug)] -pub struct ResponseHandler +pub struct ResponseHandler where C: Channel, { @@ -416,10 +439,10 @@ where /// Handed out to request handlers to fan in responses. responses_tx: mpsc::Sender<(context::Context, Response)>, /// Request handler. - f: F, + server: S, } -impl ResponseHandler +impl ResponseHandler where C: Channel, { @@ -428,14 +451,14 @@ where unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response)>); // For this to be safe, field f must be private, and code in this module must never // construct PinMut. - unsafe_unpinned!(f: F); + unsafe_unpinned!(server: S); } -impl ResponseHandler +impl ResponseHandler where C: Channel, - F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone, - Fut: Future> + Send + 'static, + S: Serve + Send + 'static, + S::Fut: Send + 'static, { fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { match ready!(self.as_mut().channel().poll_next(cx)?) { @@ -516,7 +539,7 @@ where let mut response_tx = self.as_mut().responses_tx().clone(); let trace_id = *ctx.trace_id(); - let response = self.as_mut().f().clone()(ctx, request); + let response = self.as_mut().server().clone().serve(ctx, request); let response = deadline_compat::Deadline::new(response, Instant::now() + timeout).then( move |result| { async move { @@ -550,11 +573,11 @@ where } } -impl Future for ResponseHandler +impl Future for ResponseHandler where C: Channel, - F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone, - Fut: Future> + Send + 'static, + S: Serve + Send + 'static, + S::Fut: Send + 'static, { type Output = (); @@ -581,7 +604,7 @@ where } fn make_server_error( - e: timeout::Error, + e: timeout::Error<()>, trace_id: TraceId, deadline: SystemTime, ) -> ServerError { @@ -601,26 +624,20 @@ fn make_server_error( } } else if e.is_timer() { error!( - "[{}] Response failed because of an issue with a timer: {}", + "[{}] Response failed because of an issue with a timer: {:?}", trace_id, e ); ServerError { kind: io::ErrorKind::Other, - detail: Some(format!("{}", e)), - } - } else if e.is_inner() { - let e = e.into_inner().unwrap(); - ServerError { - kind: e.kind(), - detail: Some(e.description().into()), + detail: Some(format!("{:?}", e)), } } else { - error!("[{}] Unexpected response failure: {}", trace_id, e); + error!("[{}] Unexpected response failure: {:?}", trace_id, e); ServerError { kind: io::ErrorKind::Other, - detail: Some(format!("Server unexpectedly failed to respond: {}", e)), + detail: Some(format!("Server unexpectedly failed to respond: {:?}", e)), } } } diff --git a/rpc/src/transport/channel.rs b/rpc/src/transport/channel.rs index e2786c3..309d12b 100644 --- a/rpc/src/transport/channel.rs +++ b/rpc/src/transport/channel.rs @@ -93,9 +93,9 @@ mod tests { let (client_channel, server_channel) = transport::channel::unbounded(); crate::spawn( - Server::::default() + Server::default() .incoming(stream::once(future::ready(server_channel))) - .respond_with(|_ctx, request| { + .respond_with(|_ctx, request: String| { future::ready(request.parse::().map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, @@ -108,8 +108,8 @@ mod tests { let mut client = client::new(client::Config::default(), client_channel).await?; - let response1 = client.call(context::current(), "123".into()).await; - let response2 = client.call(context::current(), "abc".into()).await; + let response1 = client.call(context::current(), "123".into()).await?; + let response2 = client.call(context::current(), "abc".into()).await?; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/rpc/src/util/deadline_compat.rs b/rpc/src/util/deadline_compat.rs index c91f20d..7898df8 100644 --- a/rpc/src/util/deadline_compat.rs +++ b/rpc/src/util/deadline_compat.rs @@ -46,16 +46,15 @@ impl Deadline { } impl Future for Deadline where - T: TryFuture, + T: Future, { - type Output = Result>; + type Output = Result>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { // First, try polling the future - match self.as_mut().future().try_poll(cx) { - Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)), + match self.as_mut().future().poll(cx) { + Poll::Ready(v) => return Poll::Ready(Ok(v)), Poll::Pending => {} - Poll::Ready(Err(e)) => return Poll::Ready(Err(timeout::Error::inner(e))), } let delay = self.delay().poll_unpin(cx); diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index fdfbeba..85dfb9d 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -11,6 +11,7 @@ use futures::{ prelude::*, Future, }; +use publisher::Publisher as _; use rpc::{ client, context, server::{self, Handler}, @@ -23,6 +24,7 @@ use std::{ thread, time::Duration, }; +use subscriber::Subscriber as _; pub mod subscriber { #[tarpc::service] @@ -65,7 +67,7 @@ impl Subscriber { server::new(config) .incoming(incoming) .take(1) - .respond_with(subscriber::serve_subscriber(Subscriber { id })), + .respond_with(Subscriber { id }.serve()), ); Ok(addr) } @@ -114,7 +116,7 @@ impl publisher::Publisher for Publisher { ) -> io::Result<()> { let conn = bincode_transport::connect(&addr).await?; let subscriber = - subscriber::subscriber_stub(client::Config::default(), conn).await?; + subscriber::SubscriberClient::new(client::Config::default(), conn).await?; eprintln!("Subscribing {}.", id); clients.lock().unwrap().insert(id, subscriber); Ok(()) @@ -149,7 +151,7 @@ async fn main() -> io::Result<()> { transport .take(1) .map(server::BaseChannel::with_defaults) - .respond_with(publisher::serve_publisher(Publisher::new())), + .respond_with(Publisher::new().serve()), ); let subscriber1 = Subscriber::listen(0, server::Config::default()).await?; @@ -158,7 +160,7 @@ async fn main() -> io::Result<()> { let publisher_conn = bincode_transport::connect(&publisher_addr); let publisher_conn = publisher_conn.await?; let mut publisher = - publisher::publisher_stub(client::Config::default(), publisher_conn).await?; + publisher::PublisherClient::new(client::Config::default(), publisher_conn).await?; if let Err(e) = publisher .subscribe(context::current(), 0, subscriber1) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 1f5d17b..5d44be5 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -57,16 +57,16 @@ async fn main() -> io::Result<()> { BaseChannel::with_defaults(client) // serve_world is generated by the tarpc::service attribute. It takes as input any type // implementing the generated World trait. - .respond_with(serve_world(HelloServer)) + .respond_with(HelloServer.serve()) .await; }; let _ = runtime::spawn(server); let transport = bincode_transport::connect(&addr).await?; - // world_stub is generated by the tarpc::service attribute. Like Server, it takes a config and - // any Transport as input, and returns a Client, also generated by the attribute. - let mut client = world_stub(client::Config::default(), transport).await?; + // WorldClient is generated by the tarpc::service attribute. It has a constructor `new` that + // takes a config and any Transport as input. + let mut client = WorldClient::new(client::Config::default(), transport).await?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/tarpc/examples/server_calling_server.rs b/tarpc/examples/server_calling_server.rs index 09cc166..025d36c 100644 --- a/tarpc/examples/server_calling_server.rs +++ b/tarpc/examples/server_calling_server.rs @@ -74,11 +74,11 @@ async fn main() -> io::Result<()> { let add_server = Server::default() .incoming(add_listener) .take(1) - .respond_with(add::serve_add(AddServer)); + .respond_with(AddServer.serve()); let _ = runtime::spawn(add_server); let to_add_server = bincode_transport::connect(&addr).await?; - let add_client = add::add_stub(client::Config::default(), to_add_server).await?; + let add_client = add::AddClient::new(client::Config::default(), to_add_server).await?; let double_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())? .filter_map(|r| future::ready(r.ok())); @@ -86,12 +86,12 @@ async fn main() -> io::Result<()> { let double_server = rpc::Server::default() .incoming(double_listener) .take(1) - .respond_with(double::serve_double(DoubleServer { add_client })); + .respond_with(DoubleServer { add_client }.serve()); let _ = runtime::spawn(double_server); let to_double_server = bincode_transport::connect(&addr).await?; let mut double_client = - double::double_stub(client::Config::default(), to_double_server).await?; + double::DoubleClient::new(client::Config::default(), to_double_server).await?; for i in 1..=5 { eprintln!("{:?}", double_client.double(context::current(), i).await?); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index b9913d3..fa7f3b5 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -47,10 +47,10 @@ async fn sequential() -> io::Result<()> { let _ = runtime::spawn( BaseChannel::new(server::Config::default(), rx) - .respond_with(serve_service(Server)) + .respond_with(Server.serve()) ); - let mut client = service_stub(client::Config::default(), tx).await?; + let mut client = ServiceClient::new(client::Config::default(), tx).await?; assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); assert_matches!( @@ -70,11 +70,11 @@ async fn serde() -> io::Result<()> { let _ = runtime::spawn( tarpc::Server::default() .incoming(transport.take(1).filter_map(|r| async { r.ok() })) - .respond_with(serve_service(Server)), + .respond_with(Server.serve()), ); let transport = bincode_transport::connect(&addr).await?; - let mut client = service_stub(client::Config::default(), transport).await?; + let mut client = ServiceClient::new(client::Config::default(), transport).await?; assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); assert_matches!( @@ -93,10 +93,10 @@ async fn concurrent() -> io::Result<()> { let _ = runtime::spawn( rpc::Server::default() .incoming(stream::once(ready(rx))) - .respond_with(serve_service(Server)), + .respond_with(Server.serve()), ); - let client = service_stub(client::Config::default(), tx).await?; + let client = ServiceClient::new(client::Config::default(), tx).await?; let mut c = client.clone(); let req1 = c.add(context::current(), 1, 2);