diff --git a/README.md b/README.md index b9d8591..cf44266 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ This example uses [tokio](https://tokio.rs), so add the following dependencies t your `Cargo.toml`: ```toml +anyhow = "1.0" futures = "1.0" tarpc = { version = "0.26", features = ["tokio1"] } tokio = { version = "1.0", features = ["macros"] } @@ -99,9 +100,8 @@ use futures::{ }; use tarpc::{ client, context, - server::{self, Incoming}, + server::{self, incoming::Incoming}, }; -use std::io; // This is the service definition. It looks a lot like a trait definition. // It defines one RPC, hello, which takes one arg, name, and returns a String. @@ -140,7 +140,7 @@ available behind the `tcp` feature. ```rust #[tokio::main] -async fn main() -> io::Result<()> { +async fn main() -> anyhow::Result<()> { let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); let server = server::BaseChannel::with_defaults(server_transport); diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 13e10f2..6316047 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -17,7 +17,7 @@ use std::{ }; use tarpc::{ context, - server::{self, Channel, Incoming}, + server::{self, incoming::Incoming, Channel}, tokio_serde::formats::Json, }; use tokio::time; diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 84e2246..1cd939c 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -9,7 +9,7 @@ use futures::{future, prelude::*}; use std::env; use tarpc::{ client, context, - server::{BaseChannel, Incoming}, + server::{incoming::Incoming, BaseChannel}, }; use tokio_serde::formats::Json; use tracing_subscriber::prelude::*; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index ed3f824..f71f351 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -317,8 +317,8 @@ where .map_err(ChannelError::Ready) } - fn start_send<'a>( - self: &'a mut Pin<&mut Self>, + fn start_send( + self: &mut Pin<&mut Self>, message: ClientMessage, ) -> Result<(), ChannelError> { self.transport_pin_mut() diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index ae562e2..e69bc3f 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -88,7 +88,7 @@ //! }; //! use tarpc::{ //! client, context, -//! server::{self, Incoming}, +//! server::{self, incoming::Incoming}, //! }; //! //! // This is the service definition. It looks a lot like a trait definition. @@ -111,7 +111,7 @@ //! # }; //! # use tarpc::{ //! # client, context, -//! # server::{self, Incoming}, +//! # server::{self, incoming::Incoming}, //! # }; //! # // This is the service definition. It looks a lot like a trait definition. //! # // It defines one RPC, hello, which takes one arg, name, and returns a String. diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 13fd429..acb0cf1 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -10,6 +10,7 @@ use crate::{ context::{self, SpanExt}, trace, ClientMessage, Request, Response, Transport, }; +use ::tokio::sync::mpsc; use futures::{ future::{AbortRegistration, Abortable}, prelude::*, @@ -19,20 +20,23 @@ use futures::{ }; use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; -use std::{convert::TryFrom, error::Error, fmt, hash::Hash, marker::PhantomData, pin::Pin}; -use tokio::sync::mpsc; +use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin}; use tracing::{info_span, instrument::Instrument, Span}; -mod filter; mod in_flight_requests; #[cfg(test)] mod testing; -mod throttle; -pub use self::{ - filter::ChannelFilter, - throttle::{Throttler, ThrottlerStream}, -}; +/// Provides functionality to apply server limits. +pub mod limits; + +/// Provides helper methods for streams of Channels. +pub mod incoming; + +/// Provides convenience functionality for tokio-enabled applications. +#[cfg(feature = "tokio1")] +#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] +pub mod tokio; /// Settings that control the behavior of [channels](Channel). #[derive(Clone, Debug)] @@ -91,51 +95,13 @@ where } } -/// An extension trait for [streams](Stream) of [`Channels`](Channel). -pub trait Incoming -where - Self: Sized + Stream, - C: Channel, -{ - /// Enforces channel per-key limits. - fn max_channels_per_key(self, n: u32, keymaker: KF) -> filter::ChannelFilter - where - K: fmt::Display + Eq + Hash + Clone + Unpin, - KF: Fn(&C) -> K, - { - ChannelFilter::new(self, n, keymaker) - } - - /// Caps the number of concurrent requests per channel. - fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream { - ThrottlerStream::new(self, n) - } - - /// [Executes](Channel::execute) each incoming channel. Each channel will be handled - /// concurrently by spawning on tokio's default executor, and each request will be also - /// be spawned on tokio's default executor. - #[cfg(feature = "tokio1")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] - fn execute(self, serve: S) -> TokioServerExecutor - where - S: Serve, - { - TokioServerExecutor { inner: self, serve } - } -} - -impl Incoming for S -where - S: Sized + Stream, - C: Channel, -{ -} - -/// BaseChannel is a [Transport] that keeps track of in-flight requests. It converts a -/// [`Transport`](Transport) of [`ClientMessages`](ClientMessage) into a stream of -/// [requests](ClientMessage::Request). +/// BaseChannel is the standard implementation of a [`Channel`]. /// -/// Besides requests, the other type of client message is [cancellation +/// BaseChannel manages a [`Transport`](Transport) of client [`messages`](ClientMessage) and +/// implements a [`Stream`] of [requests](TrackedRequest). See the [`Channel`] documentation for +/// how to use channels. +/// +/// Besides requests, the other type of client message handled by `BaseChannel` is [cancellation /// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation /// messages. Instead, it internally handles them by cancelling corresponding requests (removing /// the corresponding in-flight requests and aborting their handlers). @@ -216,15 +182,15 @@ where match start { Ok(abort_registration) => { drop(entered); - return Ok(TrackedRequest { + Ok(TrackedRequest { request, abort_registration, span, - }); + }) } Err(AlreadyExistsError) => { tracing::trace!("DuplicateRequest"); - return Err(AlreadyExistsError); + Err(AlreadyExistsError) } } } @@ -248,8 +214,8 @@ pub struct TrackedRequest { pub span: Span, } -/// The server end of an open connection with a client, streaming in requests from, and sinking -/// responses to, the client. +/// The server end of an open connection with a client, receiving requests from, and sending +/// responses to, the client. `Channel` is a [`Transport`] with request lifecycle management. /// /// The ways to use a Channel, in order of simplest to most complex, is: /// 1. [`Channel::execute`] - Requires the `tokio1` feature. This method is best for those who @@ -293,12 +259,21 @@ where /// Returns the transport underlying the channel. fn transport(&self) -> &Self::Transport; - /// Caps the number of concurrent requests to `limit`. - fn max_concurrent_requests(self, limit: usize) -> Throttler + /// Caps the number of concurrent requests to `limit`. An error will be returned for requests + /// over the concurrency limit. + /// + /// Note that this is a very + /// simplistic throttling heuristic. It is easy to set a number that is too low for the + /// resources available to the server. For production use cases, a more advanced throttler is + /// likely needed. + fn max_concurrent_requests( + self, + limit: usize, + ) -> limits::requests_per_channel::MaxRequests where Self: Sized, { - Throttler::new(self, limit) + limits::requests_per_channel::MaxRequests::new(self, limit) } /// Returns a stream of requests that automatically handle request cancellation and response @@ -321,11 +296,11 @@ where } /// Runs the channel until completion by executing all requests using the given service - /// function. Request handlers are run concurrently by [spawning](tokio::spawn) on tokio's + /// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's /// default executor. #[cfg(feature = "tokio1")] #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] - fn execute(self, serve: S) -> TokioChannelExecutor, S> + fn execute(self, serve: S) -> self::tokio::TokioChannelExecutor, S> where Self: Sized, S: Serve + Send + 'static, @@ -348,7 +323,7 @@ where Transport(#[source] E), /// An error occurred while polling expired requests. #[error("an error occurred while polling expired requests: {0}")] - Timer(#[source] tokio::time::error::Error), + Timer(#[source] ::tokio::time::error::Error), } impl Stream for BaseChannel @@ -533,18 +508,12 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, C::Error>>> { - loop { - match ready!(self.channel_pin_mut().poll_next(cx)?) { - Some(request) => { - let response_tx = self.responses_tx.clone(); - return Poll::Ready(Some(Ok(InFlightRequest { - request, - response_tx, - }))); - } - None => return Poll::Ready(None), - } - } + self.channel_pin_mut() + .poll_next(cx) + .map_ok(|request| InFlightRequest { + request, + response_tx: self.responses_tx.clone(), + }) } fn pump_write( @@ -710,128 +679,22 @@ where } } -// Send + 'static execution helper methods. - -#[cfg(feature = "tokio1")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -impl Requests -where - C: Channel, - C::Req: Send + 'static, - C::Resp: Send + 'static, -{ - /// Executes all requests using the given service function. Requests are handled concurrently - /// by [spawning](tokio::spawn) each handler on tokio's default executor. - pub fn execute(self, serve: S) -> TokioChannelExecutor - where - S: Serve + Send + 'static, - { - TokioChannelExecutor { inner: self, serve } - } -} - -/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor) -/// for each new channel. -#[pin_project] -#[derive(Debug)] -#[cfg(feature = "tokio1")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -pub struct TokioServerExecutor { - #[pin] - inner: T, - serve: S, -} - -/// A future that drives the server by [spawning](tokio::spawn) each [response -/// handler](InFlightRequest::execute) on tokio's default executor. -#[pin_project] -#[derive(Debug)] -#[cfg(feature = "tokio1")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -pub struct TokioChannelExecutor { - #[pin] - inner: T, - serve: S, -} - -#[cfg(feature = "tokio1")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -impl TokioServerExecutor { - fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { - self.as_mut().project().inner - } -} - -#[cfg(feature = "tokio1")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -impl TokioChannelExecutor { - fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { - self.as_mut().project().inner - } -} - -#[cfg(feature = "tokio1")] -impl Future for TokioServerExecutor -where - St: Sized + Stream, - C: Channel + Send + 'static, - C::Req: Send + 'static, - C::Resp: Send + 'static, - Se: Serve + Send + 'static + Clone, - Se::Fut: Send, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) { - tokio::spawn(channel.execute(self.serve.clone())); - } - tracing::info!("Server shutting down."); - Poll::Ready(()) - } -} - -#[cfg(feature = "tokio1")] -impl Future for TokioChannelExecutor, S> -where - C: Channel + 'static, - C::Req: Send + 'static, - C::Resp: Send + 'static, - S: Serve + Send + 'static + Clone, - S::Fut: Send, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) { - match response_handler { - Ok(resp) => { - let server = self.serve.clone(); - tokio::spawn(async move { - resp.execute(server).await; - }); - } - Err(e) => { - tracing::warn!("Requests stream errored out: {}", e); - break; - } - } - } - Poll::Ready(()) - } -} - #[cfg(test)] mod tests { - use super::*; - + use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests}; use crate::{ - trace, + context, trace, transport::channel::{self, UnboundedChannel}, + ClientMessage, Request, Response, }; use assert_matches::assert_matches; - use futures::future::{pending, Aborted}; + use futures::{ + future::{pending, AbortRegistration, Abortable, Aborted}, + prelude::*, + Future, + }; use futures_test::task::noop_context; + use std::{pin::Pin, task::Poll}; fn test_channel() -> ( Pin, Response>>>>, diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs new file mode 100644 index 0000000..5479b05 --- /dev/null +++ b/tarpc/src/server/incoming.rs @@ -0,0 +1,49 @@ +use super::{ + limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel}, + Channel, +}; +use futures::prelude::*; +use std::{fmt, hash::Hash}; + +#[cfg(feature = "tokio1")] +use super::{tokio::TokioServerExecutor, Serve}; + +/// An extension trait for [streams](Stream) of [`Channels`](Channel). +pub trait Incoming +where + Self: Sized + Stream, + C: Channel, +{ + /// Enforces channel per-key limits. + fn max_channels_per_key(self, n: u32, keymaker: KF) -> MaxChannelsPerKey + where + K: fmt::Display + Eq + Hash + Clone + Unpin, + KF: Fn(&C) -> K, + { + MaxChannelsPerKey::new(self, n, keymaker) + } + + /// Caps the number of concurrent requests per channel. + fn max_concurrent_requests_per_channel(self, n: usize) -> MaxRequestsPerChannel { + MaxRequestsPerChannel::new(self, n) + } + + /// [Executes](Channel::execute) each incoming channel. Each channel will be handled + /// concurrently by spawning on tokio's default executor, and each request will be also + /// be spawned on tokio's default executor. + #[cfg(feature = "tokio1")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] + fn execute(self, serve: S) -> TokioServerExecutor + where + S: Serve, + { + TokioServerExecutor::new(self, serve) + } +} + +impl Incoming for S +where + S: Sized + Stream, + C: Channel, +{ +} diff --git a/tarpc/src/server/limits.rs b/tarpc/src/server/limits.rs new file mode 100644 index 0000000..c74dba9 --- /dev/null +++ b/tarpc/src/server/limits.rs @@ -0,0 +1,5 @@ +/// Provides functionality to limit the number of active channels. +pub mod channels_per_key; + +/// Provides a [channel](crate::server::Channel) that limits the number of in-flight requests. +pub mod requests_per_channel; diff --git a/tarpc/src/server/filter.rs b/tarpc/src/server/limits/channels_per_key.rs similarity index 93% rename from tarpc/src/server/filter.rs rename to tarpc/src/server/limits/channels_per_key.rs index 295fd95..272dd56 100644 --- a/tarpc/src/server/filter.rs +++ b/tarpc/src/server/limits/channels_per_key.rs @@ -18,10 +18,14 @@ use std::{ use tokio::sync::mpsc; use tracing::{debug, info, trace}; -/// A single-threaded filter that drops channels based on per-key limits. +/// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on +/// per-key limits. +/// +/// The decision to drop a Channel is made once at the time the Channel materializes. Once a +/// Channel is yielded, it will not be prematurely dropped. #[pin_project] #[derive(Debug)] -pub struct ChannelFilter +pub struct MaxChannelsPerKey where K: Eq + Hash, { @@ -34,7 +38,7 @@ where keymaker: F, } -/// A channel that is tracked by a ChannelFilter. +/// A channel that is tracked by [`MaxChannelsPerKey`]. #[pin_project] #[derive(Debug)] pub struct TrackedChannel { @@ -129,7 +133,7 @@ impl TrackedChannel { } } -impl ChannelFilter +impl MaxChannelsPerKey where K: Eq + Hash, S: Stream, @@ -138,7 +142,7 @@ where /// Sheds new channels to stay under configured limits. pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self { let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel(); - ChannelFilter { + MaxChannelsPerKey { listener: listener.fuse(), channels_per_key, dropped_keys, @@ -149,7 +153,7 @@ where } } -impl ChannelFilter +impl MaxChannelsPerKey where S: Stream, K: fmt::Display + Eq + Hash + Clone + Unpin, @@ -241,7 +245,7 @@ where } } -impl Stream for ChannelFilter +impl Stream for MaxChannelsPerKey where S: Stream, K: fmt::Display + Eq + Hash + Clone + Unpin, @@ -344,7 +348,7 @@ fn channel_filter_increment_channels_for_key() { key: &'static str, } let (_, listener) = futures::channel::mpsc::unbounded(); - let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap(); assert_eq!(Arc::strong_count(&tracker1), 1); @@ -365,7 +369,7 @@ fn channel_filter_handle_new_channel() { key: &'static str, } let (_, listener) = futures::channel::mpsc::unbounded(); - let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); let channel1 = filter .as_mut() @@ -397,7 +401,7 @@ fn channel_filter_poll_listener() { key: &'static str, } let (new_channels, listener) = futures::channel::mpsc::unbounded(); - let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); new_channels @@ -433,7 +437,7 @@ fn channel_filter_poll_closed_channels() { key: &'static str, } let (new_channels, listener) = futures::channel::mpsc::unbounded(); - let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); new_channels @@ -461,7 +465,7 @@ fn channel_filter_stream() { key: &'static str, } let (new_channels, listener) = futures::channel::mpsc::unbounded(); - let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); new_channels diff --git a/tarpc/src/server/throttle.rs b/tarpc/src/server/limits/requests_per_channel.rs similarity index 86% rename from tarpc/src/server/throttle.rs rename to tarpc/src/server/limits/requests_per_channel.rs index a02e60f..3c29836 100644 --- a/tarpc/src/server/throttle.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -4,44 +4,49 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use super::{Channel, Config}; -use crate::{Response, ServerError}; +use crate::{ + server::{Channel, Config}, + Response, ServerError, +}; use futures::{prelude::*, ready, task::*}; use pin_project::pin_project; use std::{io, pin::Pin}; -/// A [`Channel`] that limits the number of concurrent -/// requests by throttling. +/// A [`Channel`] that limits the number of concurrent requests by throttling. +/// +/// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low +/// for the resources available to the server. For production use cases, a more advanced throttler +/// is likely needed. #[pin_project] #[derive(Debug)] -pub struct Throttler { +pub struct MaxRequests { max_in_flight_requests: usize, #[pin] inner: C, } -impl Throttler { +impl MaxRequests { /// Returns the inner channel. pub fn get_ref(&self) -> &C { &self.inner } } -impl Throttler +impl MaxRequests where C: Channel, { - /// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to + /// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to /// `max_in_flight_requests`. pub fn new(inner: C, max_in_flight_requests: usize) -> Self { - Throttler { + MaxRequests { max_in_flight_requests, inner, } } } -impl Stream for Throttler +impl Stream for MaxRequests where C: Channel, { @@ -75,7 +80,7 @@ where } } -impl Sink::Resp>> for Throttler +impl Sink::Resp>> for MaxRequests where C: Channel, { @@ -101,13 +106,13 @@ where } } -impl AsRef for Throttler { +impl AsRef for MaxRequests { fn as_ref(&self) -> &C { &self.inner } } -impl Channel for Throttler +impl Channel for MaxRequests where C: Channel, { @@ -128,16 +133,17 @@ where } } -/// A stream of throttling channels. +/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on +/// the number of in-flight requests. #[pin_project] #[derive(Debug)] -pub struct ThrottlerStream { +pub struct MaxRequestsPerChannel { #[pin] inner: S, max_in_flight_requests: usize, } -impl ThrottlerStream +impl MaxRequestsPerChannel where S: Stream, ::Item: Channel, @@ -150,16 +156,16 @@ where } } -impl Stream for ThrottlerStream +impl Stream for MaxRequestsPerChannel where S: Stream, ::Item: Channel, { - type Item = Throttler<::Item>; + type Item = MaxRequests<::Item>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { match ready!(self.as_mut().project().inner.poll_next(cx)) { - Some(channel) => Poll::Ready(Some(Throttler::new( + Some(channel) => Poll::Ready(Some(MaxRequests::new( channel, *self.project().max_in_flight_requests, ))), @@ -185,7 +191,7 @@ mod tests { #[tokio::test] async fn throttler_in_flight_requests() { - let throttler = Throttler { + let throttler = MaxRequests { max_in_flight_requests: 0, inner: FakeChannel::default::(), }; @@ -207,7 +213,7 @@ mod tests { #[test] fn throttler_poll_next_done() { - let throttler = Throttler { + let throttler = MaxRequests { max_in_flight_requests: 0, inner: FakeChannel::default::(), }; @@ -218,7 +224,7 @@ mod tests { #[test] fn throttler_poll_next_some() -> io::Result<()> { - let throttler = Throttler { + let throttler = MaxRequests { max_in_flight_requests: 1, inner: FakeChannel::default::(), }; @@ -238,7 +244,7 @@ mod tests { #[test] fn throttler_poll_next_throttled() { - let throttler = Throttler { + let throttler = MaxRequests { max_in_flight_requests: 0, inner: FakeChannel::default::(), }; @@ -254,7 +260,7 @@ mod tests { #[test] fn throttler_poll_next_throttled_sink_not_ready() { - let throttler = Throttler { + let throttler = MaxRequests { max_in_flight_requests: 0, inner: PendingSink::default::(), }; @@ -309,7 +315,7 @@ mod tests { #[tokio::test] async fn throttler_start_send() { - let throttler = Throttler { + let throttler = MaxRequests { max_in_flight_requests: 0, inner: FakeChannel::default::(), }; diff --git a/tarpc/src/server/tokio.rs b/tarpc/src/server/tokio.rs new file mode 100644 index 0000000..87db4c2 --- /dev/null +++ b/tarpc/src/server/tokio.rs @@ -0,0 +1,111 @@ +use super::{Channel, Requests, Serve}; +use futures::{prelude::*, ready, task::*}; +use pin_project::pin_project; +use std::pin::Pin; + +/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor) +/// for each new channel. Returned by +/// [`Incoming::execute`](crate::server::incoming::Incoming::execute). +#[pin_project] +#[derive(Debug)] +pub struct TokioServerExecutor { + #[pin] + inner: T, + serve: S, +} + +impl TokioServerExecutor { + pub(crate) fn new(inner: T, serve: S) -> Self { + Self { inner, serve } + } +} + +/// A future that drives the server by [spawning](tokio::spawn) each [response +/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by +/// [`Channel::execute`](crate::server::Channel::execute). +#[pin_project] +#[derive(Debug)] +pub struct TokioChannelExecutor { + #[pin] + inner: T, + serve: S, +} + +impl TokioServerExecutor { + fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { + self.as_mut().project().inner + } +} + +impl TokioChannelExecutor { + fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { + self.as_mut().project().inner + } +} + +// Send + 'static execution helper methods. + +impl Requests +where + C: Channel, + C::Req: Send + 'static, + C::Resp: Send + 'static, +{ + /// Executes all requests using the given service function. Requests are handled concurrently + /// by [spawning](::tokio::spawn) each handler on tokio's default executor. + pub fn execute(self, serve: S) -> TokioChannelExecutor + where + S: Serve + Send + 'static, + { + TokioChannelExecutor { inner: self, serve } + } +} + +impl Future for TokioServerExecutor +where + St: Sized + Stream, + C: Channel + Send + 'static, + C::Req: Send + 'static, + C::Resp: Send + 'static, + Se: Serve + Send + 'static + Clone, + Se::Fut: Send, +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) { + tokio::spawn(channel.execute(self.serve.clone())); + } + tracing::info!("Server shutting down."); + Poll::Ready(()) + } +} + +impl Future for TokioChannelExecutor, S> +where + C: Channel + 'static, + C::Req: Send + 'static, + C::Resp: Send + 'static, + S: Serve + Send + 'static + Clone, + S::Fut: Send, +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) { + match response_handler { + Ok(resp) => { + let server = self.serve.clone(); + tokio::spawn(async move { + resp.execute(server).await; + }); + } + Err(e) => { + tracing::warn!("Requests stream errored out: {}", e); + break; + } + } + } + Poll::Ready(()) + } +} diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index f6dee17..c39ed93 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -151,7 +151,7 @@ impl Sink for Channel { mod tests { use crate::{ client, context, - server::{BaseChannel, Incoming}, + server::{incoming::Incoming, BaseChannel}, transport::{ self, channel::{Channel, UnboundedChannel}, diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 4cc52d6..9a3c926 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -2,7 +2,7 @@ use futures::prelude::*; use tarpc::serde_transport; use tarpc::{ client, context, - server::{BaseChannel, Incoming}, + server::{incoming::Incoming, BaseChannel}, }; use tokio_serde::formats::Json; diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 9d97ab2..b1aa431 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -7,7 +7,7 @@ use std::time::{Duration, SystemTime}; use tarpc::{ client::{self}, context, - server::{self, BaseChannel, Channel, Incoming}, + server::{self, incoming::Incoming, BaseChannel, Channel}, transport::channel, }; use tokio::join;