From ea7b6763c4181e07479db3b13103d47b2fcd8aae Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Wed, 21 Apr 2021 15:57:08 -0700 Subject: [PATCH] Refactor server module. In the interest of the user's attention, some ancillary APIs have been moved to new submodules: - server::limits contains what was previously called Throttler and ChannelFilter. Both of those names were very generic, when the methods applied by these types were very specific (and also simplistic). Renames have occurred: - ThrottlerStream => MaxRequestsPerChannel - Throttler => MaxRequests - ChannelFilter => MaxChannelsPerKey - server::incoming contains the Incoming trait. - server::tokio contains the tokio-specific helper types. The 5 structs and 1 enum remaining in the base server module are all core to the functioning of the server. --- README.md | 6 +- example-service/src/server.rs | 2 +- tarpc/examples/tracing.rs | 2 +- tarpc/src/client.rs | 4 +- tarpc/src/lib.rs | 4 +- tarpc/src/server.rs | 243 ++++-------------- tarpc/src/server/incoming.rs | 49 ++++ tarpc/src/server/limits.rs | 5 + .../{filter.rs => limits/channels_per_key.rs} | 28 +- .../requests_per_channel.rs} | 56 ++-- tarpc/src/server/tokio.rs | 111 ++++++++ tarpc/src/transport/channel.rs | 2 +- tarpc/tests/dataservice.rs | 2 +- tarpc/tests/service_functional.rs | 2 +- 14 files changed, 277 insertions(+), 239 deletions(-) create mode 100644 tarpc/src/server/incoming.rs create mode 100644 tarpc/src/server/limits.rs rename tarpc/src/server/{filter.rs => limits/channels_per_key.rs} (93%) rename tarpc/src/server/{throttle.rs => limits/requests_per_channel.rs} (86%) create mode 100644 tarpc/src/server/tokio.rs diff --git a/README.md b/README.md index b9d8591..cf44266 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ This example uses [tokio](https://tokio.rs), so add the following dependencies t your `Cargo.toml`: ```toml +anyhow = "1.0" futures = "1.0" tarpc = { version = "0.26", features = ["tokio1"] } tokio = { version = "1.0", features = ["macros"] } @@ -99,9 +100,8 @@ use futures::{ }; use tarpc::{ client, context, - server::{self, Incoming}, + server::{self, incoming::Incoming}, }; -use std::io; // This is the service definition. It looks a lot like a trait definition. // It defines one RPC, hello, which takes one arg, name, and returns a String. @@ -140,7 +140,7 @@ available behind the `tcp` feature. ```rust #[tokio::main] -async fn main() -> io::Result<()> { +async fn main() -> anyhow::Result<()> { let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); let server = server::BaseChannel::with_defaults(server_transport); diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 13e10f2..6316047 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -17,7 +17,7 @@ use std::{ }; use tarpc::{ context, - server::{self, Channel, Incoming}, + server::{self, incoming::Incoming, Channel}, tokio_serde::formats::Json, }; use tokio::time; diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 84e2246..1cd939c 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -9,7 +9,7 @@ use futures::{future, prelude::*}; use std::env; use tarpc::{ client, context, - server::{BaseChannel, Incoming}, + server::{incoming::Incoming, BaseChannel}, }; use tokio_serde::formats::Json; use tracing_subscriber::prelude::*; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index ed3f824..f71f351 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -317,8 +317,8 @@ where .map_err(ChannelError::Ready) } - fn start_send<'a>( - self: &'a mut Pin<&mut Self>, + fn start_send( + self: &mut Pin<&mut Self>, message: ClientMessage, ) -> Result<(), ChannelError> { self.transport_pin_mut() diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index ae562e2..e69bc3f 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -88,7 +88,7 @@ //! }; //! use tarpc::{ //! client, context, -//! server::{self, Incoming}, +//! server::{self, incoming::Incoming}, //! }; //! //! // This is the service definition. It looks a lot like a trait definition. @@ -111,7 +111,7 @@ //! # }; //! # use tarpc::{ //! # client, context, -//! # server::{self, Incoming}, +//! # server::{self, incoming::Incoming}, //! # }; //! # // This is the service definition. It looks a lot like a trait definition. //! # // It defines one RPC, hello, which takes one arg, name, and returns a String. diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 13fd429..acb0cf1 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -10,6 +10,7 @@ use crate::{ context::{self, SpanExt}, trace, ClientMessage, Request, Response, Transport, }; +use ::tokio::sync::mpsc; use futures::{ future::{AbortRegistration, Abortable}, prelude::*, @@ -19,20 +20,23 @@ use futures::{ }; use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; -use std::{convert::TryFrom, error::Error, fmt, hash::Hash, marker::PhantomData, pin::Pin}; -use tokio::sync::mpsc; +use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin}; use tracing::{info_span, instrument::Instrument, Span}; -mod filter; mod in_flight_requests; #[cfg(test)] mod testing; -mod throttle; -pub use self::{ - filter::ChannelFilter, - throttle::{Throttler, ThrottlerStream}, -}; +/// Provides functionality to apply server limits. +pub mod limits; + +/// Provides helper methods for streams of Channels. +pub mod incoming; + +/// Provides convenience functionality for tokio-enabled applications. +#[cfg(feature = "tokio1")] +#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] +pub mod tokio; /// Settings that control the behavior of [channels](Channel). #[derive(Clone, Debug)] @@ -91,51 +95,13 @@ where } } -/// An extension trait for [streams](Stream) of [`Channels`](Channel). -pub trait Incoming -where - Self: Sized + Stream, - C: Channel, -{ - /// Enforces channel per-key limits. - fn max_channels_per_key(self, n: u32, keymaker: KF) -> filter::ChannelFilter - where - K: fmt::Display + Eq + Hash + Clone + Unpin, - KF: Fn(&C) -> K, - { - ChannelFilter::new(self, n, keymaker) - } - - /// Caps the number of concurrent requests per channel. - fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream { - ThrottlerStream::new(self, n) - } - - /// [Executes](Channel::execute) each incoming channel. Each channel will be handled - /// concurrently by spawning on tokio's default executor, and each request will be also - /// be spawned on tokio's default executor. - #[cfg(feature = "tokio1")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] - fn execute(self, serve: S) -> TokioServerExecutor - where - S: Serve, - { - TokioServerExecutor { inner: self, serve } - } -} - -impl Incoming for S -where - S: Sized + Stream, - C: Channel, -{ -} - -/// BaseChannel is a [Transport] that keeps track of in-flight requests. It converts a -/// [`Transport`](Transport) of [`ClientMessages`](ClientMessage) into a stream of -/// [requests](ClientMessage::Request). +/// BaseChannel is the standard implementation of a [`Channel`]. /// -/// Besides requests, the other type of client message is [cancellation +/// BaseChannel manages a [`Transport`](Transport) of client [`messages`](ClientMessage) and +/// implements a [`Stream`] of [requests](TrackedRequest). See the [`Channel`] documentation for +/// how to use channels. +/// +/// Besides requests, the other type of client message handled by `BaseChannel` is [cancellation /// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation /// messages. Instead, it internally handles them by cancelling corresponding requests (removing /// the corresponding in-flight requests and aborting their handlers). @@ -216,15 +182,15 @@ where match start { Ok(abort_registration) => { drop(entered); - return Ok(TrackedRequest { + Ok(TrackedRequest { request, abort_registration, span, - }); + }) } Err(AlreadyExistsError) => { tracing::trace!("DuplicateRequest"); - return Err(AlreadyExistsError); + Err(AlreadyExistsError) } } } @@ -248,8 +214,8 @@ pub struct TrackedRequest { pub span: Span, } -/// The server end of an open connection with a client, streaming in requests from, and sinking -/// responses to, the client. +/// The server end of an open connection with a client, receiving requests from, and sending +/// responses to, the client. `Channel` is a [`Transport`] with request lifecycle management. /// /// The ways to use a Channel, in order of simplest to most complex, is: /// 1. [`Channel::execute`] - Requires the `tokio1` feature. This method is best for those who @@ -293,12 +259,21 @@ where /// Returns the transport underlying the channel. fn transport(&self) -> &Self::Transport; - /// Caps the number of concurrent requests to `limit`. - fn max_concurrent_requests(self, limit: usize) -> Throttler + /// Caps the number of concurrent requests to `limit`. An error will be returned for requests + /// over the concurrency limit. + /// + /// Note that this is a very + /// simplistic throttling heuristic. It is easy to set a number that is too low for the + /// resources available to the server. For production use cases, a more advanced throttler is + /// likely needed. + fn max_concurrent_requests( + self, + limit: usize, + ) -> limits::requests_per_channel::MaxRequests where Self: Sized, { - Throttler::new(self, limit) + limits::requests_per_channel::MaxRequests::new(self, limit) } /// Returns a stream of requests that automatically handle request cancellation and response @@ -321,11 +296,11 @@ where } /// Runs the channel until completion by executing all requests using the given service - /// function. Request handlers are run concurrently by [spawning](tokio::spawn) on tokio's + /// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's /// default executor. #[cfg(feature = "tokio1")] #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] - fn execute(self, serve: S) -> TokioChannelExecutor, S> + fn execute(self, serve: S) -> self::tokio::TokioChannelExecutor, S> where Self: Sized, S: Serve + Send + 'static, @@ -348,7 +323,7 @@ where Transport(#[source] E), /// An error occurred while polling expired requests. #[error("an error occurred while polling expired requests: {0}")] - Timer(#[source] tokio::time::error::Error), + Timer(#[source] ::tokio::time::error::Error), } impl Stream for BaseChannel @@ -533,18 +508,12 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, C::Error>>> { - loop { - match ready!(self.channel_pin_mut().poll_next(cx)?) { - Some(request) => { - let response_tx = self.responses_tx.clone(); - return Poll::Ready(Some(Ok(InFlightRequest { - request, - response_tx, - }))); - } - None => return Poll::Ready(None), - } - } + self.channel_pin_mut() + .poll_next(cx) + .map_ok(|request| InFlightRequest { + request, + response_tx: self.responses_tx.clone(), + }) } fn pump_write( @@ -710,128 +679,22 @@ where } } -// Send + 'static execution helper methods. - -#[cfg(feature = "tokio1")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -impl Requests -where - C: Channel, - C::Req: Send + 'static, - C::Resp: Send + 'static, -{ - /// Executes all requests using the given service function. Requests are handled concurrently - /// by [spawning](tokio::spawn) each handler on tokio's default executor. - pub fn execute(self, serve: S) -> TokioChannelExecutor - where - S: Serve + Send + 'static, - { - TokioChannelExecutor { inner: self, serve } - } -} - -/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor) -/// for each new channel. -#[pin_project] -#[derive(Debug)] -#[cfg(feature = "tokio1")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -pub struct TokioServerExecutor { - #[pin] - inner: T, - serve: S, -} - -/// A future that drives the server by [spawning](tokio::spawn) each [response -/// handler](InFlightRequest::execute) on tokio's default executor. -#[pin_project] -#[derive(Debug)] -#[cfg(feature = "tokio1")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -pub struct TokioChannelExecutor { - #[pin] - inner: T, - serve: S, -} - -#[cfg(feature = "tokio1")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -impl TokioServerExecutor { - fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { - self.as_mut().project().inner - } -} - -#[cfg(feature = "tokio1")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] -impl TokioChannelExecutor { - fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { - self.as_mut().project().inner - } -} - -#[cfg(feature = "tokio1")] -impl Future for TokioServerExecutor -where - St: Sized + Stream, - C: Channel + Send + 'static, - C::Req: Send + 'static, - C::Resp: Send + 'static, - Se: Serve + Send + 'static + Clone, - Se::Fut: Send, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) { - tokio::spawn(channel.execute(self.serve.clone())); - } - tracing::info!("Server shutting down."); - Poll::Ready(()) - } -} - -#[cfg(feature = "tokio1")] -impl Future for TokioChannelExecutor, S> -where - C: Channel + 'static, - C::Req: Send + 'static, - C::Resp: Send + 'static, - S: Serve + Send + 'static + Clone, - S::Fut: Send, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) { - match response_handler { - Ok(resp) => { - let server = self.serve.clone(); - tokio::spawn(async move { - resp.execute(server).await; - }); - } - Err(e) => { - tracing::warn!("Requests stream errored out: {}", e); - break; - } - } - } - Poll::Ready(()) - } -} - #[cfg(test)] mod tests { - use super::*; - + use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests}; use crate::{ - trace, + context, trace, transport::channel::{self, UnboundedChannel}, + ClientMessage, Request, Response, }; use assert_matches::assert_matches; - use futures::future::{pending, Aborted}; + use futures::{ + future::{pending, AbortRegistration, Abortable, Aborted}, + prelude::*, + Future, + }; use futures_test::task::noop_context; + use std::{pin::Pin, task::Poll}; fn test_channel() -> ( Pin, Response>>>>, diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs new file mode 100644 index 0000000..5479b05 --- /dev/null +++ b/tarpc/src/server/incoming.rs @@ -0,0 +1,49 @@ +use super::{ + limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel}, + Channel, +}; +use futures::prelude::*; +use std::{fmt, hash::Hash}; + +#[cfg(feature = "tokio1")] +use super::{tokio::TokioServerExecutor, Serve}; + +/// An extension trait for [streams](Stream) of [`Channels`](Channel). +pub trait Incoming +where + Self: Sized + Stream, + C: Channel, +{ + /// Enforces channel per-key limits. + fn max_channels_per_key(self, n: u32, keymaker: KF) -> MaxChannelsPerKey + where + K: fmt::Display + Eq + Hash + Clone + Unpin, + KF: Fn(&C) -> K, + { + MaxChannelsPerKey::new(self, n, keymaker) + } + + /// Caps the number of concurrent requests per channel. + fn max_concurrent_requests_per_channel(self, n: usize) -> MaxRequestsPerChannel { + MaxRequestsPerChannel::new(self, n) + } + + /// [Executes](Channel::execute) each incoming channel. Each channel will be handled + /// concurrently by spawning on tokio's default executor, and each request will be also + /// be spawned on tokio's default executor. + #[cfg(feature = "tokio1")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] + fn execute(self, serve: S) -> TokioServerExecutor + where + S: Serve, + { + TokioServerExecutor::new(self, serve) + } +} + +impl Incoming for S +where + S: Sized + Stream, + C: Channel, +{ +} diff --git a/tarpc/src/server/limits.rs b/tarpc/src/server/limits.rs new file mode 100644 index 0000000..c74dba9 --- /dev/null +++ b/tarpc/src/server/limits.rs @@ -0,0 +1,5 @@ +/// Provides functionality to limit the number of active channels. +pub mod channels_per_key; + +/// Provides a [channel](crate::server::Channel) that limits the number of in-flight requests. +pub mod requests_per_channel; diff --git a/tarpc/src/server/filter.rs b/tarpc/src/server/limits/channels_per_key.rs similarity index 93% rename from tarpc/src/server/filter.rs rename to tarpc/src/server/limits/channels_per_key.rs index 295fd95..272dd56 100644 --- a/tarpc/src/server/filter.rs +++ b/tarpc/src/server/limits/channels_per_key.rs @@ -18,10 +18,14 @@ use std::{ use tokio::sync::mpsc; use tracing::{debug, info, trace}; -/// A single-threaded filter that drops channels based on per-key limits. +/// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on +/// per-key limits. +/// +/// The decision to drop a Channel is made once at the time the Channel materializes. Once a +/// Channel is yielded, it will not be prematurely dropped. #[pin_project] #[derive(Debug)] -pub struct ChannelFilter +pub struct MaxChannelsPerKey where K: Eq + Hash, { @@ -34,7 +38,7 @@ where keymaker: F, } -/// A channel that is tracked by a ChannelFilter. +/// A channel that is tracked by [`MaxChannelsPerKey`]. #[pin_project] #[derive(Debug)] pub struct TrackedChannel { @@ -129,7 +133,7 @@ impl TrackedChannel { } } -impl ChannelFilter +impl MaxChannelsPerKey where K: Eq + Hash, S: Stream, @@ -138,7 +142,7 @@ where /// Sheds new channels to stay under configured limits. pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self { let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel(); - ChannelFilter { + MaxChannelsPerKey { listener: listener.fuse(), channels_per_key, dropped_keys, @@ -149,7 +153,7 @@ where } } -impl ChannelFilter +impl MaxChannelsPerKey where S: Stream, K: fmt::Display + Eq + Hash + Clone + Unpin, @@ -241,7 +245,7 @@ where } } -impl Stream for ChannelFilter +impl Stream for MaxChannelsPerKey where S: Stream, K: fmt::Display + Eq + Hash + Clone + Unpin, @@ -344,7 +348,7 @@ fn channel_filter_increment_channels_for_key() { key: &'static str, } let (_, listener) = futures::channel::mpsc::unbounded(); - let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap(); assert_eq!(Arc::strong_count(&tracker1), 1); @@ -365,7 +369,7 @@ fn channel_filter_handle_new_channel() { key: &'static str, } let (_, listener) = futures::channel::mpsc::unbounded(); - let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); let channel1 = filter .as_mut() @@ -397,7 +401,7 @@ fn channel_filter_poll_listener() { key: &'static str, } let (new_channels, listener) = futures::channel::mpsc::unbounded(); - let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); new_channels @@ -433,7 +437,7 @@ fn channel_filter_poll_closed_channels() { key: &'static str, } let (new_channels, listener) = futures::channel::mpsc::unbounded(); - let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); new_channels @@ -461,7 +465,7 @@ fn channel_filter_stream() { key: &'static str, } let (new_channels, listener) = futures::channel::mpsc::unbounded(); - let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); new_channels diff --git a/tarpc/src/server/throttle.rs b/tarpc/src/server/limits/requests_per_channel.rs similarity index 86% rename from tarpc/src/server/throttle.rs rename to tarpc/src/server/limits/requests_per_channel.rs index a02e60f..3c29836 100644 --- a/tarpc/src/server/throttle.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -4,44 +4,49 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use super::{Channel, Config}; -use crate::{Response, ServerError}; +use crate::{ + server::{Channel, Config}, + Response, ServerError, +}; use futures::{prelude::*, ready, task::*}; use pin_project::pin_project; use std::{io, pin::Pin}; -/// A [`Channel`] that limits the number of concurrent -/// requests by throttling. +/// A [`Channel`] that limits the number of concurrent requests by throttling. +/// +/// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low +/// for the resources available to the server. For production use cases, a more advanced throttler +/// is likely needed. #[pin_project] #[derive(Debug)] -pub struct Throttler { +pub struct MaxRequests { max_in_flight_requests: usize, #[pin] inner: C, } -impl Throttler { +impl MaxRequests { /// Returns the inner channel. pub fn get_ref(&self) -> &C { &self.inner } } -impl Throttler +impl MaxRequests where C: Channel, { - /// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to + /// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to /// `max_in_flight_requests`. pub fn new(inner: C, max_in_flight_requests: usize) -> Self { - Throttler { + MaxRequests { max_in_flight_requests, inner, } } } -impl Stream for Throttler +impl Stream for MaxRequests where C: Channel, { @@ -75,7 +80,7 @@ where } } -impl Sink::Resp>> for Throttler +impl Sink::Resp>> for MaxRequests where C: Channel, { @@ -101,13 +106,13 @@ where } } -impl AsRef for Throttler { +impl AsRef for MaxRequests { fn as_ref(&self) -> &C { &self.inner } } -impl Channel for Throttler +impl Channel for MaxRequests where C: Channel, { @@ -128,16 +133,17 @@ where } } -/// A stream of throttling channels. +/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on +/// the number of in-flight requests. #[pin_project] #[derive(Debug)] -pub struct ThrottlerStream { +pub struct MaxRequestsPerChannel { #[pin] inner: S, max_in_flight_requests: usize, } -impl ThrottlerStream +impl MaxRequestsPerChannel where S: Stream, ::Item: Channel, @@ -150,16 +156,16 @@ where } } -impl Stream for ThrottlerStream +impl Stream for MaxRequestsPerChannel where S: Stream, ::Item: Channel, { - type Item = Throttler<::Item>; + type Item = MaxRequests<::Item>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { match ready!(self.as_mut().project().inner.poll_next(cx)) { - Some(channel) => Poll::Ready(Some(Throttler::new( + Some(channel) => Poll::Ready(Some(MaxRequests::new( channel, *self.project().max_in_flight_requests, ))), @@ -185,7 +191,7 @@ mod tests { #[tokio::test] async fn throttler_in_flight_requests() { - let throttler = Throttler { + let throttler = MaxRequests { max_in_flight_requests: 0, inner: FakeChannel::default::(), }; @@ -207,7 +213,7 @@ mod tests { #[test] fn throttler_poll_next_done() { - let throttler = Throttler { + let throttler = MaxRequests { max_in_flight_requests: 0, inner: FakeChannel::default::(), }; @@ -218,7 +224,7 @@ mod tests { #[test] fn throttler_poll_next_some() -> io::Result<()> { - let throttler = Throttler { + let throttler = MaxRequests { max_in_flight_requests: 1, inner: FakeChannel::default::(), }; @@ -238,7 +244,7 @@ mod tests { #[test] fn throttler_poll_next_throttled() { - let throttler = Throttler { + let throttler = MaxRequests { max_in_flight_requests: 0, inner: FakeChannel::default::(), }; @@ -254,7 +260,7 @@ mod tests { #[test] fn throttler_poll_next_throttled_sink_not_ready() { - let throttler = Throttler { + let throttler = MaxRequests { max_in_flight_requests: 0, inner: PendingSink::default::(), }; @@ -309,7 +315,7 @@ mod tests { #[tokio::test] async fn throttler_start_send() { - let throttler = Throttler { + let throttler = MaxRequests { max_in_flight_requests: 0, inner: FakeChannel::default::(), }; diff --git a/tarpc/src/server/tokio.rs b/tarpc/src/server/tokio.rs new file mode 100644 index 0000000..87db4c2 --- /dev/null +++ b/tarpc/src/server/tokio.rs @@ -0,0 +1,111 @@ +use super::{Channel, Requests, Serve}; +use futures::{prelude::*, ready, task::*}; +use pin_project::pin_project; +use std::pin::Pin; + +/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor) +/// for each new channel. Returned by +/// [`Incoming::execute`](crate::server::incoming::Incoming::execute). +#[pin_project] +#[derive(Debug)] +pub struct TokioServerExecutor { + #[pin] + inner: T, + serve: S, +} + +impl TokioServerExecutor { + pub(crate) fn new(inner: T, serve: S) -> Self { + Self { inner, serve } + } +} + +/// A future that drives the server by [spawning](tokio::spawn) each [response +/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by +/// [`Channel::execute`](crate::server::Channel::execute). +#[pin_project] +#[derive(Debug)] +pub struct TokioChannelExecutor { + #[pin] + inner: T, + serve: S, +} + +impl TokioServerExecutor { + fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { + self.as_mut().project().inner + } +} + +impl TokioChannelExecutor { + fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> { + self.as_mut().project().inner + } +} + +// Send + 'static execution helper methods. + +impl Requests +where + C: Channel, + C::Req: Send + 'static, + C::Resp: Send + 'static, +{ + /// Executes all requests using the given service function. Requests are handled concurrently + /// by [spawning](::tokio::spawn) each handler on tokio's default executor. + pub fn execute(self, serve: S) -> TokioChannelExecutor + where + S: Serve + Send + 'static, + { + TokioChannelExecutor { inner: self, serve } + } +} + +impl Future for TokioServerExecutor +where + St: Sized + Stream, + C: Channel + Send + 'static, + C::Req: Send + 'static, + C::Resp: Send + 'static, + Se: Serve + Send + 'static + Clone, + Se::Fut: Send, +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) { + tokio::spawn(channel.execute(self.serve.clone())); + } + tracing::info!("Server shutting down."); + Poll::Ready(()) + } +} + +impl Future for TokioChannelExecutor, S> +where + C: Channel + 'static, + C::Req: Send + 'static, + C::Resp: Send + 'static, + S: Serve + Send + 'static + Clone, + S::Fut: Send, +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) { + match response_handler { + Ok(resp) => { + let server = self.serve.clone(); + tokio::spawn(async move { + resp.execute(server).await; + }); + } + Err(e) => { + tracing::warn!("Requests stream errored out: {}", e); + break; + } + } + } + Poll::Ready(()) + } +} diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index f6dee17..c39ed93 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -151,7 +151,7 @@ impl Sink for Channel { mod tests { use crate::{ client, context, - server::{BaseChannel, Incoming}, + server::{incoming::Incoming, BaseChannel}, transport::{ self, channel::{Channel, UnboundedChannel}, diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 4cc52d6..9a3c926 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -2,7 +2,7 @@ use futures::prelude::*; use tarpc::serde_transport; use tarpc::{ client, context, - server::{BaseChannel, Incoming}, + server::{incoming::Incoming, BaseChannel}, }; use tokio_serde::formats::Json; diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 9d97ab2..b1aa431 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -7,7 +7,7 @@ use std::time::{Duration, SystemTime}; use tarpc::{ client::{self}, context, - server::{self, BaseChannel, Channel, Incoming}, + server::{self, incoming::Incoming, BaseChannel, Channel}, transport::channel, }; use tokio::join;