Refactor server module.

In the interest of the user's attention, some ancillary APIs have been
moved to new submodules:

- server::limits contains what was previously called Throttler and
  ChannelFilter. Both of those names were very generic, when the methods
  applied by these types were very specific (and also simplistic). Renames
  have occurred:
  - ThrottlerStream => MaxRequestsPerChannel
  - Throttler => MaxRequests
  - ChannelFilter => MaxChannelsPerKey
- server::incoming contains the Incoming trait.
- server::tokio contains the tokio-specific helper types.

The 5 structs and 1 enum remaining in the base server module are all
core to the functioning of the server.
This commit is contained in:
Tim Kuehn
2021-04-21 15:57:08 -07:00
parent eb67c540b9
commit ea7b6763c4
14 changed files with 277 additions and 239 deletions

View File

@@ -80,6 +80,7 @@ This example uses [tokio](https://tokio.rs), so add the following dependencies t
your `Cargo.toml`:
```toml
anyhow = "1.0"
futures = "1.0"
tarpc = { version = "0.26", features = ["tokio1"] }
tokio = { version = "1.0", features = ["macros"] }
@@ -99,9 +100,8 @@ use futures::{
};
use tarpc::{
client, context,
server::{self, Incoming},
server::{self, incoming::Incoming},
};
use std::io;
// This is the service definition. It looks a lot like a trait definition.
// It defines one RPC, hello, which takes one arg, name, and returns a String.
@@ -140,7 +140,7 @@ available behind the `tcp` feature.
```rust
#[tokio::main]
async fn main() -> io::Result<()> {
async fn main() -> anyhow::Result<()> {
let (client_transport, server_transport) = tarpc::transport::channel::unbounded();
let server = server::BaseChannel::with_defaults(server_transport);

View File

@@ -17,7 +17,7 @@ use std::{
};
use tarpc::{
context,
server::{self, Channel, Incoming},
server::{self, incoming::Incoming, Channel},
tokio_serde::formats::Json,
};
use tokio::time;

View File

@@ -9,7 +9,7 @@ use futures::{future, prelude::*};
use std::env;
use tarpc::{
client, context,
server::{BaseChannel, Incoming},
server::{incoming::Incoming, BaseChannel},
};
use tokio_serde::formats::Json;
use tracing_subscriber::prelude::*;

View File

@@ -317,8 +317,8 @@ where
.map_err(ChannelError::Ready)
}
fn start_send<'a>(
self: &'a mut Pin<&mut Self>,
fn start_send(
self: &mut Pin<&mut Self>,
message: ClientMessage<Req>,
) -> Result<(), ChannelError<C::Error>> {
self.transport_pin_mut()

View File

@@ -88,7 +88,7 @@
//! };
//! use tarpc::{
//! client, context,
//! server::{self, Incoming},
//! server::{self, incoming::Incoming},
//! };
//!
//! // This is the service definition. It looks a lot like a trait definition.
@@ -111,7 +111,7 @@
//! # };
//! # use tarpc::{
//! # client, context,
//! # server::{self, Incoming},
//! # server::{self, incoming::Incoming},
//! # };
//! # // This is the service definition. It looks a lot like a trait definition.
//! # // It defines one RPC, hello, which takes one arg, name, and returns a String.

View File

@@ -10,6 +10,7 @@ use crate::{
context::{self, SpanExt},
trace, ClientMessage, Request, Response, Transport,
};
use ::tokio::sync::mpsc;
use futures::{
future::{AbortRegistration, Abortable},
prelude::*,
@@ -19,20 +20,23 @@ use futures::{
};
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
use pin_project::pin_project;
use std::{convert::TryFrom, error::Error, fmt, hash::Hash, marker::PhantomData, pin::Pin};
use tokio::sync::mpsc;
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin};
use tracing::{info_span, instrument::Instrument, Span};
mod filter;
mod in_flight_requests;
#[cfg(test)]
mod testing;
mod throttle;
pub use self::{
filter::ChannelFilter,
throttle::{Throttler, ThrottlerStream},
};
/// Provides functionality to apply server limits.
pub mod limits;
/// Provides helper methods for streams of Channels.
pub mod incoming;
/// Provides convenience functionality for tokio-enabled applications.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
pub mod tokio;
/// Settings that control the behavior of [channels](Channel).
#[derive(Clone, Debug)]
@@ -91,51 +95,13 @@ where
}
}
/// An extension trait for [streams](Stream) of [`Channels`](Channel).
pub trait Incoming<C>
where
Self: Sized + Stream<Item = C>,
C: Channel,
{
/// Enforces channel per-key limits.
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> filter::ChannelFilter<Self, K, KF>
where
K: fmt::Display + Eq + Hash + Clone + Unpin,
KF: Fn(&C) -> K,
{
ChannelFilter::new(self, n, keymaker)
}
/// Caps the number of concurrent requests per channel.
fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream<Self> {
ThrottlerStream::new(self, n)
}
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
/// concurrently by spawning on tokio's default executor, and each request will be also
/// be spawned on tokio's default executor.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
where
S: Serve<C::Req, Resp = C::Resp>,
{
TokioServerExecutor { inner: self, serve }
}
}
impl<S, C> Incoming<C> for S
where
S: Sized + Stream<Item = C>,
C: Channel,
{
}
/// BaseChannel is a [Transport] that keeps track of in-flight requests. It converts a
/// [`Transport`](Transport) of [`ClientMessages`](ClientMessage) into a stream of
/// [requests](ClientMessage::Request).
/// BaseChannel is the standard implementation of a [`Channel`].
///
/// Besides requests, the other type of client message is [cancellation
/// BaseChannel manages a [`Transport`](Transport) of client [`messages`](ClientMessage) and
/// implements a [`Stream`] of [requests](TrackedRequest). See the [`Channel`] documentation for
/// how to use channels.
///
/// Besides requests, the other type of client message handled by `BaseChannel` is [cancellation
/// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation
/// messages. Instead, it internally handles them by cancelling corresponding requests (removing
/// the corresponding in-flight requests and aborting their handlers).
@@ -216,15 +182,15 @@ where
match start {
Ok(abort_registration) => {
drop(entered);
return Ok(TrackedRequest {
Ok(TrackedRequest {
request,
abort_registration,
span,
});
})
}
Err(AlreadyExistsError) => {
tracing::trace!("DuplicateRequest");
return Err(AlreadyExistsError);
Err(AlreadyExistsError)
}
}
}
@@ -248,8 +214,8 @@ pub struct TrackedRequest<Req> {
pub span: Span,
}
/// The server end of an open connection with a client, streaming in requests from, and sinking
/// responses to, the client.
/// The server end of an open connection with a client, receiving requests from, and sending
/// responses to, the client. `Channel` is a [`Transport`] with request lifecycle management.
///
/// The ways to use a Channel, in order of simplest to most complex, is:
/// 1. [`Channel::execute`] - Requires the `tokio1` feature. This method is best for those who
@@ -293,12 +259,21 @@ where
/// Returns the transport underlying the channel.
fn transport(&self) -> &Self::Transport;
/// Caps the number of concurrent requests to `limit`.
fn max_concurrent_requests(self, limit: usize) -> Throttler<Self>
/// Caps the number of concurrent requests to `limit`. An error will be returned for requests
/// over the concurrency limit.
///
/// Note that this is a very
/// simplistic throttling heuristic. It is easy to set a number that is too low for the
/// resources available to the server. For production use cases, a more advanced throttler is
/// likely needed.
fn max_concurrent_requests(
self,
limit: usize,
) -> limits::requests_per_channel::MaxRequests<Self>
where
Self: Sized,
{
Throttler::new(self, limit)
limits::requests_per_channel::MaxRequests::new(self, limit)
}
/// Returns a stream of requests that automatically handle request cancellation and response
@@ -321,11 +296,11 @@ where
}
/// Runs the channel until completion by executing all requests using the given service
/// function. Request handlers are run concurrently by [spawning](tokio::spawn) on tokio's
/// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's
/// default executor.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
fn execute<S>(self, serve: S) -> TokioChannelExecutor<Requests<Self>, S>
fn execute<S>(self, serve: S) -> self::tokio::TokioChannelExecutor<Requests<Self>, S>
where
Self: Sized,
S: Serve<Self::Req, Resp = Self::Resp> + Send + 'static,
@@ -348,7 +323,7 @@ where
Transport(#[source] E),
/// An error occurred while polling expired requests.
#[error("an error occurred while polling expired requests: {0}")]
Timer(#[source] tokio::time::error::Error),
Timer(#[source] ::tokio::time::error::Error),
}
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
@@ -533,18 +508,12 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<InFlightRequest<C::Req, C::Resp>, C::Error>>> {
loop {
match ready!(self.channel_pin_mut().poll_next(cx)?) {
Some(request) => {
let response_tx = self.responses_tx.clone();
return Poll::Ready(Some(Ok(InFlightRequest {
request,
response_tx,
})));
}
None => return Poll::Ready(None),
}
}
self.channel_pin_mut()
.poll_next(cx)
.map_ok(|request| InFlightRequest {
request,
response_tx: self.responses_tx.clone(),
})
}
fn pump_write(
@@ -710,128 +679,22 @@ where
}
}
// Send + 'static execution helper methods.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
impl<C> Requests<C>
where
C: Channel,
C::Req: Send + 'static,
C::Resp: Send + 'static,
{
/// Executes all requests using the given service function. Requests are handled concurrently
/// by [spawning](tokio::spawn) each handler on tokio's default executor.
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
where
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
{
TokioChannelExecutor { inner: self, serve }
}
}
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
/// for each new channel.
#[pin_project]
#[derive(Debug)]
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
pub struct TokioServerExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
/// A future that drives the server by [spawning](tokio::spawn) each [response
/// handler](InFlightRequest::execute) on tokio's default executor.
#[pin_project]
#[derive(Debug)]
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
pub struct TokioChannelExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
impl<T, S> TokioServerExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
impl<T, S> TokioChannelExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
#[cfg(feature = "tokio1")]
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
where
St: Sized + Stream<Item = C>,
C: Channel + Send + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
Se::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
tokio::spawn(channel.execute(self.serve.clone()));
}
tracing::info!("Server shutting down.");
Poll::Ready(())
}
}
#[cfg(feature = "tokio1")]
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
where
C: Channel + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
S::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
match response_handler {
Ok(resp) => {
let server = self.serve.clone();
tokio::spawn(async move {
resp.execute(server).await;
});
}
Err(e) => {
tracing::warn!("Requests stream errored out: {}", e);
break;
}
}
}
Poll::Ready(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests};
use crate::{
trace,
context, trace,
transport::channel::{self, UnboundedChannel},
ClientMessage, Request, Response,
};
use assert_matches::assert_matches;
use futures::future::{pending, Aborted};
use futures::{
future::{pending, AbortRegistration, Abortable, Aborted},
prelude::*,
Future,
};
use futures_test::task::noop_context;
use std::{pin::Pin, task::Poll};
fn test_channel<Req, Resp>() -> (
Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,

View File

@@ -0,0 +1,49 @@
use super::{
limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel},
Channel,
};
use futures::prelude::*;
use std::{fmt, hash::Hash};
#[cfg(feature = "tokio1")]
use super::{tokio::TokioServerExecutor, Serve};
/// An extension trait for [streams](Stream) of [`Channels`](Channel).
pub trait Incoming<C>
where
Self: Sized + Stream<Item = C>,
C: Channel,
{
/// Enforces channel per-key limits.
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> MaxChannelsPerKey<Self, K, KF>
where
K: fmt::Display + Eq + Hash + Clone + Unpin,
KF: Fn(&C) -> K,
{
MaxChannelsPerKey::new(self, n, keymaker)
}
/// Caps the number of concurrent requests per channel.
fn max_concurrent_requests_per_channel(self, n: usize) -> MaxRequestsPerChannel<Self> {
MaxRequestsPerChannel::new(self, n)
}
/// [Executes](Channel::execute) each incoming channel. Each channel will be handled
/// concurrently by spawning on tokio's default executor, and each request will be also
/// be spawned on tokio's default executor.
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
fn execute<S>(self, serve: S) -> TokioServerExecutor<Self, S>
where
S: Serve<C::Req, Resp = C::Resp>,
{
TokioServerExecutor::new(self, serve)
}
}
impl<S, C> Incoming<C> for S
where
S: Sized + Stream<Item = C>,
C: Channel,
{
}

View File

@@ -0,0 +1,5 @@
/// Provides functionality to limit the number of active channels.
pub mod channels_per_key;
/// Provides a [channel](crate::server::Channel) that limits the number of in-flight requests.
pub mod requests_per_channel;

View File

@@ -18,10 +18,14 @@ use std::{
use tokio::sync::mpsc;
use tracing::{debug, info, trace};
/// A single-threaded filter that drops channels based on per-key limits.
/// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on
/// per-key limits.
///
/// The decision to drop a Channel is made once at the time the Channel materializes. Once a
/// Channel is yielded, it will not be prematurely dropped.
#[pin_project]
#[derive(Debug)]
pub struct ChannelFilter<S, K, F>
pub struct MaxChannelsPerKey<S, K, F>
where
K: Eq + Hash,
{
@@ -34,7 +38,7 @@ where
keymaker: F,
}
/// A channel that is tracked by a ChannelFilter.
/// A channel that is tracked by [`MaxChannelsPerKey`].
#[pin_project]
#[derive(Debug)]
pub struct TrackedChannel<C, K> {
@@ -129,7 +133,7 @@ impl<C, K> TrackedChannel<C, K> {
}
}
impl<S, K, F> ChannelFilter<S, K, F>
impl<S, K, F> MaxChannelsPerKey<S, K, F>
where
K: Eq + Hash,
S: Stream,
@@ -138,7 +142,7 @@ where
/// Sheds new channels to stay under configured limits.
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel();
ChannelFilter {
MaxChannelsPerKey {
listener: listener.fuse(),
channels_per_key,
dropped_keys,
@@ -149,7 +153,7 @@ where
}
}
impl<S, K, F> ChannelFilter<S, K, F>
impl<S, K, F> MaxChannelsPerKey<S, K, F>
where
S: Stream,
K: fmt::Display + Eq + Hash + Clone + Unpin,
@@ -241,7 +245,7 @@ where
}
}
impl<S, K, F> Stream for ChannelFilter<S, K, F>
impl<S, K, F> Stream for MaxChannelsPerKey<S, K, F>
where
S: Stream,
K: fmt::Display + Eq + Hash + Clone + Unpin,
@@ -344,7 +348,7 @@ fn channel_filter_increment_channels_for_key() {
key: &'static str,
}
let (_, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap();
assert_eq!(Arc::strong_count(&tracker1), 1);
@@ -365,7 +369,7 @@ fn channel_filter_handle_new_channel() {
key: &'static str,
}
let (_, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
let channel1 = filter
.as_mut()
@@ -397,7 +401,7 @@ fn channel_filter_poll_listener() {
key: &'static str,
}
let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
@@ -433,7 +437,7 @@ fn channel_filter_poll_closed_channels() {
key: &'static str,
}
let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
@@ -461,7 +465,7 @@ fn channel_filter_stream() {
key: &'static str,
}
let (new_channels, listener) = futures::channel::mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels

View File

@@ -4,44 +4,49 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use super::{Channel, Config};
use crate::{Response, ServerError};
use crate::{
server::{Channel, Config},
Response, ServerError,
};
use futures::{prelude::*, ready, task::*};
use pin_project::pin_project;
use std::{io, pin::Pin};
/// A [`Channel`] that limits the number of concurrent
/// requests by throttling.
/// A [`Channel`] that limits the number of concurrent requests by throttling.
///
/// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low
/// for the resources available to the server. For production use cases, a more advanced throttler
/// is likely needed.
#[pin_project]
#[derive(Debug)]
pub struct Throttler<C> {
pub struct MaxRequests<C> {
max_in_flight_requests: usize,
#[pin]
inner: C,
}
impl<C> Throttler<C> {
impl<C> MaxRequests<C> {
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
}
}
impl<C> Throttler<C>
impl<C> MaxRequests<C>
where
C: Channel,
{
/// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to
/// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to
/// `max_in_flight_requests`.
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
Throttler {
MaxRequests {
max_in_flight_requests,
inner,
}
}
}
impl<C> Stream for Throttler<C>
impl<C> Stream for MaxRequests<C>
where
C: Channel,
{
@@ -75,7 +80,7 @@ where
}
}
impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
impl<C> Sink<Response<<C as Channel>::Resp>> for MaxRequests<C>
where
C: Channel,
{
@@ -101,13 +106,13 @@ where
}
}
impl<C> AsRef<C> for Throttler<C> {
impl<C> AsRef<C> for MaxRequests<C> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C> Channel for Throttler<C>
impl<C> Channel for MaxRequests<C>
where
C: Channel,
{
@@ -128,16 +133,17 @@ where
}
}
/// A stream of throttling channels.
/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on
/// the number of in-flight requests.
#[pin_project]
#[derive(Debug)]
pub struct ThrottlerStream<S> {
pub struct MaxRequestsPerChannel<S> {
#[pin]
inner: S,
max_in_flight_requests: usize,
}
impl<S> ThrottlerStream<S>
impl<S> MaxRequestsPerChannel<S>
where
S: Stream,
<S as Stream>::Item: Channel,
@@ -150,16 +156,16 @@ where
}
}
impl<S> Stream for ThrottlerStream<S>
impl<S> Stream for MaxRequestsPerChannel<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
type Item = Throttler<<S as Stream>::Item>;
type Item = MaxRequests<<S as Stream>::Item>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match ready!(self.as_mut().project().inner.poll_next(cx)) {
Some(channel) => Poll::Ready(Some(Throttler::new(
Some(channel) => Poll::Ready(Some(MaxRequests::new(
channel,
*self.project().max_in_flight_requests,
))),
@@ -185,7 +191,7 @@ mod tests {
#[tokio::test]
async fn throttler_in_flight_requests() {
let throttler = Throttler {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
@@ -207,7 +213,7 @@ mod tests {
#[test]
fn throttler_poll_next_done() {
let throttler = Throttler {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
@@ -218,7 +224,7 @@ mod tests {
#[test]
fn throttler_poll_next_some() -> io::Result<()> {
let throttler = Throttler {
let throttler = MaxRequests {
max_in_flight_requests: 1,
inner: FakeChannel::default::<isize, isize>(),
};
@@ -238,7 +244,7 @@ mod tests {
#[test]
fn throttler_poll_next_throttled() {
let throttler = Throttler {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
@@ -254,7 +260,7 @@ mod tests {
#[test]
fn throttler_poll_next_throttled_sink_not_ready() {
let throttler = Throttler {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: PendingSink::default::<isize, isize>(),
};
@@ -309,7 +315,7 @@ mod tests {
#[tokio::test]
async fn throttler_start_send() {
let throttler = Throttler {
let throttler = MaxRequests {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};

111
tarpc/src/server/tokio.rs Normal file
View File

@@ -0,0 +1,111 @@
use super::{Channel, Requests, Serve};
use futures::{prelude::*, ready, task::*};
use pin_project::pin_project;
use std::pin::Pin;
/// A future that drives the server by [spawning](tokio::spawn) a [`TokioChannelExecutor`](TokioChannelExecutor)
/// for each new channel. Returned by
/// [`Incoming::execute`](crate::server::incoming::Incoming::execute).
#[pin_project]
#[derive(Debug)]
pub struct TokioServerExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
impl<T, S> TokioServerExecutor<T, S> {
pub(crate) fn new(inner: T, serve: S) -> Self {
Self { inner, serve }
}
}
/// A future that drives the server by [spawning](tokio::spawn) each [response
/// handler](super::InFlightRequest::execute) on tokio's default executor. Returned by
/// [`Channel::execute`](crate::server::Channel::execute).
#[pin_project]
#[derive(Debug)]
pub struct TokioChannelExecutor<T, S> {
#[pin]
inner: T,
serve: S,
}
impl<T, S> TokioServerExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
impl<T, S> TokioChannelExecutor<T, S> {
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut T> {
self.as_mut().project().inner
}
}
// Send + 'static execution helper methods.
impl<C> Requests<C>
where
C: Channel,
C::Req: Send + 'static,
C::Resp: Send + 'static,
{
/// Executes all requests using the given service function. Requests are handled concurrently
/// by [spawning](::tokio::spawn) each handler on tokio's default executor.
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
where
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
{
TokioChannelExecutor { inner: self, serve }
}
}
impl<St, C, Se> Future for TokioServerExecutor<St, Se>
where
St: Sized + Stream<Item = C>,
C: Channel + Send + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
Se::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
tokio::spawn(channel.execute(self.serve.clone()));
}
tracing::info!("Server shutting down.");
Poll::Ready(())
}
}
impl<C, S> Future for TokioChannelExecutor<Requests<C>, S>
where
C: Channel + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
S::Fut: Send,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
while let Some(response_handler) = ready!(self.inner_pin_mut().poll_next(cx)) {
match response_handler {
Ok(resp) => {
let server = self.serve.clone();
tokio::spawn(async move {
resp.execute(server).await;
});
}
Err(e) => {
tracing::warn!("Requests stream errored out: {}", e);
break;
}
}
}
Poll::Ready(())
}
}

View File

@@ -151,7 +151,7 @@ impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
mod tests {
use crate::{
client, context,
server::{BaseChannel, Incoming},
server::{incoming::Incoming, BaseChannel},
transport::{
self,
channel::{Channel, UnboundedChannel},

View File

@@ -2,7 +2,7 @@ use futures::prelude::*;
use tarpc::serde_transport;
use tarpc::{
client, context,
server::{BaseChannel, Incoming},
server::{incoming::Incoming, BaseChannel},
};
use tokio_serde::formats::Json;

View File

@@ -7,7 +7,7 @@ use std::time::{Duration, SystemTime};
use tarpc::{
client::{self},
context,
server::{self, BaseChannel, Channel, Incoming},
server::{self, incoming::Incoming, BaseChannel, Channel},
transport::channel,
};
use tokio::join;