From 50879d2acb6a984526da00762f3450ac2e8c2d06 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Mon, 22 Jul 2019 13:13:08 -0700 Subject: [PATCH] Don't bake in Send + 'static. Send + 'static was baked in to make it possible to spawn futures onto the default executor. We can accomplish the same thing by offering helper fns that do the spawning while not requiring it for the rest of the functionality. Fixes https://github.com/google/tarpc/issues/212 --- README.md | 2 +- example-service/src/client.rs | 2 +- example-service/src/server.rs | 4 +- plugins/src/lib.rs | 69 +++-- rpc/src/client/channel.rs | 60 ++--- rpc/src/client/mod.rs | 46 +++- rpc/src/server/mod.rs | 339 +++++++++++++++--------- rpc/src/server/testing.rs | 3 +- rpc/src/server/throttle.rs | 6 +- rpc/src/transport/channel.rs | 2 +- tarpc/Cargo.toml | 2 + tarpc/examples/pubsub.rs | 4 +- tarpc/examples/readme.rs | 3 +- tarpc/examples/server_calling_server.rs | 4 +- tarpc/tests/service_functional.rs | 92 ++++++- 15 files changed, 428 insertions(+), 210 deletions(-) diff --git a/README.md b/README.md index e098a00..608b5b4 100644 --- a/README.md +++ b/README.md @@ -182,7 +182,7 @@ async fn main() -> io::Result<()> { // WorldClient is generated by the macro. It has a constructor `new` that takes a config and // any Transport as input - let mut client = WorldClient::new(client::Config::default(), client_transport).await?; + let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 5074b0e..c0e8615 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -46,7 +46,7 @@ async fn main() -> io::Result<()> { // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. - let mut client = service::WorldClient::new(client::Config::default(), transport).await?; + let mut client = service::WorldClient::new(client::Config::default(), transport).spawn()?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/example-service/src/server.rs b/example-service/src/server.rs index b60d7d2..acd6766 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -75,11 +75,11 @@ async fn main() -> io::Result<()> { // the generated World trait. .map(|channel| { let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap()); - channel.respond_with(server.serve()) + channel.respond_with(server.serve()).execute() }) // Max 10 channels. .buffer_unordered(10) - .for_each(|_| futures::future::ready(())) + .for_each(|_| async {}) .await; Ok(()) diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 5cb9269..b73b2c9 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -21,7 +21,7 @@ use syn::{ punctuated::Punctuated, spanned::Spanned, token::Comma, - ArgCaptured, Attribute, FnArg, Ident, Pat, ReturnType, Token, Visibility, + ArgCaptured, Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, ReturnType, Token, Visibility, }; struct Service { @@ -126,6 +126,40 @@ impl Parse for RpcMethod { } } +// If `derive_serde` meta item is not present, defaults to cfg!(feature = "serde1"). +// `derive_serde` can only be true when serde1 is enabled. +struct DeriveSerde(bool); + +impl Parse for DeriveSerde { + fn parse(input: ParseStream) -> syn::Result { + if input.is_empty() { + return Ok(DeriveSerde(cfg!(feature = "serde1"))) + } + match input.parse::()? { + MetaNameValue { ref ident, ref lit, .. } if ident == "derive_serde" => { + match lit { + Lit::Bool(LitBool{value: true, ..}) if cfg!(feature = "serde1") => Ok(DeriveSerde(true)), + Lit::Bool(LitBool{value: true, ..}) => Err(syn::Error::new( + lit.span(), + "To enable serde, first enable the `serde1` feature of tarpc", + )), + Lit::Bool(LitBool{value: false, ..}) => Ok(DeriveSerde(false)), + lit => Err(syn::Error::new( + lit.span(), + "`derive_serde` expects a value of type `bool`", + )), + } + } + MetaNameValue { ident, .. } => { + Err(syn::Error::new( + ident.span(), + "tarpc::service only supports one meta item, `derive_serde = {bool}`", + )) + } + } + } +} + /// Generates: /// - service trait /// - serve fn @@ -135,13 +169,7 @@ impl Parse for RpcMethod { /// - ResponseFut Future #[proc_macro_attribute] pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { - struct EmptyArgs; - impl Parse for EmptyArgs { - fn parse(_: ParseStream) -> syn::Result { - Ok(EmptyArgs) - } - } - parse_macro_input!(attr as EmptyArgs); + let derive_serde = parse_macro_input!(attr as DeriveSerde); let Service { attrs, @@ -223,14 +251,15 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { let response_fut_ident_repeated2 = response_fut_ident_repeated.clone(); let server_ident = Ident::new(&format!("Serve{}", ident), ident.span()); - #[cfg(feature = "serde1")] - let derive_serialize = quote!(#[derive(serde::Serialize, serde::Deserialize)]); - #[cfg(not(feature = "serde1"))] - let derive_serialize = quote!(); + let derive_serialize = if derive_serde.0 { + quote!(#[derive(serde::Serialize, serde::Deserialize)]) + } else { + quote!() + }; let tokens = quote! { #( #attrs )* - #vis trait #ident: Clone + Send + 'static { + #vis trait #ident: Clone { #( #types_and_fns )* /// Returns a serving function to use with tarpc::server::Server. @@ -322,12 +351,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { impl #client_ident { /// Returns a new client stub that sends requests over the given transport. - #vis async fn new(config: tarpc::client::Config, transport: T) - -> std::io::Result + #vis fn new(config: tarpc::client::Config, transport: T) + -> tarpc::client::NewClient< + Self, + tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>> where - T: tarpc::Transport, tarpc::Response<#response_ident>> + Send + 'static + T: tarpc::Transport, tarpc::Response<#response_ident>> { - Ok(#client_ident(tarpc::client::new(config, transport).await?)) + let new_client = tarpc::client::new(config, transport); + tarpc::client::NewClient { + client: #client_ident(new_client.client), + dispatch: new_client.dispatch, + } } } diff --git a/rpc/src/client/channel.rs b/rpc/src/client/channel.rs index c6bb428..6c91fb5 100644 --- a/rpc/src/client/channel.rs +++ b/rpc/src/client/channel.rs @@ -19,11 +19,11 @@ use futures::{ Poll, }; use humantime::format_rfc3339; -use log::{debug, error, info, trace}; +use log::{debug, info, trace}; use pin_utils::{unsafe_pinned, unsafe_unpinned}; use std::{ io, - marker::{self, Unpin}, + marker::Unpin, pin::Pin, sync::{ atomic::{AtomicU64, Ordering}, @@ -33,7 +33,7 @@ use std::{ }; use trace::SpanId; -use super::Config; +use super::{Config, NewClient}; /// Handles communication from the client to request dispatch. #[derive(Debug)] @@ -246,48 +246,39 @@ impl Drop for DispatchResponse { } } -/// Spawns a dispatch task on the default executor that manages the lifecycle of requests initiated -/// by the returned [`Channel`]. -pub async fn spawn(config: Config, transport: C) -> io::Result> +/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the +/// channel. +pub fn new( + config: Config, + transport: C, +) -> NewClient, RequestDispatch> where - Req: marker::Send + 'static, - Resp: marker::Send + 'static, - C: Transport, Response> + marker::Send + 'static, + C: Transport, Response>, { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); let canceled_requests = canceled_requests.fuse(); - crate::spawn( - RequestDispatch { + NewClient { + client: Channel { + to_dispatch, + cancellation, + next_request_id: Arc::new(AtomicU64::new(0)), + }, + dispatch: RequestDispatch { config, canceled_requests, transport: transport.fuse(), in_flight_requests: FnvHashMap::default(), pending_requests: pending_requests.fuse(), - } - .unwrap_or_else(move |e| error!("Connection broken: {}", e)), - ) - .map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!( - "Could not spawn client dispatch task. Is shutdown: {}", - e.is_shutdown() - ), - ) - })?; - - Ok(Channel { - to_dispatch, - cancellation, - next_request_id: Arc::new(AtomicU64::new(0)), - }) + }, + } } /// Handles the lifecycle of requests, writing requests to the wire, managing cancellations, /// and dispatching responses to the appropriate channel. -struct RequestDispatch { +#[derive(Debug)] +pub struct RequestDispatch { /// Writes requests to the wire and reads responses off the wire. transport: Fuse, /// Requests waiting to be written to the wire. @@ -302,8 +293,6 @@ struct RequestDispatch { impl RequestDispatch where - Req: marker::Send, - Resp: marker::Send, C: Transport, Response>, { unsafe_pinned!(in_flight_requests: FnvHashMap>); @@ -492,8 +481,6 @@ where impl Future for RequestDispatch where - Req: marker::Send, - Resp: marker::Send, C: Transport, Response>, { type Output = io::Result<()>; @@ -532,6 +519,7 @@ struct DispatchRequest { response_completion: oneshot::Sender>, } +#[derive(Debug)] struct InFlightData { ctx: context::Context, response_completion: oneshot::Sender>, @@ -776,7 +764,7 @@ mod tests { }; use futures_test::task::noop_waker_ref; use std::time::Duration; - use std::{marker, pin::Pin, sync::atomic::AtomicU64, sync::Arc, time::Instant}; + use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc, time::Instant}; #[test] fn dispatch_response_cancels_on_timeout() { @@ -955,7 +943,7 @@ mod tests { impl PollTest for Poll>> where - E: ::std::fmt::Display + marker::Send + 'static, + E: ::std::fmt::Display, { type T = Option; diff --git a/rpc/src/client/mod.rs b/rpc/src/client/mod.rs index ccb0758..fdec802 100644 --- a/rpc/src/client/mod.rs +++ b/rpc/src/client/mod.rs @@ -6,13 +6,14 @@ //! Provides a client that connects to a server and sends multiplexed requests. -use crate::{context, ClientMessage, Response, Transport}; +use crate::context; use futures::prelude::*; +use log::error; use std::io; /// Provides a [`Client`] backed by a transport. pub mod channel; -pub use self::channel::Channel; +pub use channel::{new, Channel}; /// Sends multiplexed requests to, and receives responses from, a server. pub trait Client<'a, Req> { @@ -125,15 +126,34 @@ impl Default for Config { } } -/// Creates a new Client by wrapping a [`Transport`] and spawning a dispatch task -/// that manages the lifecycle of requests. -/// -/// Must only be called from on an executor. -pub async fn new(config: Config, transport: T) -> io::Result> -where - Req: Send + 'static, - Resp: Send + 'static, - T: Transport, Response> + Send + 'static, -{ - Ok(channel::spawn(config, transport).await?) +/// A channel and dispatch pair. The dispatch drives the sending and receiving of requests +/// and must be polled continuously or spawned. +#[derive(Debug)] +pub struct NewClient { + /// The new client. + pub client: C, + /// The client's dispatch. + pub dispatch: D, +} + +impl NewClient +where + D: Future> + Send + 'static, +{ + /// Helper method to spawn the dispatch on the default executor. + pub fn spawn(self) -> io::Result { + let dispatch = self + .dispatch + .unwrap_or_else(move |e| error!("Connection broken: {}", e)); + crate::spawn(dispatch).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!( + "Could not spawn client dispatch task. Is shutdown: {}", + e.is_shutdown() + ), + ) + })?; + Ok(self.client) + } } diff --git a/rpc/src/server/mod.rs b/rpc/src/server/mod.rs index 60c6124..cec87b3 100644 --- a/rpc/src/server/mod.rs +++ b/rpc/src/server/mod.rs @@ -78,7 +78,7 @@ impl Config { /// Returns a channel backed by `transport` and configured with `self`. pub fn channel(self, transport: T) -> BaseChannel where - T: Transport, ClientMessage> + Send, + T: Transport, ClientMessage>, { BaseChannel::new(self, transport) } @@ -101,49 +101,13 @@ impl Server { /// Returns a stream of server channels. pub fn incoming(self, listener: S) -> impl Stream> where - Req: Send, - Resp: Send, S: Stream, - T: Transport, ClientMessage> + Send, + T: Transport, ClientMessage>, { listener.map(move |t| BaseChannel::new(self.config.clone(), t)) } } -/// The future driving the server. -#[derive(Debug)] -pub struct Running { - incoming: St, - server: Se, -} - -impl Running { - unsafe_pinned!(incoming: St); - unsafe_unpinned!(server: Se); -} - -impl Future for Running -where - St: Sized + Stream, - C: Channel + Send + 'static, - Se: Serve + Send + 'static, - Se::Fut: Send + 'static -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) { - if let Err(e) = - crate::spawn(channel.respond_with(self.as_mut().server().clone())) - { - warn!("Failed to spawn channel handler: {:?}", e); - } - } - info!("Server shutting down."); - Poll::Ready(()) - } -} - /// Basically a Fn(Req) -> impl Future; pub trait Serve: Sized + Clone { /// Type of response. @@ -191,8 +155,7 @@ where /// Responds to all requests with `server`. fn respond_with(self, server: S) -> Running where - S: Serve + Send + 'static, - S::Fut: Send + 'static, + S: Serve, { Running { incoming: self, @@ -226,7 +189,7 @@ impl BaseChannel { impl BaseChannel where - T: Transport, ClientMessage> + Send, + T: Transport, ClientMessage>, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -288,10 +251,10 @@ where Self: Transport::Resp>, Request<::Req>>, { /// Type of request item. - type Req: Send + 'static; + type Req; /// Type of response sink item. - type Resp: Send + 'static; + type Resp; /// Configuration of the channel. fn config(&self) -> &Config; @@ -314,16 +277,15 @@ where /// Respond to requests coming over the channel with `f`. Returns a future that drives the /// responses and resolves when the connection is closed. - fn respond_with(self, server: S) -> ResponseHandler + fn respond_with(self, server: S) -> ClientHandler where - S: Serve + Send + 'static, - S::Fut: Send + 'static, + S: Serve, Self: Sized, { let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer); let responses = responses.fuse(); - ResponseHandler { + ClientHandler { channel: self, server, pending_responses: responses, @@ -334,9 +296,7 @@ where impl Stream for BaseChannel where - T: Transport, ClientMessage> + Send + 'static, - Req: Send + 'static, - Resp: Send + 'static, + T: Transport, ClientMessage>, { type Item = io::Result>; @@ -362,9 +322,7 @@ where impl Sink> for BaseChannel where - T: Transport, ClientMessage> + Send + 'static, - Req: Send + 'static, - Resp: Send + 'static, + T: Transport, ClientMessage>, { type Error = io::Error; @@ -402,9 +360,7 @@ impl AsRef for BaseChannel { impl Channel for BaseChannel where - T: Transport, ClientMessage> + Send + 'static, - Req: Send + 'static, - Resp: Send + 'static, + T: Transport, ClientMessage>, { type Req = Req; type Resp = Resp; @@ -429,7 +385,7 @@ where /// A running handler serving all requests coming over a channel. #[derive(Debug)] -pub struct ResponseHandler +pub struct ClientHandler where C: Channel, { @@ -438,11 +394,11 @@ where pending_responses: Fuse)>>, /// Handed out to request handlers to fan in responses. responses_tx: mpsc::Sender<(context::Context, Response)>, - /// Request handler. + /// Server server: S, } -impl ResponseHandler +impl ClientHandler where C: Channel, { @@ -450,22 +406,21 @@ where unsafe_pinned!(pending_responses: Fuse)>>); unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response)>); // For this to be safe, field f must be private, and code in this module must never - // construct PinMut. + // construct PinMut. unsafe_unpinned!(server: S); } -impl ResponseHandler +impl ClientHandler where C: Channel, - S: Serve + Send + 'static, - S::Fut: Send + 'static, + S: Serve, { - fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { + fn pump_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> PollIo> { match ready!(self.as_mut().channel().poll_next(cx)?) { - Some(request) => { - self.handle_request(request)?; - Poll::Ready(Some(Ok(()))) - } + Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))), None => Poll::Ready(None), } } @@ -518,13 +473,16 @@ where match ready!(self.as_mut().pending_responses().poll_next(cx)) { Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))), None => { - // This branch likely won't happen, since the ResponseHandler is holding a Sender. + // This branch likely won't happen, since the ClientHandler is holding a Sender. Poll::Ready(None) } } } - fn handle_request(mut self: Pin<&mut Self>, request: Request) -> io::Result<()> { + fn handle_request( + mut self: Pin<&mut Self>, + request: Request, + ) -> RequestHandler { let request_id = request.id; let deadline = request.context.deadline; let timeout = deadline.as_duration(); @@ -536,70 +494,144 @@ where ); let ctx = request.context; let request = request.message; - let mut response_tx = self.as_mut().responses_tx().clone(); - let trace_id = *ctx.trace_id(); let response = self.as_mut().server().clone().serve(ctx, request); - let response = deadline_compat::Deadline::new(response, Instant::now() + timeout).then( - move |result| { - async move { - let response = Response { - request_id, - message: match result { - Ok(message) => Ok(message), - Err(e) => Err(make_server_error(e, trace_id, deadline)), - }, - }; - trace!("[{}] Sending response.", trace_id); - response_tx - .send((ctx, response)) - .unwrap_or_else(|_| ()) - .await; - } - }, - ); + let response = Resp { + state: RespState::PollResp, + request_id, + ctx, + deadline, + f: deadline_compat::Deadline::new(response, Instant::now() + timeout), + response: None, + response_tx: self.as_mut().responses_tx().clone(), + }; let abort_registration = self.as_mut().channel().start_request(request_id); - let response = Abortable::new(response, abort_registration); - crate::spawn(response.map(|_| ())).map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!( - "Could not spawn response task. Is shutdown: {}", - e.is_shutdown() - ), - ) - })?; - Ok(()) + RequestHandler { + resp: Abortable::new(response, abort_registration), + } } } -impl Future for ResponseHandler +/// A future fulfilling a single client request. +#[derive(Debug)] +pub struct RequestHandler { + resp: Abortable>, +} + +impl RequestHandler { + unsafe_pinned!(resp: Abortable>); +} + +impl Future for RequestHandler where - C: Channel, - S: Serve + Send + 'static, - S::Fut: Send + 'static, + F: Future, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let _ = ready!(self.resp().poll(cx)); + Poll::Ready(()) + } +} + +#[derive(Debug)] +struct Resp { + state: RespState, + request_id: u64, + ctx: context::Context, + deadline: SystemTime, + f: deadline_compat::Deadline, + response: Option>, + response_tx: mpsc::Sender<(context::Context, Response)>, +} + +#[derive(Debug)] +enum RespState { + PollResp, + PollReady, + PollFlush, +} + +impl Resp { + unsafe_pinned!(f: deadline_compat::Deadline); + unsafe_pinned!(response_tx: mpsc::Sender<(context::Context, Response)>); + unsafe_unpinned!(response: Option>); + unsafe_unpinned!(state: RespState); +} + +impl Future for Resp +where + F: Future, { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - move || -> Poll> { - loop { - let read = self.as_mut().pump_read(cx)?; - match ( - read, - self.as_mut().pump_write(cx, read == Poll::Ready(None))?, - ) { - (Poll::Ready(None), Poll::Ready(None)) => { - return Poll::Ready(Ok(())); + loop { + match self.as_mut().state() { + RespState::PollResp => { + let result = ready!(self.as_mut().f().poll(cx)); + *self.as_mut().response() = Some(Response { + request_id: self.request_id, + message: match result { + Ok(message) => Ok(message), + Err(e) => { + Err(make_server_error(e, *self.ctx.trace_id(), self.deadline)) + } + }, + }); + *self.as_mut().state() = RespState::PollReady; + } + RespState::PollReady => { + let ready = ready!(self.as_mut().response_tx().poll_ready(cx)); + if ready.is_err() { + return Poll::Ready(()); } - (Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {} - _ => { - return Poll::Pending; + let resp = (self.ctx, self.as_mut().response().take().unwrap()); + if self.as_mut().response_tx().start_send(resp).is_err() { + return Poll::Ready(()); } + *self.as_mut().state() = RespState::PollFlush; + } + RespState::PollFlush => { + let ready = ready!(self.as_mut().response_tx().poll_flush(cx)); + if ready.is_err() { + return Poll::Ready(()); + } + return Poll::Ready(()); } } - }() - .map(|r| r.unwrap_or_else(|e| info!("ResponseHandler errored out: {}", e))) + } + } +} + +impl Stream for ClientHandler +where + C: Channel, + S: Serve, +{ + type Item = io::Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + let read = self.as_mut().pump_read(cx)?; + let read_closed = if let Poll::Ready(None) = read { + true + } else { + false + }; + match (read, self.as_mut().pump_write(cx, read_closed)?) { + (Poll::Ready(None), Poll::Ready(None)) => { + return Poll::Ready(None); + } + (Poll::Ready(Some(request_handler)), _) => { + return Poll::Ready(Some(Ok(request_handler))); + } + (_, Poll::Ready(Some(()))) => {} + _ => { + return Poll::Pending; + } + } + } } } @@ -641,3 +673,72 @@ fn make_server_error( } } } + +// Send + 'static execution helper methods. + +impl ClientHandler +where + C: Channel + 'static, + C::Req: Send + 'static, + C::Resp: Send + 'static, + S: Serve + Send + 'static, + S::Fut: Send + 'static, +{ + /// Runs the client handler until completion by spawning each + /// request handler onto the default executor. + pub fn execute(self) -> impl Future { + self.try_for_each(|request_handler| { + async { + crate::spawn(request_handler).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!( + "Could not spawn response task. Is shutdown: {}", + e.is_shutdown() + ), + ) + }) + } + }) + .unwrap_or_else(|e| info!("ClientHandler errored out: {}", e)) + } +} + +/// A future that drives the server by spawning channels and request handlers on the default +/// executor. +#[derive(Debug)] +pub struct Running { + incoming: St, + server: Se, +} + +impl Running { + unsafe_pinned!(incoming: St); + unsafe_unpinned!(server: Se); +} + +impl Future for Running +where + St: Sized + Stream, + C: Channel + Send + 'static, + C::Req: Send + 'static, + C::Resp: Send + 'static, + Se: Serve + Send + 'static + Clone, + Se::Fut: Send + 'static, +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) { + if let Err(e) = crate::spawn( + channel + .respond_with(self.as_mut().server().clone()) + .execute(), + ) { + warn!("Failed to spawn channel handler: {:?}", e); + } + } + info!("Server shutting down."); + Poll::Ready(()) + } +} diff --git a/rpc/src/server/testing.rs b/rpc/src/server/testing.rs index 804e601..5ba0455 100644 --- a/rpc/src/server/testing.rs +++ b/rpc/src/server/testing.rs @@ -60,8 +60,7 @@ impl Sink> for FakeChannel> { impl Channel for FakeChannel>, Response> where - Req: Unpin + Send + 'static, - Resp: Send + 'static, + Req: Unpin, { type Req = Req; type Resp = Resp; diff --git a/rpc/src/server/throttle.rs b/rpc/src/server/throttle.rs index 4ba9f19..cedb219 100644 --- a/rpc/src/server/throttle.rs +++ b/rpc/src/server/throttle.rs @@ -286,11 +286,7 @@ fn throttler_poll_next_throttled_sink_not_ready() { Poll::Pending } } - impl Channel for PendingSink>, Response> - where - Req: Send + 'static, - Resp: Send + 'static, - { + impl Channel for PendingSink>, Response> { type Req = Req; type Resp = Resp; fn config(&self) -> &Config { diff --git a/rpc/src/transport/channel.rs b/rpc/src/transport/channel.rs index 309d12b..c38c28f 100644 --- a/rpc/src/transport/channel.rs +++ b/rpc/src/transport/channel.rs @@ -106,7 +106,7 @@ mod tests { ) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - let mut client = client::new(client::Config::default(), client_channel).await?; + let mut client = client::new(client::Config::default(), client_channel).spawn()?; let response1 = client.call(context::current(), "123".into()).await?; let response2 = client.call(context::current(), "abc".into()).await?; diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 7e9efe3..47b9d36 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -31,10 +31,12 @@ bytes = { version = "0.4", features = ["serde"] } env_logger = "0.6" futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] } humantime = "1.0" +log = "0.4" runtime = "0.3.0-alpha.6" runtime-tokio = "0.3.0-alpha.5" tokio-tcp = "0.1" pin-utils = "0.1.0-alpha.4" +tokio = "0.1" [[example]] name = "server_calling_server" diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 85dfb9d..9444f78 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -116,7 +116,7 @@ impl publisher::Publisher for Publisher { ) -> io::Result<()> { let conn = bincode_transport::connect(&addr).await?; let subscriber = - subscriber::SubscriberClient::new(client::Config::default(), conn).await?; + subscriber::SubscriberClient::new(client::Config::default(), conn).spawn()?; eprintln!("Subscribing {}.", id); clients.lock().unwrap().insert(id, subscriber); Ok(()) @@ -160,7 +160,7 @@ async fn main() -> io::Result<()> { let publisher_conn = bincode_transport::connect(&publisher_addr); let publisher_conn = publisher_conn.await?; let mut publisher = - publisher::PublisherClient::new(client::Config::default(), publisher_conn).await?; + publisher::PublisherClient::new(client::Config::default(), publisher_conn).spawn()?; if let Err(e) = publisher .subscribe(context::current(), 0, subscriber1) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 5d44be5..c4c7669 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -58,6 +58,7 @@ async fn main() -> io::Result<()> { // serve_world is generated by the tarpc::service attribute. It takes as input any type // implementing the generated World trait. .respond_with(HelloServer.serve()) + .execute() .await; }; let _ = runtime::spawn(server); @@ -66,7 +67,7 @@ async fn main() -> io::Result<()> { // WorldClient is generated by the tarpc::service attribute. It has a constructor `new` that // takes a config and any Transport as input. - let mut client = WorldClient::new(client::Config::default(), transport).await?; + let mut client = WorldClient::new(client::Config::default(), transport).spawn()?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/tarpc/examples/server_calling_server.rs b/tarpc/examples/server_calling_server.rs index 025d36c..400c69a 100644 --- a/tarpc/examples/server_calling_server.rs +++ b/tarpc/examples/server_calling_server.rs @@ -78,7 +78,7 @@ async fn main() -> io::Result<()> { let _ = runtime::spawn(add_server); let to_add_server = bincode_transport::connect(&addr).await?; - let add_client = add::AddClient::new(client::Config::default(), to_add_server).await?; + let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn()?; let double_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())? .filter_map(|r| future::ready(r.ok())); @@ -91,7 +91,7 @@ async fn main() -> io::Result<()> { let to_double_server = bincode_transport::connect(&addr).await?; let mut double_client = - double::DoubleClient::new(client::Config::default(), to_double_server).await?; + double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?; for i in 1..=5 { eprintln!("{:?}", double_client.double(context::current(), i).await?); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index fa7f3b5..5d3db88 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -1,21 +1,38 @@ #![feature(async_await)] -#[cfg(not(feature = "serde1"))] -use std::rc::Rc; use assert_matches::assert_matches; use futures::{ future::{ready, Ready}, prelude::*, }; - -use std::io; +use std::{rc::Rc, io}; use tarpc::{ - client, context, + client::{self, NewClient}, context, server::{self, BaseChannel, Channel, Handler}, transport::channel, }; +trait RuntimeExt { + fn exec_bg(&mut self, future: impl Future + 'static); + fn exec(&mut self, future: F) -> Result + where + F: Future>; +} + +impl RuntimeExt for tokio::runtime::current_thread::Runtime { + fn exec_bg(&mut self, future: impl Future + 'static) { + self.spawn(Box::pin(future.unit_error()).compat()); + } + + fn exec(&mut self, future: F) -> Result + where + F: Future>, + { + self.block_on(futures::compat::Compat::new(Box::pin(future))) + } +} + #[tarpc_plugins::service] trait Service { async fn add(x: i32, y: i32) -> i32; @@ -48,9 +65,10 @@ async fn sequential() -> io::Result<()> { let _ = runtime::spawn( BaseChannel::new(server::Config::default(), rx) .respond_with(Server.serve()) + .execute() ); - let mut client = ServiceClient::new(client::Config::default(), tx).await?; + let mut client = ServiceClient::new(client::Config::default(), tx).spawn()?; assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); assert_matches!( @@ -74,7 +92,7 @@ async fn serde() -> io::Result<()> { ); let transport = bincode_transport::connect(&addr).await?; - let mut client = ServiceClient::new(client::Config::default(), transport).await?; + let mut client = ServiceClient::new(client::Config::default(), transport).spawn()?; assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); assert_matches!( @@ -96,7 +114,7 @@ async fn concurrent() -> io::Result<()> { .respond_with(Server.serve()), ); - let client = ServiceClient::new(client::Config::default(), tx).await?; + let client = ServiceClient::new(client::Config::default(), tx).spawn()?; let mut c = client.clone(); let req1 = c.add(context::current(), 1, 2); @@ -113,3 +131,61 @@ async fn concurrent() -> io::Result<()> { Ok(()) } + +#[tarpc::service(derive_serde = false)] +trait InMemory { + async fn strong_count(rc: Rc<()>) -> usize; + async fn weak_count(rc: Rc<()>) -> usize; +} + +impl InMemory for () { + type StrongCountFut = Ready; + fn strong_count(self, _: context::Context, rc: Rc<()>) -> Self::StrongCountFut { + ready(Rc::strong_count(&rc)) + } + + type WeakCountFut = Ready; + fn weak_count(self, _: context::Context, rc: Rc<()>) -> Self::WeakCountFut { + ready(Rc::weak_count(&rc)) + } +} + +#[test] +fn in_memory_single_threaded() -> io::Result<()> { + use log::warn; + + let _ = env_logger::try_init(); + let mut runtime = tokio::runtime::current_thread::Runtime::new()?; + + let (tx, rx) = channel::unbounded(); + + let server = BaseChannel::new(server::Config::default(), rx) + .respond_with(().serve()) + .try_for_each(|r| async move { Ok(r.await) }); + runtime.exec_bg(async { + if let Err(e) = server.await { + warn!("Error while running server: {}", e); + } + }); + + let NewClient{mut client, dispatch} = InMemoryClient::new(client::Config::default(), tx); + runtime.exec_bg(async move { + if let Err(e) = dispatch.await { + warn!("Error while running client dispatch: {}", e) + } + }); + + let rc = Rc::new(()); + assert_matches!( + runtime.exec(client.strong_count(context::current(), rc.clone())), + Ok(2) + ); + + let _weak = Rc::downgrade(&rc); + assert_matches!( + runtime.exec(client.weak_count(context::current(), rc)), + Ok(1) + ); + + Ok(()) +}