diff --git a/README.md b/README.md index e098a00..608b5b4 100644 --- a/README.md +++ b/README.md @@ -182,7 +182,7 @@ async fn main() -> io::Result<()> { // WorldClient is generated by the macro. It has a constructor `new` that takes a config and // any Transport as input - let mut client = WorldClient::new(client::Config::default(), client_transport).await?; + let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 5074b0e..c0e8615 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -46,7 +46,7 @@ async fn main() -> io::Result<()> { // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. - let mut client = service::WorldClient::new(client::Config::default(), transport).await?; + let mut client = service::WorldClient::new(client::Config::default(), transport).spawn()?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/example-service/src/server.rs b/example-service/src/server.rs index b60d7d2..acd6766 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -75,11 +75,11 @@ async fn main() -> io::Result<()> { // the generated World trait. .map(|channel| { let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap()); - channel.respond_with(server.serve()) + channel.respond_with(server.serve()).execute() }) // Max 10 channels. .buffer_unordered(10) - .for_each(|_| futures::future::ready(())) + .for_each(|_| async {}) .await; Ok(()) diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 5cb9269..b73b2c9 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -21,7 +21,7 @@ use syn::{ punctuated::Punctuated, spanned::Spanned, token::Comma, - ArgCaptured, Attribute, FnArg, Ident, Pat, ReturnType, Token, Visibility, + ArgCaptured, Attribute, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, ReturnType, Token, Visibility, }; struct Service { @@ -126,6 +126,40 @@ impl Parse for RpcMethod { } } +// If `derive_serde` meta item is not present, defaults to cfg!(feature = "serde1"). +// `derive_serde` can only be true when serde1 is enabled. +struct DeriveSerde(bool); + +impl Parse for DeriveSerde { + fn parse(input: ParseStream) -> syn::Result { + if input.is_empty() { + return Ok(DeriveSerde(cfg!(feature = "serde1"))) + } + match input.parse::()? { + MetaNameValue { ref ident, ref lit, .. } if ident == "derive_serde" => { + match lit { + Lit::Bool(LitBool{value: true, ..}) if cfg!(feature = "serde1") => Ok(DeriveSerde(true)), + Lit::Bool(LitBool{value: true, ..}) => Err(syn::Error::new( + lit.span(), + "To enable serde, first enable the `serde1` feature of tarpc", + )), + Lit::Bool(LitBool{value: false, ..}) => Ok(DeriveSerde(false)), + lit => Err(syn::Error::new( + lit.span(), + "`derive_serde` expects a value of type `bool`", + )), + } + } + MetaNameValue { ident, .. } => { + Err(syn::Error::new( + ident.span(), + "tarpc::service only supports one meta item, `derive_serde = {bool}`", + )) + } + } + } +} + /// Generates: /// - service trait /// - serve fn @@ -135,13 +169,7 @@ impl Parse for RpcMethod { /// - ResponseFut Future #[proc_macro_attribute] pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { - struct EmptyArgs; - impl Parse for EmptyArgs { - fn parse(_: ParseStream) -> syn::Result { - Ok(EmptyArgs) - } - } - parse_macro_input!(attr as EmptyArgs); + let derive_serde = parse_macro_input!(attr as DeriveSerde); let Service { attrs, @@ -223,14 +251,15 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { let response_fut_ident_repeated2 = response_fut_ident_repeated.clone(); let server_ident = Ident::new(&format!("Serve{}", ident), ident.span()); - #[cfg(feature = "serde1")] - let derive_serialize = quote!(#[derive(serde::Serialize, serde::Deserialize)]); - #[cfg(not(feature = "serde1"))] - let derive_serialize = quote!(); + let derive_serialize = if derive_serde.0 { + quote!(#[derive(serde::Serialize, serde::Deserialize)]) + } else { + quote!() + }; let tokens = quote! { #( #attrs )* - #vis trait #ident: Clone + Send + 'static { + #vis trait #ident: Clone { #( #types_and_fns )* /// Returns a serving function to use with tarpc::server::Server. @@ -322,12 +351,18 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { impl #client_ident { /// Returns a new client stub that sends requests over the given transport. - #vis async fn new(config: tarpc::client::Config, transport: T) - -> std::io::Result + #vis fn new(config: tarpc::client::Config, transport: T) + -> tarpc::client::NewClient< + Self, + tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>> where - T: tarpc::Transport, tarpc::Response<#response_ident>> + Send + 'static + T: tarpc::Transport, tarpc::Response<#response_ident>> { - Ok(#client_ident(tarpc::client::new(config, transport).await?)) + let new_client = tarpc::client::new(config, transport); + tarpc::client::NewClient { + client: #client_ident(new_client.client), + dispatch: new_client.dispatch, + } } } diff --git a/rpc/src/client/channel.rs b/rpc/src/client/channel.rs index c6bb428..6c91fb5 100644 --- a/rpc/src/client/channel.rs +++ b/rpc/src/client/channel.rs @@ -19,11 +19,11 @@ use futures::{ Poll, }; use humantime::format_rfc3339; -use log::{debug, error, info, trace}; +use log::{debug, info, trace}; use pin_utils::{unsafe_pinned, unsafe_unpinned}; use std::{ io, - marker::{self, Unpin}, + marker::Unpin, pin::Pin, sync::{ atomic::{AtomicU64, Ordering}, @@ -33,7 +33,7 @@ use std::{ }; use trace::SpanId; -use super::Config; +use super::{Config, NewClient}; /// Handles communication from the client to request dispatch. #[derive(Debug)] @@ -246,48 +246,39 @@ impl Drop for DispatchResponse { } } -/// Spawns a dispatch task on the default executor that manages the lifecycle of requests initiated -/// by the returned [`Channel`]. -pub async fn spawn(config: Config, transport: C) -> io::Result> +/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the +/// channel. +pub fn new( + config: Config, + transport: C, +) -> NewClient, RequestDispatch> where - Req: marker::Send + 'static, - Resp: marker::Send + 'static, - C: Transport, Response> + marker::Send + 'static, + C: Transport, Response>, { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); let canceled_requests = canceled_requests.fuse(); - crate::spawn( - RequestDispatch { + NewClient { + client: Channel { + to_dispatch, + cancellation, + next_request_id: Arc::new(AtomicU64::new(0)), + }, + dispatch: RequestDispatch { config, canceled_requests, transport: transport.fuse(), in_flight_requests: FnvHashMap::default(), pending_requests: pending_requests.fuse(), - } - .unwrap_or_else(move |e| error!("Connection broken: {}", e)), - ) - .map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!( - "Could not spawn client dispatch task. Is shutdown: {}", - e.is_shutdown() - ), - ) - })?; - - Ok(Channel { - to_dispatch, - cancellation, - next_request_id: Arc::new(AtomicU64::new(0)), - }) + }, + } } /// Handles the lifecycle of requests, writing requests to the wire, managing cancellations, /// and dispatching responses to the appropriate channel. -struct RequestDispatch { +#[derive(Debug)] +pub struct RequestDispatch { /// Writes requests to the wire and reads responses off the wire. transport: Fuse, /// Requests waiting to be written to the wire. @@ -302,8 +293,6 @@ struct RequestDispatch { impl RequestDispatch where - Req: marker::Send, - Resp: marker::Send, C: Transport, Response>, { unsafe_pinned!(in_flight_requests: FnvHashMap>); @@ -492,8 +481,6 @@ where impl Future for RequestDispatch where - Req: marker::Send, - Resp: marker::Send, C: Transport, Response>, { type Output = io::Result<()>; @@ -532,6 +519,7 @@ struct DispatchRequest { response_completion: oneshot::Sender>, } +#[derive(Debug)] struct InFlightData { ctx: context::Context, response_completion: oneshot::Sender>, @@ -776,7 +764,7 @@ mod tests { }; use futures_test::task::noop_waker_ref; use std::time::Duration; - use std::{marker, pin::Pin, sync::atomic::AtomicU64, sync::Arc, time::Instant}; + use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc, time::Instant}; #[test] fn dispatch_response_cancels_on_timeout() { @@ -955,7 +943,7 @@ mod tests { impl PollTest for Poll>> where - E: ::std::fmt::Display + marker::Send + 'static, + E: ::std::fmt::Display, { type T = Option; diff --git a/rpc/src/client/mod.rs b/rpc/src/client/mod.rs index ccb0758..fdec802 100644 --- a/rpc/src/client/mod.rs +++ b/rpc/src/client/mod.rs @@ -6,13 +6,14 @@ //! Provides a client that connects to a server and sends multiplexed requests. -use crate::{context, ClientMessage, Response, Transport}; +use crate::context; use futures::prelude::*; +use log::error; use std::io; /// Provides a [`Client`] backed by a transport. pub mod channel; -pub use self::channel::Channel; +pub use channel::{new, Channel}; /// Sends multiplexed requests to, and receives responses from, a server. pub trait Client<'a, Req> { @@ -125,15 +126,34 @@ impl Default for Config { } } -/// Creates a new Client by wrapping a [`Transport`] and spawning a dispatch task -/// that manages the lifecycle of requests. -/// -/// Must only be called from on an executor. -pub async fn new(config: Config, transport: T) -> io::Result> -where - Req: Send + 'static, - Resp: Send + 'static, - T: Transport, Response> + Send + 'static, -{ - Ok(channel::spawn(config, transport).await?) +/// A channel and dispatch pair. The dispatch drives the sending and receiving of requests +/// and must be polled continuously or spawned. +#[derive(Debug)] +pub struct NewClient { + /// The new client. + pub client: C, + /// The client's dispatch. + pub dispatch: D, +} + +impl NewClient +where + D: Future> + Send + 'static, +{ + /// Helper method to spawn the dispatch on the default executor. + pub fn spawn(self) -> io::Result { + let dispatch = self + .dispatch + .unwrap_or_else(move |e| error!("Connection broken: {}", e)); + crate::spawn(dispatch).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!( + "Could not spawn client dispatch task. Is shutdown: {}", + e.is_shutdown() + ), + ) + })?; + Ok(self.client) + } } diff --git a/rpc/src/server/mod.rs b/rpc/src/server/mod.rs index 60c6124..cec87b3 100644 --- a/rpc/src/server/mod.rs +++ b/rpc/src/server/mod.rs @@ -78,7 +78,7 @@ impl Config { /// Returns a channel backed by `transport` and configured with `self`. pub fn channel(self, transport: T) -> BaseChannel where - T: Transport, ClientMessage> + Send, + T: Transport, ClientMessage>, { BaseChannel::new(self, transport) } @@ -101,49 +101,13 @@ impl Server { /// Returns a stream of server channels. pub fn incoming(self, listener: S) -> impl Stream> where - Req: Send, - Resp: Send, S: Stream, - T: Transport, ClientMessage> + Send, + T: Transport, ClientMessage>, { listener.map(move |t| BaseChannel::new(self.config.clone(), t)) } } -/// The future driving the server. -#[derive(Debug)] -pub struct Running { - incoming: St, - server: Se, -} - -impl Running { - unsafe_pinned!(incoming: St); - unsafe_unpinned!(server: Se); -} - -impl Future for Running -where - St: Sized + Stream, - C: Channel + Send + 'static, - Se: Serve + Send + 'static, - Se::Fut: Send + 'static -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) { - if let Err(e) = - crate::spawn(channel.respond_with(self.as_mut().server().clone())) - { - warn!("Failed to spawn channel handler: {:?}", e); - } - } - info!("Server shutting down."); - Poll::Ready(()) - } -} - /// Basically a Fn(Req) -> impl Future; pub trait Serve: Sized + Clone { /// Type of response. @@ -191,8 +155,7 @@ where /// Responds to all requests with `server`. fn respond_with(self, server: S) -> Running where - S: Serve + Send + 'static, - S::Fut: Send + 'static, + S: Serve, { Running { incoming: self, @@ -226,7 +189,7 @@ impl BaseChannel { impl BaseChannel where - T: Transport, ClientMessage> + Send, + T: Transport, ClientMessage>, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -288,10 +251,10 @@ where Self: Transport::Resp>, Request<::Req>>, { /// Type of request item. - type Req: Send + 'static; + type Req; /// Type of response sink item. - type Resp: Send + 'static; + type Resp; /// Configuration of the channel. fn config(&self) -> &Config; @@ -314,16 +277,15 @@ where /// Respond to requests coming over the channel with `f`. Returns a future that drives the /// responses and resolves when the connection is closed. - fn respond_with(self, server: S) -> ResponseHandler + fn respond_with(self, server: S) -> ClientHandler where - S: Serve + Send + 'static, - S::Fut: Send + 'static, + S: Serve, Self: Sized, { let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer); let responses = responses.fuse(); - ResponseHandler { + ClientHandler { channel: self, server, pending_responses: responses, @@ -334,9 +296,7 @@ where impl Stream for BaseChannel where - T: Transport, ClientMessage> + Send + 'static, - Req: Send + 'static, - Resp: Send + 'static, + T: Transport, ClientMessage>, { type Item = io::Result>; @@ -362,9 +322,7 @@ where impl Sink> for BaseChannel where - T: Transport, ClientMessage> + Send + 'static, - Req: Send + 'static, - Resp: Send + 'static, + T: Transport, ClientMessage>, { type Error = io::Error; @@ -402,9 +360,7 @@ impl AsRef for BaseChannel { impl Channel for BaseChannel where - T: Transport, ClientMessage> + Send + 'static, - Req: Send + 'static, - Resp: Send + 'static, + T: Transport, ClientMessage>, { type Req = Req; type Resp = Resp; @@ -429,7 +385,7 @@ where /// A running handler serving all requests coming over a channel. #[derive(Debug)] -pub struct ResponseHandler +pub struct ClientHandler where C: Channel, { @@ -438,11 +394,11 @@ where pending_responses: Fuse)>>, /// Handed out to request handlers to fan in responses. responses_tx: mpsc::Sender<(context::Context, Response)>, - /// Request handler. + /// Server server: S, } -impl ResponseHandler +impl ClientHandler where C: Channel, { @@ -450,22 +406,21 @@ where unsafe_pinned!(pending_responses: Fuse)>>); unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response)>); // For this to be safe, field f must be private, and code in this module must never - // construct PinMut. + // construct PinMut. unsafe_unpinned!(server: S); } -impl ResponseHandler +impl ClientHandler where C: Channel, - S: Serve + Send + 'static, - S::Fut: Send + 'static, + S: Serve, { - fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { + fn pump_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> PollIo> { match ready!(self.as_mut().channel().poll_next(cx)?) { - Some(request) => { - self.handle_request(request)?; - Poll::Ready(Some(Ok(()))) - } + Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))), None => Poll::Ready(None), } } @@ -518,13 +473,16 @@ where match ready!(self.as_mut().pending_responses().poll_next(cx)) { Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))), None => { - // This branch likely won't happen, since the ResponseHandler is holding a Sender. + // This branch likely won't happen, since the ClientHandler is holding a Sender. Poll::Ready(None) } } } - fn handle_request(mut self: Pin<&mut Self>, request: Request) -> io::Result<()> { + fn handle_request( + mut self: Pin<&mut Self>, + request: Request, + ) -> RequestHandler { let request_id = request.id; let deadline = request.context.deadline; let timeout = deadline.as_duration(); @@ -536,70 +494,144 @@ where ); let ctx = request.context; let request = request.message; - let mut response_tx = self.as_mut().responses_tx().clone(); - let trace_id = *ctx.trace_id(); let response = self.as_mut().server().clone().serve(ctx, request); - let response = deadline_compat::Deadline::new(response, Instant::now() + timeout).then( - move |result| { - async move { - let response = Response { - request_id, - message: match result { - Ok(message) => Ok(message), - Err(e) => Err(make_server_error(e, trace_id, deadline)), - }, - }; - trace!("[{}] Sending response.", trace_id); - response_tx - .send((ctx, response)) - .unwrap_or_else(|_| ()) - .await; - } - }, - ); + let response = Resp { + state: RespState::PollResp, + request_id, + ctx, + deadline, + f: deadline_compat::Deadline::new(response, Instant::now() + timeout), + response: None, + response_tx: self.as_mut().responses_tx().clone(), + }; let abort_registration = self.as_mut().channel().start_request(request_id); - let response = Abortable::new(response, abort_registration); - crate::spawn(response.map(|_| ())).map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!( - "Could not spawn response task. Is shutdown: {}", - e.is_shutdown() - ), - ) - })?; - Ok(()) + RequestHandler { + resp: Abortable::new(response, abort_registration), + } } } -impl Future for ResponseHandler +/// A future fulfilling a single client request. +#[derive(Debug)] +pub struct RequestHandler { + resp: Abortable>, +} + +impl RequestHandler { + unsafe_pinned!(resp: Abortable>); +} + +impl Future for RequestHandler where - C: Channel, - S: Serve + Send + 'static, - S::Fut: Send + 'static, + F: Future, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let _ = ready!(self.resp().poll(cx)); + Poll::Ready(()) + } +} + +#[derive(Debug)] +struct Resp { + state: RespState, + request_id: u64, + ctx: context::Context, + deadline: SystemTime, + f: deadline_compat::Deadline, + response: Option>, + response_tx: mpsc::Sender<(context::Context, Response)>, +} + +#[derive(Debug)] +enum RespState { + PollResp, + PollReady, + PollFlush, +} + +impl Resp { + unsafe_pinned!(f: deadline_compat::Deadline); + unsafe_pinned!(response_tx: mpsc::Sender<(context::Context, Response)>); + unsafe_unpinned!(response: Option>); + unsafe_unpinned!(state: RespState); +} + +impl Future for Resp +where + F: Future, { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - move || -> Poll> { - loop { - let read = self.as_mut().pump_read(cx)?; - match ( - read, - self.as_mut().pump_write(cx, read == Poll::Ready(None))?, - ) { - (Poll::Ready(None), Poll::Ready(None)) => { - return Poll::Ready(Ok(())); + loop { + match self.as_mut().state() { + RespState::PollResp => { + let result = ready!(self.as_mut().f().poll(cx)); + *self.as_mut().response() = Some(Response { + request_id: self.request_id, + message: match result { + Ok(message) => Ok(message), + Err(e) => { + Err(make_server_error(e, *self.ctx.trace_id(), self.deadline)) + } + }, + }); + *self.as_mut().state() = RespState::PollReady; + } + RespState::PollReady => { + let ready = ready!(self.as_mut().response_tx().poll_ready(cx)); + if ready.is_err() { + return Poll::Ready(()); } - (Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {} - _ => { - return Poll::Pending; + let resp = (self.ctx, self.as_mut().response().take().unwrap()); + if self.as_mut().response_tx().start_send(resp).is_err() { + return Poll::Ready(()); } + *self.as_mut().state() = RespState::PollFlush; + } + RespState::PollFlush => { + let ready = ready!(self.as_mut().response_tx().poll_flush(cx)); + if ready.is_err() { + return Poll::Ready(()); + } + return Poll::Ready(()); } } - }() - .map(|r| r.unwrap_or_else(|e| info!("ResponseHandler errored out: {}", e))) + } + } +} + +impl Stream for ClientHandler +where + C: Channel, + S: Serve, +{ + type Item = io::Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + let read = self.as_mut().pump_read(cx)?; + let read_closed = if let Poll::Ready(None) = read { + true + } else { + false + }; + match (read, self.as_mut().pump_write(cx, read_closed)?) { + (Poll::Ready(None), Poll::Ready(None)) => { + return Poll::Ready(None); + } + (Poll::Ready(Some(request_handler)), _) => { + return Poll::Ready(Some(Ok(request_handler))); + } + (_, Poll::Ready(Some(()))) => {} + _ => { + return Poll::Pending; + } + } + } } } @@ -641,3 +673,72 @@ fn make_server_error( } } } + +// Send + 'static execution helper methods. + +impl ClientHandler +where + C: Channel + 'static, + C::Req: Send + 'static, + C::Resp: Send + 'static, + S: Serve + Send + 'static, + S::Fut: Send + 'static, +{ + /// Runs the client handler until completion by spawning each + /// request handler onto the default executor. + pub fn execute(self) -> impl Future { + self.try_for_each(|request_handler| { + async { + crate::spawn(request_handler).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!( + "Could not spawn response task. Is shutdown: {}", + e.is_shutdown() + ), + ) + }) + } + }) + .unwrap_or_else(|e| info!("ClientHandler errored out: {}", e)) + } +} + +/// A future that drives the server by spawning channels and request handlers on the default +/// executor. +#[derive(Debug)] +pub struct Running { + incoming: St, + server: Se, +} + +impl Running { + unsafe_pinned!(incoming: St); + unsafe_unpinned!(server: Se); +} + +impl Future for Running +where + St: Sized + Stream, + C: Channel + Send + 'static, + C::Req: Send + 'static, + C::Resp: Send + 'static, + Se: Serve + Send + 'static + Clone, + Se::Fut: Send + 'static, +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) { + if let Err(e) = crate::spawn( + channel + .respond_with(self.as_mut().server().clone()) + .execute(), + ) { + warn!("Failed to spawn channel handler: {:?}", e); + } + } + info!("Server shutting down."); + Poll::Ready(()) + } +} diff --git a/rpc/src/server/testing.rs b/rpc/src/server/testing.rs index 804e601..5ba0455 100644 --- a/rpc/src/server/testing.rs +++ b/rpc/src/server/testing.rs @@ -60,8 +60,7 @@ impl Sink> for FakeChannel> { impl Channel for FakeChannel>, Response> where - Req: Unpin + Send + 'static, - Resp: Send + 'static, + Req: Unpin, { type Req = Req; type Resp = Resp; diff --git a/rpc/src/server/throttle.rs b/rpc/src/server/throttle.rs index 4ba9f19..cedb219 100644 --- a/rpc/src/server/throttle.rs +++ b/rpc/src/server/throttle.rs @@ -286,11 +286,7 @@ fn throttler_poll_next_throttled_sink_not_ready() { Poll::Pending } } - impl Channel for PendingSink>, Response> - where - Req: Send + 'static, - Resp: Send + 'static, - { + impl Channel for PendingSink>, Response> { type Req = Req; type Resp = Resp; fn config(&self) -> &Config { diff --git a/rpc/src/transport/channel.rs b/rpc/src/transport/channel.rs index 309d12b..c38c28f 100644 --- a/rpc/src/transport/channel.rs +++ b/rpc/src/transport/channel.rs @@ -106,7 +106,7 @@ mod tests { ) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - let mut client = client::new(client::Config::default(), client_channel).await?; + let mut client = client::new(client::Config::default(), client_channel).spawn()?; let response1 = client.call(context::current(), "123".into()).await?; let response2 = client.call(context::current(), "abc".into()).await?; diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 7e9efe3..47b9d36 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -31,10 +31,12 @@ bytes = { version = "0.4", features = ["serde"] } env_logger = "0.6" futures-preview = { version = "0.3.0-alpha.17", features = ["compat"] } humantime = "1.0" +log = "0.4" runtime = "0.3.0-alpha.6" runtime-tokio = "0.3.0-alpha.5" tokio-tcp = "0.1" pin-utils = "0.1.0-alpha.4" +tokio = "0.1" [[example]] name = "server_calling_server" diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 85dfb9d..9444f78 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -116,7 +116,7 @@ impl publisher::Publisher for Publisher { ) -> io::Result<()> { let conn = bincode_transport::connect(&addr).await?; let subscriber = - subscriber::SubscriberClient::new(client::Config::default(), conn).await?; + subscriber::SubscriberClient::new(client::Config::default(), conn).spawn()?; eprintln!("Subscribing {}.", id); clients.lock().unwrap().insert(id, subscriber); Ok(()) @@ -160,7 +160,7 @@ async fn main() -> io::Result<()> { let publisher_conn = bincode_transport::connect(&publisher_addr); let publisher_conn = publisher_conn.await?; let mut publisher = - publisher::PublisherClient::new(client::Config::default(), publisher_conn).await?; + publisher::PublisherClient::new(client::Config::default(), publisher_conn).spawn()?; if let Err(e) = publisher .subscribe(context::current(), 0, subscriber1) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 5d44be5..c4c7669 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -58,6 +58,7 @@ async fn main() -> io::Result<()> { // serve_world is generated by the tarpc::service attribute. It takes as input any type // implementing the generated World trait. .respond_with(HelloServer.serve()) + .execute() .await; }; let _ = runtime::spawn(server); @@ -66,7 +67,7 @@ async fn main() -> io::Result<()> { // WorldClient is generated by the tarpc::service attribute. It has a constructor `new` that // takes a config and any Transport as input. - let mut client = WorldClient::new(client::Config::default(), transport).await?; + let mut client = WorldClient::new(client::Config::default(), transport).spawn()?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/tarpc/examples/server_calling_server.rs b/tarpc/examples/server_calling_server.rs index 025d36c..400c69a 100644 --- a/tarpc/examples/server_calling_server.rs +++ b/tarpc/examples/server_calling_server.rs @@ -78,7 +78,7 @@ async fn main() -> io::Result<()> { let _ = runtime::spawn(add_server); let to_add_server = bincode_transport::connect(&addr).await?; - let add_client = add::AddClient::new(client::Config::default(), to_add_server).await?; + let add_client = add::AddClient::new(client::Config::default(), to_add_server).spawn()?; let double_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())? .filter_map(|r| future::ready(r.ok())); @@ -91,7 +91,7 @@ async fn main() -> io::Result<()> { let to_double_server = bincode_transport::connect(&addr).await?; let mut double_client = - double::DoubleClient::new(client::Config::default(), to_double_server).await?; + double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?; for i in 1..=5 { eprintln!("{:?}", double_client.double(context::current(), i).await?); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index fa7f3b5..5d3db88 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -1,21 +1,38 @@ #![feature(async_await)] -#[cfg(not(feature = "serde1"))] -use std::rc::Rc; use assert_matches::assert_matches; use futures::{ future::{ready, Ready}, prelude::*, }; - -use std::io; +use std::{rc::Rc, io}; use tarpc::{ - client, context, + client::{self, NewClient}, context, server::{self, BaseChannel, Channel, Handler}, transport::channel, }; +trait RuntimeExt { + fn exec_bg(&mut self, future: impl Future + 'static); + fn exec(&mut self, future: F) -> Result + where + F: Future>; +} + +impl RuntimeExt for tokio::runtime::current_thread::Runtime { + fn exec_bg(&mut self, future: impl Future + 'static) { + self.spawn(Box::pin(future.unit_error()).compat()); + } + + fn exec(&mut self, future: F) -> Result + where + F: Future>, + { + self.block_on(futures::compat::Compat::new(Box::pin(future))) + } +} + #[tarpc_plugins::service] trait Service { async fn add(x: i32, y: i32) -> i32; @@ -48,9 +65,10 @@ async fn sequential() -> io::Result<()> { let _ = runtime::spawn( BaseChannel::new(server::Config::default(), rx) .respond_with(Server.serve()) + .execute() ); - let mut client = ServiceClient::new(client::Config::default(), tx).await?; + let mut client = ServiceClient::new(client::Config::default(), tx).spawn()?; assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); assert_matches!( @@ -74,7 +92,7 @@ async fn serde() -> io::Result<()> { ); let transport = bincode_transport::connect(&addr).await?; - let mut client = ServiceClient::new(client::Config::default(), transport).await?; + let mut client = ServiceClient::new(client::Config::default(), transport).spawn()?; assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); assert_matches!( @@ -96,7 +114,7 @@ async fn concurrent() -> io::Result<()> { .respond_with(Server.serve()), ); - let client = ServiceClient::new(client::Config::default(), tx).await?; + let client = ServiceClient::new(client::Config::default(), tx).spawn()?; let mut c = client.clone(); let req1 = c.add(context::current(), 1, 2); @@ -113,3 +131,61 @@ async fn concurrent() -> io::Result<()> { Ok(()) } + +#[tarpc::service(derive_serde = false)] +trait InMemory { + async fn strong_count(rc: Rc<()>) -> usize; + async fn weak_count(rc: Rc<()>) -> usize; +} + +impl InMemory for () { + type StrongCountFut = Ready; + fn strong_count(self, _: context::Context, rc: Rc<()>) -> Self::StrongCountFut { + ready(Rc::strong_count(&rc)) + } + + type WeakCountFut = Ready; + fn weak_count(self, _: context::Context, rc: Rc<()>) -> Self::WeakCountFut { + ready(Rc::weak_count(&rc)) + } +} + +#[test] +fn in_memory_single_threaded() -> io::Result<()> { + use log::warn; + + let _ = env_logger::try_init(); + let mut runtime = tokio::runtime::current_thread::Runtime::new()?; + + let (tx, rx) = channel::unbounded(); + + let server = BaseChannel::new(server::Config::default(), rx) + .respond_with(().serve()) + .try_for_each(|r| async move { Ok(r.await) }); + runtime.exec_bg(async { + if let Err(e) = server.await { + warn!("Error while running server: {}", e); + } + }); + + let NewClient{mut client, dispatch} = InMemoryClient::new(client::Config::default(), tx); + runtime.exec_bg(async move { + if let Err(e) = dispatch.await { + warn!("Error while running client dispatch: {}", e) + } + }); + + let rc = Rc::new(()); + assert_matches!( + runtime.exec(client.strong_count(context::current(), rc.clone())), + Ok(2) + ); + + let _weak = Rc::downgrade(&rc); + assert_matches!( + runtime.exec(client.weak_count(context::current(), rc)), + Ok(1) + ); + + Ok(()) +}