diff --git a/README.md b/README.md index d5aab1e..c997816 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ use futures::{ }; use tarpc::{ client, context, - server::{self, Handler}, + server::{self, Incoming}, }; use std::io; @@ -135,16 +135,11 @@ available behind the `tcp` feature. async fn main() -> io::Result<()> { let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); - let server = server::new(server::Config::default()) - // incoming() takes a stream of transports such as would be returned by - // TcpListener::incoming (but a stream instead of an iterator). - .incoming(stream::once(future::ready(server_transport))) - .respond_with(HelloServer.serve()); + let server = server::BaseChannel::with_defaults(server_transport); + tokio::spawn(server.execute(HelloServer.serve())); - tokio::spawn(server); - - // WorldClient is generated by the macro. It has a constructor `new` that takes a config and - // any Transport as input + // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` + // that takes a config and any Transport as input. let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 59f9ed0..0faadc8 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -13,7 +13,7 @@ use std::{ }; use tarpc::{ context, - server::{self, Channel, Handler}, + server::{self, Channel, Incoming}, tokio_serde::formats::Json, }; @@ -69,7 +69,7 @@ async fn main() -> io::Result<()> { // the generated World trait. .map(|channel| { let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap()); - channel.respond_with(server.serve()).execute() + channel.requests().execute(server.serve()) }) // Max 10 channels. .buffer_unordered(10) diff --git a/hooks/pre-commit b/hooks/pre-commit index ac1bea9..4052a2b 100755 --- a/hooks/pre-commit +++ b/hooks/pre-commit @@ -93,7 +93,7 @@ diff="" for file in $(git diff --name-only --cached); do if [ ${file: -3} == ".rs" ]; then - diff="$diff$(cargo fmt -- --unstable-features --skip-children --check $file)" + diff="$diff$(cargo fmt -- --check $file)" fi done if grep --quiet "^[-+]" <<< "$diff"; then diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 923b497..c52711e 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -485,7 +485,7 @@ impl<'a> ServiceGenerator<'a> { #vis trait #service_ident: Clone { #( #types_and_fns )* - /// Returns a serving function to use with [tarpc::server::Channel::respond_with]. + /// Returns a serving function to use with [tarpc::server::InFlightRequest::execute]. fn serve(self) -> #server_ident { #server_ident { service: self } } @@ -499,7 +499,7 @@ impl<'a> ServiceGenerator<'a> { } = self; quote! { - /// A serving function to use with [tarpc::server::Channel::respond_with]. + /// A serving function to use with [tarpc::server::InFlightRequest::execute]. #[derive(Clone)] #vis struct #server_ident { service: S, @@ -662,7 +662,8 @@ impl<'a> ServiceGenerator<'a> { quote! { #[allow(unused)] #[derive(Clone, Debug)] - /// The client stub that makes RPC calls to the server. Exposes a Future interface. + /// The client stub that makes RPC calls to the server. ALl request methods return + /// [Futures](std::future::Future). #vis struct #client_ident>(C); } } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 70cd97a..f0faee7 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -113,8 +113,7 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { let transport = incoming.next().await.unwrap().unwrap(); BaseChannel::with_defaults(add_compression(transport)) - .respond_with(HelloServer.serve()) - .execute() + .execute(HelloServer.serve()) .await; }); diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 2fe01c5..d0df3bd 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -36,9 +36,7 @@ async fn main() -> std::io::Result<()> { let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); - let fut = BaseChannel::with_defaults(transport) - .respond_with(Service.serve()) - .execute(); + let fut = BaseChannel::with_defaults(transport).execute(Service.serve()); tokio::spawn(fut); } }); diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index f98a237..6b43bc6 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -105,11 +105,11 @@ impl Subscriber { ) -> anyhow::Result { let publisher = tcp::connect(publisher_addr, Json::default).await?; let local_addr = publisher.local_addr()?; - let mut handler = server::BaseChannel::with_defaults(publisher) - .respond_with(Subscriber { local_addr, topics }.serve()); - // The first request is for the topics being subscriibed to. + let mut handler = server::BaseChannel::with_defaults(publisher).requests(); + let subscriber = Subscriber { local_addr, topics }; + // The first request is for the topics being subscribed to. match handler.next().await { - Some(init_topics) => init_topics?.await, + Some(init_topics) => init_topics?.execute(subscriber.clone().serve()).await, None => { return Err(anyhow!( "[{}] Server never initialized the subscriber.", @@ -117,7 +117,7 @@ impl Subscriber { )) } }; - let (handler, abort_handle) = future::abortable(handler.execute()); + let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve())); tokio::spawn(async move { match handler.await { Ok(()) | Err(future::Aborted) => info!("[{}] subscriber shutdown.", local_addr), @@ -162,8 +162,7 @@ impl Publisher { info!("[{}] publisher connected.", publisher.peer_addr().unwrap()); server::BaseChannel::with_defaults(publisher) - .respond_with(self.serve()) - .execute() + .execute(self.serve()) .await }); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index fe40886..26a63ba 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -11,7 +11,7 @@ use futures::{ use std::io; use tarpc::{ client, context, - server::{BaseChannel, Channel}, + server::{self, Channel}, }; use tokio_serde::formats::Json; @@ -40,40 +40,21 @@ impl World for HelloServer { #[tokio::main] async fn main() -> io::Result<()> { - // tarpc_json_transport is provided by the associated crate json_transport. It makes it - // easy to start up a serde-powered JSON serialization strategy over TCP. - let mut transport = tarpc::serde_transport::tcp::listen("localhost:0", Json::default).await?; - let addr = transport.local_addr(); + let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); - let server = async move { - // For this example, we're just going to wait for one connection. - let client = transport.next().await.unwrap().unwrap(); + let server = server::BaseChannel::with_defaults(server_transport); + tokio::spawn(server.execute(HelloServer.serve())); - // `Channel` is a trait representing a server-side connection. It is a trait to allow - // for some channels to be instrumented: for example, to track the number of open connections. - // BaseChannel is the most basic channel, simply wrapping a transport with no added - // functionality. - BaseChannel::with_defaults(client) - // serve_world is generated by the tarpc::service attribute. It takes as input any type - // implementing the generated World trait. - .respond_with(HelloServer.serve()) - .execute() - .await; - }; - tokio::spawn(server); - - let transport = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - - // WorldClient is generated by the tarpc::service attribute. It has a constructor `new` that - // takes a config and any Transport as input. - let mut client = WorldClient::new(client::Config::default(), transport).spawn()?; + // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` + // that takes a config and any Transport as input. + let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. let hello = client.hello(context::current(), "Stim".to_string()).await?; - eprintln!("{}", hello); + println!("{}", hello); Ok(()) } diff --git a/tarpc/examples/server_calling_server.rs b/tarpc/examples/server_calling_server.rs index 3592c23..eb066f0 100644 --- a/tarpc/examples/server_calling_server.rs +++ b/tarpc/examples/server_calling_server.rs @@ -9,7 +9,7 @@ use futures::{future, prelude::*}; use std::io; use tarpc::{ client, context, - server::{Handler, Server}, + server::{BaseChannel, Incoming}, }; use tokio_serde::formats::Json; @@ -62,10 +62,10 @@ async fn main() -> io::Result<()> { .await? .filter_map(|r| future::ready(r.ok())); let addr = add_listener.get_ref().local_addr(); - let add_server = Server::default() - .incoming(add_listener) + let add_server = add_listener + .map(BaseChannel::with_defaults) .take(1) - .respond_with(AddServer.serve()); + .execute(AddServer.serve()); tokio::spawn(add_server); let to_add_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; @@ -75,10 +75,10 @@ async fn main() -> io::Result<()> { .await? .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); - let double_server = tarpc::Server::default() - .incoming(double_listener) + let double_server = double_listener + .map(BaseChannel::with_defaults) .take(1) - .respond_with(DoubleServer { add_client }.serve()); + .execute(DoubleServer { add_client }.serve()); tokio::spawn(double_server); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index e844a9c..a3b161d 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -79,7 +79,7 @@ //! }; //! use tarpc::{ //! client, context, -//! server::{self, Handler}, +//! server::{self, Incoming}, //! }; //! use std::io; //! @@ -103,7 +103,7 @@ //! # }; //! # use tarpc::{ //! # client, context, -//! # server::{self, Handler}, +//! # server::{self, Incoming}, //! # }; //! # use std::io; //! # // This is the service definition. It looks a lot like a trait definition. @@ -143,7 +143,7 @@ //! # }; //! # use tarpc::{ //! # client, context, -//! # server::{self, Handler}, +//! # server::{self, Channel}, //! # }; //! # use std::io; //! # // This is the service definition. It looks a lot like a trait definition. @@ -172,16 +172,11 @@ //! async fn main() -> io::Result<()> { //! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); //! -//! let server = server::new(server::Config::default()) -//! // incoming() takes a stream of transports such as would be returned by -//! // TcpListener::incoming (but a stream instead of an iterator). -//! .incoming(stream::once(future::ready(server_transport))) -//! .respond_with(HelloServer.serve()); +//! let server = server::BaseChannel::with_defaults(server_transport); +//! tokio::spawn(server.execute(HelloServer.serve())); //! -//! tokio::spawn(server); -//! -//! // WorldClient is generated by the macro. It has a constructor `new` that takes a config and -//! // any Transport as input +//! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` +//! // that takes a config and any Transport as input. //! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?; //! //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same @@ -304,7 +299,7 @@ pub mod server; pub mod transport; pub(crate) mod util; -pub use crate::{client::Client, server::Server, transport::sealed::Transport}; +pub use crate::{client::Client, transport::sealed::Transport}; use anyhow::Context as _; use futures::task::*; diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index f875220..295d0c4 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -22,8 +22,7 @@ use futures::{ use humantime::format_rfc3339; use log::{debug, trace}; use pin_project::{pin_project, pinned_drop}; -use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime}; -use tokio::time::Timeout; +use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin}; mod filter; #[cfg(test)] @@ -35,30 +34,12 @@ pub use self::{ throttle::{Throttler, ThrottlerStream}, }; -/// Manages clients, serving multiplexed requests over each connection. -pub struct Server { - config: Config, - ghost: PhantomData<(Req, Resp)>, -} - -impl Default for Server { - fn default() -> Self { - new(Config::default()) - } -} - -impl fmt::Debug for Server { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "Server") - } -} - -/// Settings that control the behavior of the server. +/// Settings that control the behavior of [channels](Channel). #[derive(Clone, Debug)] pub struct Config { - /// The number of responses per client that can be buffered server-side before being sent. - /// `pending_response_buffer` controls the buffer size of the channel that a server's - /// response tasks use to send responses to the client handler task. + /// Controls the buffer size of the in-process channel over which a server's handlers send + /// responses to the [`Channel`]. In other words, this is the number of responses that can sit + /// in the outbound queue before request handlers begin blocking. pub pending_response_buffer: usize, } @@ -80,32 +61,8 @@ impl Config { } } -/// Returns a new server with configuration specified `config`. -pub fn new(config: Config) -> Server { - Server { - config, - ghost: PhantomData, - } -} - -impl Server { - /// Returns the config for this server. - pub fn config(&self) -> &Config { - &self.config - } - - /// Returns a stream of server channels. - pub fn incoming(self, listener: S) -> impl Stream> - where - S: Stream, - T: Transport, ClientMessage>, - { - listener.map(move |t| BaseChannel::new(self.config.clone(), t)) - } -} - -/// Basically a Fn(Req) -> impl Future; -pub trait Serve: Sized + Clone { +/// Equivalent to a `FnOnce(Req) -> impl Future`. +pub trait Serve { /// Type of response. type Resp; @@ -129,8 +86,8 @@ where } } -/// A utility trait enabling a stream to fluently chain a request handler. -pub trait Handler +/// An extension trait for [streams](Stream) of [`Channels`](Channel). +pub trait Incoming where Self: Sized + Stream, C: Channel, @@ -149,28 +106,34 @@ where ThrottlerStream::new(self, n) } - /// Responds to all requests with [`server::serve`](Serve). + /// [Executes](Channel::execute) each incoming channel. Each channel will be handled + /// concurrently by spawning on tokio's default executor, and each request will be also + /// be spawned on tokio's default executor. #[cfg(feature = "tokio1")] #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] - fn respond_with(self, server: S) -> Running + fn execute(self, serve: S) -> TokioServerExecutor where S: Serve, { - Running { - incoming: self, - server, - } + TokioServerExecutor { inner: self, serve } } } -impl Handler for S +impl Incoming for S where S: Sized + Stream, C: Channel, { } -/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests. +/// BaseChannel is a [Transport] that keeps track of in-flight requests. It converts a +/// [`Transport`](Transport) of [`ClientMessages`](ClientMessage) into a stream of +/// [requests](ClientMessage::Request). +/// +/// Besides requests, the other type of client message is [cancellation +/// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation +/// messages. Instead, it internally handles them by cancelling corresponding requests (removing +/// the corresponding in-flight requests and aborting their handlers). #[pin_project(PinnedDrop)] pub struct BaseChannel { config: Config, @@ -251,10 +214,25 @@ impl fmt::Debug for BaseChannel { /// The server end of an open connection with a client, streaming in requests from, and sinking /// responses to, the client. /// -/// Channels are free to somewhat rely on the assumption that all in-flight requests are eventually -/// either [cancelled](BaseChannel::cancel_request) or [responded to](Sink::start_send). Safety cannot -/// rely on this assumption, but it is best for `Channel` users to always account for all outstanding -/// requests. +/// +/// The ways to use a Channel, in order of simplest to most complex, is: +/// 1. [Channel::execute] - Requires the `tokio1` feature. This method is best for those who +/// do not have specific scheduling needs and whose services are `Send + 'static`. +/// 2. [Channel::requests] - This method is best for those who need direct access to individual +/// requests, or are not using `tokio`, or want control over [futures](Future) scheduling. +/// 3. [Raw stream]() - A user is free to manually handle requests produced by +/// Channel. If they do so, they should uphold the service contract: +/// 1. All work being done as part of processing request `request_id` is aborted when +/// either of the following occurs: +/// - The channel receives a [cancellation message](ClientMessage::Cancel) for request +/// `request_id`. +/// - The [deadline](crate::context::Context::deadline) of request `request_id` is reached. +/// 2. When a server completes a response for request `request_id`, it is +/// [sent](Sink::start_send) into the Channel. Because there is no guarantee that a +/// cancellation message will ever be received for a request, services should strive to clean +/// up Channel resources by sending a response for every request. For example, [`BaseChannel`] +/// has a map of requests to [abort handles][AbortHandle] whose entries are only removed +/// upon either request cancellation or response completion. pub trait Channel where Self: Transport::Resp>, Request<::Req>>, @@ -269,14 +247,14 @@ where fn config(&self) -> &Config; /// Returns the number of in-flight requests over this channel. - fn in_flight_requests(self: Pin<&mut Self>) -> usize; + fn in_flight_requests(&self) -> usize; - /// Caps the number of concurrent requests. - fn max_concurrent_requests(self, n: usize) -> Throttler + /// Caps the number of concurrent requests to `limit`. + fn max_concurrent_requests(self, limit: usize) -> Throttler where Self: Sized, { - Throttler::new(self, n) + Throttler::new(self, limit) } /// Tells the Channel that request with ID `request_id` is being handled. @@ -284,23 +262,37 @@ where /// to the Channel. fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration; - /// Respond to requests coming over the channel with `f`. Returns a future that drives the - /// responses and resolves when the connection is closed. - fn respond_with(self, server: S) -> ClientHandler + /// Returns a stream of requests that automatically handle request cancellation and response + /// routing. + fn requests(self) -> Requests where - S: Serve, Self: Sized, { let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer); let responses = responses.fuse(); - ClientHandler { + Requests { channel: self, - server, pending_responses: responses, responses_tx, } } + + /// Runs the channel until completion by executing all requests using the given service + /// function. Request handlers are run concurrently by [spawning](tokio::spawn) on tokio's + /// default executor. + #[cfg(feature = "tokio1")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] + fn execute(self, serve: S) -> TokioChannelExecutor, S> + where + Self: Sized, + S: Serve + Send + Sync + 'static, + S::Fut: Send, + Self::Req: Send + 'static, + Self::Resp: Send + 'static, + { + self.requests().execute(serve) + } } impl Stream for BaseChannel @@ -390,8 +382,8 @@ where &self.config } - fn in_flight_requests(mut self: Pin<&mut Self>) -> usize { - self.as_mut().project().in_flight_requests.len() + fn in_flight_requests(&self) -> usize { + self.in_flight_requests.len() } fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration { @@ -405,9 +397,9 @@ where } } -/// A running handler serving all requests coming over a channel. +/// A stream of requests coming over a channel. #[pin_project] -pub struct ClientHandler +pub struct Requests where C: Channel, { @@ -419,26 +411,30 @@ where /// Handed out to request handlers to fan in responses. #[pin] responses_tx: mpsc::Sender<(context::Context, Response)>, - /// Server - server: S, } -impl ClientHandler +impl Requests where C: Channel, - S: Serve, { /// Returns the inner channel over which messages are sent and received. - pub fn get_pin_channel(self: Pin<&mut Self>) -> Pin<&mut C> { - self.project().channel + pub fn channel_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> { + self.as_mut().project().channel } fn pump_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> PollIo> { + ) -> PollIo> { match ready!(self.as_mut().project().channel.poll_next(cx)?) { - Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))), + Some(request) => { + let abort_registration = self.as_mut().project().channel.start_request(request.id); + Poll::Ready(Some(Ok(InFlightRequest { + request, + response_tx: self.responses_tx.clone(), + abort_registration, + }))) + } None => Poll::Ready(None), } } @@ -449,28 +445,28 @@ where read_half_closed: bool, ) -> PollIo<()> { match self.as_mut().poll_next_response(cx)? { - Poll::Ready(Some((ctx, response))) => { + Poll::Ready(Some((context, response))) => { trace!( "[{}] Staging response. In-flight requests = {}.", - ctx.trace_id(), - self.as_mut().project().channel.in_flight_requests(), + context.trace_id(), + self.channel.in_flight_requests(), ); - self.as_mut().project().channel.start_send(response)?; + self.channel_pin_mut().start_send(response)?; Poll::Ready(Some(Ok(()))) } Poll::Ready(None) => { // Shutdown can't be done before we finish pumping out remaining responses. - ready!(self.as_mut().project().channel.poll_flush(cx)?); + ready!(self.channel_pin_mut().poll_flush(cx)?); Poll::Ready(None) } Poll::Pending => { // No more requests to process, so flush any requests buffered in the transport. - ready!(self.as_mut().project().channel.poll_flush(cx)?); + ready!(self.channel_pin_mut().poll_flush(cx)?); // Being here means there are no staged requests and all written responses are // fully flushed. So, if the read half is closed and there are no in-flight // requests, then we can close the write half. - if read_half_closed && self.as_mut().project().channel.in_flight_requests() == 0 { + if read_half_closed && self.channel.in_flight_requests() == 0 { Poll::Ready(None) } else { Poll::Pending @@ -484,183 +480,116 @@ where cx: &mut Context<'_>, ) -> PollIo<(context::Context, Response)> { // Ensure there's room to write a response. - while self.as_mut().project().channel.poll_ready(cx)?.is_pending() { + while self.channel_pin_mut().poll_ready(cx)?.is_pending() { ready!(self.as_mut().project().channel.poll_flush(cx)?); } match ready!(self.as_mut().project().pending_responses.poll_next(cx)) { - Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))), + Some(response) => Poll::Ready(Some(Ok(response))), None => { - // This branch likely won't happen, since the ClientHandler is holding a Sender. + // This branch likely won't happen, since the Requests stream is holding a Sender. Poll::Ready(None) } } } - - fn handle_request( - mut self: Pin<&mut Self>, - request: Request, - ) -> RequestHandler { - let request_id = request.id; - let deadline = request.context.deadline; - let timeout = deadline.time_until(); - trace!( - "[{}] Received request with deadline {} (timeout {:?}).", - request.context.trace_id(), - format_rfc3339(deadline), - timeout, - ); - let ctx = request.context; - let request = request.message; - - let response = self.as_mut().project().server.clone().serve(ctx, request); - let response = Resp { - state: RespState::PollResp, - request_id, - ctx, - deadline, - f: tokio::time::timeout(timeout, response), - response: None, - response_tx: self.as_mut().project().responses_tx.clone(), - }; - let abort_registration = self.as_mut().project().channel.start_request(request_id); - RequestHandler { - resp: Abortable::new(response, abort_registration), - } - } } -impl fmt::Debug for ClientHandler +impl fmt::Debug for Requests where C: Channel, { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "ClientHandler") + write!(fmt, "Requests") } } -/// A future fulfilling a single client request. -#[pin_project] -pub struct RequestHandler { - #[pin] - resp: Abortable>, -} - -impl Future for RequestHandler -where - F: Future, -{ - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - let _ = ready!(self.project().resp.poll(cx)); - Poll::Ready(()) - } -} - -impl fmt::Debug for RequestHandler { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "RequestHandler") - } -} - -#[pin_project] -struct Resp { - state: RespState, - request_id: u64, - ctx: context::Context, - deadline: SystemTime, - #[pin] - f: Timeout, - response: Option>, - #[pin] - response_tx: mpsc::Sender<(context::Context, Response)>, -} - +/// A request produced by [Channel::requests]. #[derive(Debug)] -#[allow(clippy::enum_variant_names)] -enum RespState { - PollResp, - PollReady, - PollFlush, +pub struct InFlightRequest { + request: Request, + response_tx: mpsc::Sender<(context::Context, Response)>, + abort_registration: AbortRegistration, } -impl Future for Resp -where - F: Future, -{ - type Output = (); +impl InFlightRequest { + /// Returns a reference to the request. + pub fn get(&self) -> &Request { + &self.request + } - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - loop { - match self.as_mut().project().state { - RespState::PollResp => { - let result = ready!(self.as_mut().project().f.poll(cx)); - *self.as_mut().project().response = Some(Response { - request_id: self.request_id, - message: match result { - Ok(message) => Ok(message), - Err(tokio::time::error::Elapsed { .. }) => { - debug!( - "[{}] Response did not complete before deadline of {}s.", - self.ctx.trace_id(), - format_rfc3339(self.deadline) - ); - // No point in responding, since the client will have dropped the - // request. - Err(ServerError { - kind: io::ErrorKind::TimedOut, - detail: Some(format!( - "Response did not complete before deadline of {}s.", - format_rfc3339(self.deadline) - )), - }) - } - }, - }); - *self.as_mut().project().state = RespState::PollReady; - } - RespState::PollReady => { - let ready = ready!(self.as_mut().project().response_tx.poll_ready(cx)); - if ready.is_err() { - return Poll::Ready(()); - } - let resp = (self.ctx, self.as_mut().project().response.take().unwrap()); - if self - .as_mut() - .project() - .response_tx - .start_send(resp) - .is_err() - { - return Poll::Ready(()); - } - *self.as_mut().project().state = RespState::PollFlush; - } - RespState::PollFlush => { - let ready = ready!(self.as_mut().project().response_tx.poll_flush(cx)); - if ready.is_err() { - return Poll::Ready(()); - } - return Poll::Ready(()); - } - } - } + /// Returns a [future](Future) that executes the request using the given [service + /// function](Serve). The service function's output is automatically sent back to the + /// [Channel] that yielded this request. + /// + /// The returned future will stop executing when the first of the following conditions is met: + /// + /// 1. The channel that yielded this request receives a [cancellation + /// message](ClientMessage::Cancel) for this request. + /// 2. The request [deadline](crate::context::Context::deadline) is reached. + /// 3. The service function completes. + pub fn execute(self, serve: S) -> impl Future + where + S: Serve, + { + let Self { + abort_registration, + request, + mut response_tx, + } = self; + Abortable::new( + async move { + let Request { + context, + message, + id: request_id, + } = request; + let trace_id = *request.context.trace_id(); + let deadline = request.context.deadline; + let timeout = deadline.time_until(); + trace!( + "[{}] Handling request with deadline {} (timeout {:?}).", + trace_id, + format_rfc3339(deadline), + timeout, + ); + let result = + tokio::time::timeout(timeout, async { serve.serve(context, message).await }) + .await; + let response = Response { + request_id, + message: match result { + Ok(message) => Ok(message), + Err(tokio::time::error::Elapsed { .. }) => { + debug!( + "[{}] Response did not complete before deadline of {}s.", + trace_id, + format_rfc3339(deadline) + ); + // No point in responding, since the client will have dropped the + // request. + Err(ServerError { + kind: io::ErrorKind::TimedOut, + detail: Some(format!( + "Response did not complete before deadline of {}s.", + format_rfc3339(deadline) + )), + }) + } + }, + }; + let _ = response_tx.send((context, response)).await; + }, + abort_registration, + ) + .unwrap_or_else(|_| {}) } } -impl fmt::Debug for Resp { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "Resp") - } -} - -impl Stream for ClientHandler +impl Stream for Requests where C: Channel, - S: Serve, { - type Item = io::Result>; + type Item = io::Result>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { @@ -684,77 +613,111 @@ where // Send + 'static execution helper methods. -impl ClientHandler +#[cfg(feature = "tokio1")] +#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] +impl Requests where - C: Channel + 'static, + C: Channel, C::Req: Send + 'static, C::Resp: Send + 'static, - S: Serve + Send + 'static, - S::Fut: Send + 'static, { - /// Runs the client handler until completion by [spawning](tokio::spawn) each - /// request handler onto the default executor. - #[cfg(feature = "tokio1")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] - pub fn execute(self) -> impl Future { - self.try_for_each(|request_handler| async { - tokio::spawn(request_handler); - Ok(()) - }) - .map_ok(|()| log::info!("ClientHandler finished.")) - .unwrap_or_else(|e| log::info!("ClientHandler errored out: {}", e)) + /// Executes all requests using the given service function. Requests are handled concurrently + /// by [spawning](tokio::spawn) each handler on tokio's default executor. + pub fn execute(self, serve: S) -> TokioChannelExecutor + where + S: Serve + Send + Sync + 'static, + { + TokioChannelExecutor { inner: self, serve } } } -/// A future that drives the server by [spawning](tokio::spawn) channels and request handlers on the default -/// executor. +/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor) +/// for each new channel. #[pin_project] #[derive(Debug)] #[cfg(feature = "tokio1")] #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -pub struct Running { +pub struct TokioServerExecutor { #[pin] - incoming: St, - server: Se, + inner: T, + serve: S, +} + +/// A future that drives the server by [spawning](tokio::spawn) each [response handler](ResponseHandler) +/// on tokio's default executor. +#[pin_project] +#[derive(Debug)] +#[cfg(feature = "tokio1")] +#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] +pub struct TokioChannelExecutor { + #[pin] + inner: T, + serve: S, } #[cfg(feature = "tokio1")] -impl Future for Running +#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] +impl TokioServerExecutor { + fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { + self.as_mut().project().inner + } +} + +#[cfg(feature = "tokio1")] +#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] +impl TokioChannelExecutor { + fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { + self.as_mut().project().inner + } +} + +#[cfg(feature = "tokio1")] +impl Future for TokioServerExecutor where St: Sized + Stream, C: Channel + Send + 'static, C::Req: Send + 'static, C::Resp: Send + 'static, - Se: Serve + Send + 'static + Clone, - Se::Fut: Send + 'static, + Se: Serve + Send + Sync + 'static + Clone, + Se::Fut: Send, { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - while let Some(channel) = ready!(self.as_mut().project().incoming.poll_next(cx)) { - tokio::spawn( - channel - .respond_with(self.as_mut().project().server.clone()) - .execute(), - ); + while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) { + tokio::spawn(channel.execute(self.serve.clone())); } log::info!("Server shutting down."); Poll::Ready(()) } } -#[tokio::test] -async fn abort_in_flight_requests_on_channel_drop() { - use assert_matches::assert_matches; - use futures::future::Aborted; +#[cfg(feature = "tokio1")] +impl Future for TokioChannelExecutor, S> +where + C: Channel + 'static, + C::Req: Send + 'static, + C::Resp: Send + 'static, + S: Serve + Send + Sync + 'static + Clone, + S::Fut: Send, +{ + type Output = (); - let (_, server_transport) = - super::transport::channel::unbounded::, ClientMessage<()>>(); - let channel = BaseChannel::with_defaults(server_transport); - let mut channel = Box::pin(channel); - - let abort_registration = channel.as_mut().start_request(1); - let future = Abortable::new(async { () }, abort_registration); - drop(channel); - assert_matches!(future.await, Err(Aborted)); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) { + match response_handler { + Ok(resp) => { + let server = self.serve.clone(); + tokio::spawn(async move { + resp.execute(server).await; + }); + } + Err(e) => { + log::info!("Requests stream errored out: {}", e); + break; + } + } + } + Poll::Ready(()) + } } diff --git a/tarpc/src/server/filter.rs b/tarpc/src/server/filter.rs index 8e5d257..b457304 100644 --- a/tarpc/src/server/filter.rs +++ b/tarpc/src/server/filter.rs @@ -108,8 +108,8 @@ where self.inner.config() } - fn in_flight_requests(self: Pin<&mut Self>) -> usize { - self.project().inner.in_flight_requests() + fn in_flight_requests(&self) -> usize { + self.inner.in_flight_requests() } fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration { diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 8de8ef0..cc274d6 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -77,7 +77,7 @@ where &self.config } - fn in_flight_requests(self: Pin<&mut Self>) -> usize { + fn in_flight_requests(&self) -> usize { self.in_flight_requests.len() } diff --git a/tarpc/src/server/throttle.rs b/tarpc/src/server/throttle.rs index d1f7913..807ae43 100644 --- a/tarpc/src/server/throttle.rs +++ b/tarpc/src/server/throttle.rs @@ -113,8 +113,8 @@ where type Req = ::Req; type Resp = ::Resp; - fn in_flight_requests(self: Pin<&mut Self>) -> usize { - self.project().inner.in_flight_requests() + fn in_flight_requests(&self) -> usize { + self.inner.in_flight_requests() } fn config(&self) -> &Config { @@ -292,7 +292,7 @@ fn throttler_poll_next_throttled_sink_not_ready() { fn config(&self) -> &Config { unimplemented!() } - fn in_flight_requests(self: Pin<&mut Self>) -> usize { + fn in_flight_requests(&self) -> usize { 0 } fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration { diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 4b3fcb0..3223732 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -82,7 +82,7 @@ impl Sink for UnboundedChannel { mod tests { use crate::{ client, context, - server::{Handler, Server}, + server::{BaseChannel, Incoming}, transport, }; use assert_matches::assert_matches; @@ -96,9 +96,9 @@ mod tests { let (client_channel, server_channel) = transport::channel::unbounded(); tokio::spawn( - Server::default() - .incoming(stream::once(future::ready(server_channel))) - .respond_with(|_ctx, request: String| { + stream::once(future::ready(server_channel)) + .map(BaseChannel::with_defaults) + .execute(|_ctx, request: String| { future::ready(request.parse::().map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index b961dfe..96e06a0 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,7 +1,10 @@ use futures::prelude::*; use std::io; use tarpc::serde_transport; -use tarpc::{client, context, server::Handler}; +use tarpc::{ + client, context, + server::{BaseChannel, Incoming}, +}; use tokio_serde::formats::Json; #[tarpc::derive_serde] @@ -34,9 +37,11 @@ async fn test_call() -> io::Result<()> { let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?; let addr = transport.local_addr(); tokio::spawn( - tarpc::Server::default() - .incoming(transport.take(1).filter_map(|r| async { r.ok() })) - .respond_with(ColorServer.serve()), + transport + .take(1) + .filter_map(|r| async { r.ok() }) + .map(BaseChannel::with_defaults) + .execute(ColorServer.serve()), ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 18f315b..8a00274 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -5,13 +5,12 @@ use futures::{ }; use std::{ io, - sync::Arc, time::{Duration, SystemTime}, }; use tarpc::{ client::{self}, context, - server::{self, BaseChannel, Channel, Handler}, + server::{self, BaseChannel, Channel, Incoming}, transport::channel, }; use tokio::join; @@ -47,8 +46,8 @@ async fn sequential() -> io::Result<()> { tokio::spawn( BaseChannel::new(server::Config::default(), rx) - .respond_with(Server.serve()) - .execute(), + .requests() + .execute(Server.serve()), ); let mut client = ServiceClient::new(client::Config::default(), tx).spawn()?; @@ -68,19 +67,14 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> { async fn r#loop(); } - struct LoopServer(tokio::sync::mpsc::UnboundedSender); + #[derive(Clone)] + struct LoopServer; #[derive(Debug)] struct AllHandlersComplete; - impl Drop for LoopServer { - fn drop(&mut self) { - let _ = self.0.send(AllHandlersComplete); - } - } - #[tarpc::server] - impl Loop for Arc { + impl Loop for LoopServer { async fn r#loop(self, _: context::Context) { loop { futures::pending!(); @@ -91,7 +85,6 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> { let _ = env_logger::try_init(); let (tx, rx) = channel::unbounded(); - let (rpc_finished_tx, mut rpc_finished) = tokio::sync::mpsc::unbounded_channel(); // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. @@ -105,18 +98,16 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> { let _ = client.r#loop(ctx).await; }); - let mut server = - BaseChannel::with_defaults(rx).respond_with(Arc::new(LoopServer(rpc_finished_tx)).serve()); - let first_handler = server.next().await.unwrap()?; + let mut requests = BaseChannel::with_defaults(rx).requests(); + // Reading a request should trigger the request being registered with BaseChannel. + let first_request = requests.next().await.unwrap()?; + // Dropping the channel should trigger cleanup of outstanding requests. + drop(requests); + // In-flight requests should be aborted by channel cleanup. + // The first and only request sent by the client is `loop`, which is an infinite loop + // on the server side, so if cleanup was not triggered, this line should hang indefinitely. + first_request.execute(LoopServer.serve()).await; - drop(server); - first_handler.await; - - // At this point, a single RPC has been sent and a single response initiated. - // The request handler will loop for a long time unless aborted. - // Now, we assert that the act of disconnecting a client is sufficient to abort all - // handlers initiated by the connection's RPCs. - assert_matches!(rpc_finished.recv().await, Some(AllHandlersComplete)); Ok(()) } @@ -131,9 +122,11 @@ async fn serde() -> io::Result<()> { let transport = tarpc::serde_transport::tcp::listen("localhost:56789", Json::default).await?; let addr = transport.local_addr(); tokio::spawn( - tarpc::Server::default() - .incoming(transport.take(1).filter_map(|r| async { r.ok() })) - .respond_with(Server.serve()), + transport + .take(1) + .filter_map(|r| async { r.ok() }) + .map(BaseChannel::with_defaults) + .execute(Server.serve()), ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; @@ -154,9 +147,9 @@ async fn concurrent() -> io::Result<()> { let (tx, rx) = channel::unbounded(); tokio::spawn( - tarpc::Server::default() - .incoming(stream::once(ready(rx))) - .respond_with(Server.serve()), + stream::once(ready(rx)) + .map(BaseChannel::with_defaults) + .execute(Server.serve()), ); let client = ServiceClient::new(client::Config::default(), tx).spawn()?; @@ -183,9 +176,9 @@ async fn concurrent_join() -> io::Result<()> { let (tx, rx) = channel::unbounded(); tokio::spawn( - tarpc::Server::default() - .incoming(stream::once(ready(rx))) - .respond_with(Server.serve()), + stream::once(ready(rx)) + .map(BaseChannel::with_defaults) + .execute(Server.serve()), ); let client = ServiceClient::new(client::Config::default(), tx).spawn()?; @@ -213,9 +206,9 @@ async fn concurrent_join_all() -> io::Result<()> { let (tx, rx) = channel::unbounded(); tokio::spawn( - tarpc::Server::default() - .incoming(stream::once(ready(rx))) - .respond_with(Server.serve()), + stream::once(ready(rx)) + .map(BaseChannel::with_defaults) + .execute(Server.serve()), ); let client = ServiceClient::new(client::Config::default(), tx).spawn()?;