mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-01-16 16:29:32 +01:00
Make server methods more composable.
-- Connection Limits The problem with having ConnectionFilter default-enabled is elaborated on in https://github.com/google/tarpc/issues/217. The gist of it is not all servers want a policy based on `SocketAddr`. This PR allows customizing the behavior of ConnectionFilter, at the cost of not having it enabled by default. However, enabling it is as simple as one line: incoming.max_channels_per_key(10, ip_addr) The second argument is a key function that takes the user-chosen transport and returns some hashable, equatable, cloneable key. In the above example, it returns an `IpAddr`. This also allows the `Transport` trait to have the addr fns removed, which means it has become simply an alias for `Stream + Sink`. -- Per-Channel Request Throttling With respect to Channel's throttling behavior, the same argument applies. There isn't a one size fits all solution to throttling requests, and the policy applied by tarpc is just one of potentially many solutions. As such, `Channel` is now a trait that offers a few combinators, one of which is throttling: channel.max_concurrent_requests(10).respond_with(serve(Server)) This functionality is also available on the existing `Handler` trait, which applies it to all incoming channels and can be used in tandem with connection limits: incoming .max_channels_per_key(10, ip_addr) .max_concurrent_requests_per_channel(10).respond_with(serve(Server)) -- Global Request Throttling I've entirely removed the overall request limit enforced across all channels. This functionality is easily gotten back via [`StreamExt::buffer_unordered`](https://rust-lang-nursery.github.io/futures-api-docs/0.3.0-alpha.1/futures/stream/trait.StreamExt.html#method.buffer_unordered), with the difference being that the previous behavior allowed you to spawn channels onto different threads, whereas `buffer_unordered ` means the `Channels` are handled on a single thread (the per-request handlers are still spawned). Considering the existing options, I don't believe that the benefit provided by this functionality held its own.
This commit is contained in:
@@ -5,259 +5,331 @@
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use crate::{
|
||||
server::{Channel, Config},
|
||||
server::{self, Channel},
|
||||
util::Compact,
|
||||
ClientMessage, PollIo, Response, Transport,
|
||||
Response,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{
|
||||
channel::mpsc,
|
||||
future::AbortRegistration,
|
||||
prelude::*,
|
||||
ready,
|
||||
stream::Fuse,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use pin_utils::unsafe_pinned;
|
||||
use log::{debug, info, trace};
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::{
|
||||
collections::hash_map::Entry,
|
||||
io,
|
||||
marker::PhantomData,
|
||||
net::{IpAddr, SocketAddr},
|
||||
ops::Try,
|
||||
option::NoneError,
|
||||
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, io, marker::Unpin, ops::Try,
|
||||
pin::Pin,
|
||||
};
|
||||
|
||||
/// Drops connections under configurable conditions:
|
||||
///
|
||||
/// 1. If the max number of connections is reached.
|
||||
/// 2. If the max number of connections for a single IP is reached.
|
||||
/// A single-threaded filter that drops channels based on per-key limits.
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectionFilter<S, Req, Resp> {
|
||||
pub struct ChannelFilter<S, K, F>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
{
|
||||
listener: Fuse<S>,
|
||||
closed_connections: mpsc::UnboundedSender<SocketAddr>,
|
||||
closed_connections_rx: mpsc::UnboundedReceiver<SocketAddr>,
|
||||
config: Config,
|
||||
connections_per_ip: FnvHashMap<IpAddr, usize>,
|
||||
open_connections: usize,
|
||||
ghost: PhantomData<(Req, Resp)>,
|
||||
channels_per_key: u32,
|
||||
dropped_keys: mpsc::UnboundedReceiver<K>,
|
||||
dropped_keys_tx: mpsc::UnboundedSender<K>,
|
||||
key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
|
||||
keymaker: F,
|
||||
}
|
||||
|
||||
enum NewConnection<Req, Resp, C> {
|
||||
Filtered,
|
||||
Accepted(Channel<Req, Resp, C>),
|
||||
/// A channel that is tracked by a ChannelFilter.
|
||||
#[derive(Debug)]
|
||||
pub struct TrackedChannel<C, K> {
|
||||
inner: C,
|
||||
tracker: Arc<Tracker<K>>,
|
||||
}
|
||||
|
||||
impl<Req, Resp, C> Try for NewConnection<Req, Resp, C> {
|
||||
type Ok = Channel<Req, Resp, C>;
|
||||
type Error = NoneError;
|
||||
impl<C, K> TrackedChannel<C, K> {
|
||||
unsafe_pinned!(inner: C);
|
||||
}
|
||||
|
||||
fn into_result(self) -> Result<Channel<Req, Resp, C>, NoneError> {
|
||||
#[derive(Debug)]
|
||||
struct Tracker<K> {
|
||||
key: Option<K>,
|
||||
dropped_keys: mpsc::UnboundedSender<K>,
|
||||
}
|
||||
|
||||
impl<K> Drop for Tracker<K> {
|
||||
fn drop(&mut self) {
|
||||
// Don't care if the listener is dropped.
|
||||
let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
/// A running handler serving all requests for a single client.
|
||||
#[derive(Debug)]
|
||||
pub struct TrackedHandler<K, Fut> {
|
||||
inner: Fut,
|
||||
tracker: Tracker<K>,
|
||||
}
|
||||
|
||||
impl<K, Fut> TrackedHandler<K, Fut>
|
||||
where
|
||||
Fut: Future,
|
||||
{
|
||||
unsafe_pinned!(inner: Fut);
|
||||
}
|
||||
|
||||
impl<K, Fut> Future for TrackedHandler<K, Fut>
|
||||
where
|
||||
Fut: Future,
|
||||
{
|
||||
type Output = Fut::Output;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.inner().poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, K> Stream for TrackedChannel<C, K>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Item = <C as Stream>::Item;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
self.channel().poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, K> Sink<Response<C::Resp>> for TrackedChannel<C, K>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.channel().poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Response<C::Resp>) -> Result<(), Self::Error> {
|
||||
self.channel().start_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.channel().poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.channel().poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, K> AsRef<C> for TrackedChannel<C, K> {
|
||||
fn as_ref(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, K> Channel for TrackedChannel<C, K>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Req = C::Req;
|
||||
type Resp = C::Resp;
|
||||
|
||||
fn config(&self) -> &server::Config {
|
||||
self.inner.config()
|
||||
}
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
self.inner().in_flight_requests()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
self.inner().start_request(request_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, K> TrackedChannel<C, K> {
|
||||
/// Returns the inner channel.
|
||||
pub fn get_ref(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
/// Returns the pinned inner channel.
|
||||
fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> {
|
||||
self.inner()
|
||||
}
|
||||
}
|
||||
|
||||
enum NewChannel<C, K> {
|
||||
Accepted(TrackedChannel<C, K>),
|
||||
Filtered(K),
|
||||
}
|
||||
|
||||
impl<C, K> Try for NewChannel<C, K> {
|
||||
type Ok = TrackedChannel<C, K>;
|
||||
type Error = K;
|
||||
|
||||
fn into_result(self) -> Result<TrackedChannel<C, K>, K> {
|
||||
match self {
|
||||
NewConnection::Filtered => Err(NoneError),
|
||||
NewConnection::Accepted(channel) => Ok(channel),
|
||||
NewChannel::Accepted(channel) => Ok(channel),
|
||||
NewChannel::Filtered(k) => Err(k),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_error(_: NoneError) -> Self {
|
||||
NewConnection::Filtered
|
||||
fn from_error(k: K) -> Self {
|
||||
NewChannel::Filtered(k)
|
||||
}
|
||||
|
||||
fn from_ok(channel: Channel<Req, Resp, C>) -> Self {
|
||||
NewConnection::Accepted(channel)
|
||||
fn from_ok(channel: TrackedChannel<C, K>) -> Self {
|
||||
NewChannel::Accepted(channel)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, Req, Resp> ConnectionFilter<S, Req, Resp> {
|
||||
unsafe_pinned!(open_connections: usize);
|
||||
unsafe_pinned!(config: Config);
|
||||
unsafe_pinned!(connections_per_ip: FnvHashMap<IpAddr, usize>);
|
||||
unsafe_pinned!(closed_connections_rx: mpsc::UnboundedReceiver<SocketAddr>);
|
||||
impl<S, K, F> ChannelFilter<S, K, F>
|
||||
where
|
||||
K: fmt::Display + Eq + Hash + Clone,
|
||||
{
|
||||
unsafe_pinned!(listener: Fuse<S>);
|
||||
unsafe_pinned!(dropped_keys: mpsc::UnboundedReceiver<K>);
|
||||
unsafe_pinned!(dropped_keys_tx: mpsc::UnboundedSender<K>);
|
||||
unsafe_unpinned!(key_counts: FnvHashMap<K, Weak<Tracker<K>>>);
|
||||
unsafe_unpinned!(channels_per_key: u32);
|
||||
unsafe_unpinned!(keymaker: F);
|
||||
}
|
||||
|
||||
/// Sheds new connections to stay under configured limits.
|
||||
pub fn filter<C>(listener: S, config: Config) -> Self
|
||||
where
|
||||
S: Stream<Item = Result<C, io::Error>>,
|
||||
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
|
||||
{
|
||||
let (closed_connections, closed_connections_rx) = mpsc::unbounded();
|
||||
|
||||
ConnectionFilter {
|
||||
impl<S, K, F> ChannelFilter<S, K, F>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
S: Stream,
|
||||
{
|
||||
/// Sheds new channels to stay under configured limits.
|
||||
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
|
||||
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded();
|
||||
ChannelFilter {
|
||||
listener: listener.fuse(),
|
||||
closed_connections,
|
||||
closed_connections_rx,
|
||||
config,
|
||||
connections_per_ip: FnvHashMap::default(),
|
||||
open_connections: 0,
|
||||
ghost: PhantomData,
|
||||
channels_per_key,
|
||||
dropped_keys,
|
||||
dropped_keys_tx,
|
||||
key_counts: FnvHashMap::default(),
|
||||
keymaker,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_new_connection<C>(self: &mut Pin<&mut Self>, stream: C) -> NewConnection<Req, Resp, C>
|
||||
where
|
||||
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
|
||||
{
|
||||
let peer = match stream.peer_addr() {
|
||||
Ok(peer) => peer,
|
||||
Err(e) => {
|
||||
warn!("Could not get peer_addr of new connection: {}", e);
|
||||
return NewConnection::Filtered;
|
||||
}
|
||||
};
|
||||
|
||||
let open_connections = *self.as_mut().open_connections();
|
||||
if open_connections >= self.as_mut().config().max_connections {
|
||||
warn!(
|
||||
"[{}] Shedding connection because the maximum open connections \
|
||||
limit is reached ({}/{}).",
|
||||
peer,
|
||||
open_connections,
|
||||
self.as_mut().config().max_connections
|
||||
);
|
||||
return NewConnection::Filtered;
|
||||
}
|
||||
|
||||
let config = self.config.clone();
|
||||
let open_connections_for_ip = self.increment_connections_for_ip(&peer)?;
|
||||
*self.as_mut().open_connections() += 1;
|
||||
impl<S, C, K, F> ChannelFilter<S, K, F>
|
||||
where
|
||||
S: Stream<Item = C>,
|
||||
C: Channel,
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
F: Fn(&C) -> K,
|
||||
{
|
||||
fn handle_new_channel(self: &mut Pin<&mut Self>, stream: C) -> NewChannel<C, K> {
|
||||
let key = self.as_mut().keymaker()(&stream);
|
||||
let tracker = self.increment_channels_for_key(key.clone())?;
|
||||
let max = self.as_mut().channels_per_key();
|
||||
|
||||
debug!(
|
||||
"[{}] Opening channel ({}/{} connections for IP, {} total).",
|
||||
peer,
|
||||
open_connections_for_ip,
|
||||
config.max_connections_per_ip,
|
||||
self.as_mut().open_connections(),
|
||||
"[{}] Opening channel ({}/{}) channels for key.",
|
||||
key,
|
||||
Arc::strong_count(&tracker),
|
||||
max
|
||||
);
|
||||
|
||||
NewConnection::Accepted(Channel {
|
||||
client_addr: peer,
|
||||
closed_connections: self.closed_connections.clone(),
|
||||
transport: stream.fuse(),
|
||||
config,
|
||||
ghost: PhantomData,
|
||||
NewChannel::Accepted(TrackedChannel {
|
||||
tracker,
|
||||
inner: stream,
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_closed_connection(self: &mut Pin<&mut Self>, addr: &SocketAddr) {
|
||||
*self.as_mut().open_connections() -= 1;
|
||||
debug!(
|
||||
"[{}] Closing channel. {} open connections remaining.",
|
||||
addr, self.open_connections
|
||||
);
|
||||
self.decrement_connections_for_ip(&addr);
|
||||
self.as_mut().connections_per_ip().compact(0.1);
|
||||
}
|
||||
fn increment_channels_for_key(self: &mut Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
|
||||
let channels_per_key = self.channels_per_key;
|
||||
let dropped_keys = self.dropped_keys_tx.clone();
|
||||
let key_counts = &mut self.as_mut().key_counts();
|
||||
match key_counts.entry(key.clone()) {
|
||||
Entry::Vacant(vacant) => {
|
||||
let tracker = Arc::new(Tracker {
|
||||
key: Some(key),
|
||||
dropped_keys,
|
||||
});
|
||||
|
||||
fn increment_connections_for_ip(self: &mut Pin<&mut Self>, peer: &SocketAddr) -> Option<usize> {
|
||||
let max_connections_per_ip = self.as_mut().config().max_connections_per_ip;
|
||||
let mut occupied;
|
||||
let mut connections_per_ip = self.as_mut().connections_per_ip();
|
||||
let occupied = match connections_per_ip.entry(peer.ip()) {
|
||||
Entry::Vacant(vacant) => vacant.insert(0),
|
||||
Entry::Occupied(o) => {
|
||||
if *o.get() < max_connections_per_ip {
|
||||
// Store the reference outside the block to extend the lifetime.
|
||||
occupied = o;
|
||||
occupied.get_mut()
|
||||
} else {
|
||||
vacant.insert(Arc::downgrade(&tracker));
|
||||
Ok(tracker)
|
||||
}
|
||||
Entry::Occupied(mut o) => {
|
||||
let count = o.get().strong_count();
|
||||
if count >= channels_per_key.try_into().unwrap() {
|
||||
info!(
|
||||
"[{}] Opened max connections from IP ({}/{}).",
|
||||
peer,
|
||||
o.get(),
|
||||
max_connections_per_ip
|
||||
"[{}] Opened max channels from key ({}/{}).",
|
||||
key, count, channels_per_key
|
||||
);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
};
|
||||
*occupied += 1;
|
||||
Some(*occupied)
|
||||
}
|
||||
|
||||
fn decrement_connections_for_ip(self: &mut Pin<&mut Self>, addr: &SocketAddr) {
|
||||
let should_compact = match self.as_mut().connections_per_ip().entry(addr.ip()) {
|
||||
Entry::Vacant(_) => {
|
||||
error!("[{}] Got vacant entry when closing connection.", addr);
|
||||
return;
|
||||
}
|
||||
Entry::Occupied(mut occupied) => {
|
||||
*occupied.get_mut() -= 1;
|
||||
if *occupied.get() == 0 {
|
||||
occupied.remove();
|
||||
true
|
||||
Err(key)
|
||||
} else {
|
||||
false
|
||||
Ok(o.get().upgrade().unwrap_or_else(|| {
|
||||
let tracker = Arc::new(Tracker {
|
||||
key: Some(key),
|
||||
dropped_keys,
|
||||
});
|
||||
|
||||
*o.get_mut() = Arc::downgrade(&tracker);
|
||||
tracker
|
||||
}))
|
||||
}
|
||||
}
|
||||
};
|
||||
if should_compact {
|
||||
self.as_mut().connections_per_ip().compact(0.1);
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_listener<C>(
|
||||
fn poll_listener(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<NewConnection<Req, Resp, C>>
|
||||
where
|
||||
S: Stream<Item = Result<C, io::Error>>,
|
||||
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
|
||||
{
|
||||
match ready!(self.as_mut().listener().poll_next_unpin(cx)?) {
|
||||
Some(codec) => Poll::Ready(Some(Ok(self.handle_new_connection(codec)))),
|
||||
) -> Poll<Option<NewChannel<C, K>>> {
|
||||
match ready!(self.as_mut().listener().poll_next_unpin(cx)) {
|
||||
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_closed_connections(
|
||||
self: &mut Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
match ready!(self.as_mut().closed_connections_rx().poll_next_unpin(cx)) {
|
||||
Some(addr) => {
|
||||
self.handle_closed_connection(&addr);
|
||||
Poll::Ready(Ok(()))
|
||||
fn poll_closed_channels(self: &mut Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
match ready!(self.as_mut().dropped_keys().poll_next_unpin(cx)) {
|
||||
Some(key) => {
|
||||
debug!("All channels dropped for key [{}]", key);
|
||||
self.as_mut().key_counts().remove(&key);
|
||||
self.as_mut().key_counts().compact(0.1);
|
||||
Poll::Ready(())
|
||||
}
|
||||
None => unreachable!("Holding a copy of closed_connections and didn't close it."),
|
||||
None => unreachable!("Holding a copy of closed_channels and didn't close it."),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, Req, Resp, T> Stream for ConnectionFilter<S, Req, Resp>
|
||||
impl<S, C, K, F> Stream for ChannelFilter<S, K, F>
|
||||
where
|
||||
S: Stream<Item = Result<T, io::Error>>,
|
||||
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
|
||||
S: Stream<Item = C>,
|
||||
C: Channel,
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
F: Fn(&C) -> K,
|
||||
{
|
||||
type Item = io::Result<Channel<Req, Resp, T>>;
|
||||
type Item = TrackedChannel<C, K>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Channel<Req, Resp, T>> {
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<TrackedChannel<C, K>>> {
|
||||
loop {
|
||||
match (
|
||||
self.as_mut().poll_listener(cx)?,
|
||||
self.poll_closed_connections(cx)?,
|
||||
self.as_mut().poll_listener(cx),
|
||||
self.poll_closed_channels(cx),
|
||||
) {
|
||||
(Poll::Ready(Some(NewConnection::Accepted(channel))), _) => {
|
||||
return Poll::Ready(Some(Ok(channel)));
|
||||
(Poll::Ready(Some(NewChannel::Accepted(channel))), _) => {
|
||||
return Poll::Ready(Some(channel));
|
||||
}
|
||||
(Poll::Ready(Some(NewConnection::Filtered)), _) | (_, Poll::Ready(())) => {
|
||||
trace!(
|
||||
"Filtered a connection; {} open.",
|
||||
self.as_mut().open_connections()
|
||||
);
|
||||
(Poll::Ready(Some(NewChannel::Filtered(_))), _) => {
|
||||
continue;
|
||||
}
|
||||
(_, Poll::Ready(())) => continue,
|
||||
(Poll::Pending, Poll::Pending) => return Poll::Pending,
|
||||
(Poll::Ready(None), Poll::Pending) => {
|
||||
if *self.as_mut().open_connections() > 0 {
|
||||
trace!(
|
||||
"Listener closed; {} open connections.",
|
||||
self.as_mut().open_connections()
|
||||
);
|
||||
return Poll::Pending;
|
||||
}
|
||||
trace!("Shutting down listener: all connections closed, and no more coming.");
|
||||
trace!("Shutting down listener.");
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,27 +7,27 @@
|
||||
//! Provides a server that concurrently handles many connections sending multiplexed requests.
|
||||
|
||||
use crate::{
|
||||
context, util::deadline_compat, util::AsDuration, util::Compact, ClientMessage,
|
||||
ClientMessageKind, PollIo, Request, Response, ServerError, Transport,
|
||||
context, util::deadline_compat, util::AsDuration, util::Compact, ClientMessage, PollIo,
|
||||
Request, Response, ServerError, Transport,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{
|
||||
channel::mpsc,
|
||||
future::{abortable, AbortHandle},
|
||||
future::{AbortHandle, AbortRegistration, Abortable},
|
||||
prelude::*,
|
||||
ready,
|
||||
stream::Fuse,
|
||||
task::{Context, Poll},
|
||||
try_ready,
|
||||
};
|
||||
use humantime::format_rfc3339;
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use std::{
|
||||
error::Error as StdError,
|
||||
fmt,
|
||||
hash::Hash,
|
||||
io,
|
||||
marker::PhantomData,
|
||||
net::SocketAddr,
|
||||
pin::Pin,
|
||||
time::{Instant, SystemTime},
|
||||
};
|
||||
@@ -35,6 +35,14 @@ use tokio_timer::timeout;
|
||||
use trace::{self, TraceId};
|
||||
|
||||
mod filter;
|
||||
#[cfg(test)]
|
||||
mod testing;
|
||||
mod throttle;
|
||||
|
||||
pub use self::{
|
||||
filter::ChannelFilter,
|
||||
throttle::{Throttler, ThrottlerStream},
|
||||
};
|
||||
|
||||
/// Manages clients, serving multiplexed requests over each connection.
|
||||
#[derive(Debug)]
|
||||
@@ -53,17 +61,6 @@ impl<Req, Resp> Default for Server<Req, Resp> {
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Config {
|
||||
/// The maximum number of clients that can be connected to the server at once. When at the
|
||||
/// limit, existing connections are honored and new connections are rejected.
|
||||
pub max_connections: usize,
|
||||
/// The maximum number of clients per IP address that can be connected to the server at once.
|
||||
/// When an IP is at the limit, existing connections are honored and new connections on that IP
|
||||
/// address are rejected.
|
||||
pub max_connections_per_ip: usize,
|
||||
/// The maximum number of requests that can be in flight for each client. When a client is at
|
||||
/// the in-flight request limit, existing requests are fulfilled and new requests are rejected.
|
||||
/// Rejected requests are sent a response error.
|
||||
pub max_in_flight_requests_per_connection: usize,
|
||||
/// The number of responses per client that can be buffered server-side before being sent.
|
||||
/// `pending_response_buffer` controls the buffer size of the channel that a server's
|
||||
/// response tasks use to send responses to the client handler task.
|
||||
@@ -73,14 +70,21 @@ pub struct Config {
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Config {
|
||||
max_connections: 1_000_000,
|
||||
max_connections_per_ip: 1_000,
|
||||
max_in_flight_requests_per_connection: 1_000,
|
||||
pending_response_buffer: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Returns a channel backed by `transport` and configured with `self`.
|
||||
pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>> + Send,
|
||||
{
|
||||
BaseChannel::new(self, transport)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new server with configuration specified `config`.
|
||||
pub fn new<Req, Resp>(config: Config) -> Server<Req, Resp> {
|
||||
Server {
|
||||
@@ -95,18 +99,15 @@ impl<Req, Resp> Server<Req, Resp> {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Returns a stream of the incoming connections to the server.
|
||||
pub fn incoming<S, T>(
|
||||
self,
|
||||
listener: S,
|
||||
) -> impl Stream<Item = io::Result<Channel<Req, Resp, T>>>
|
||||
/// Returns a stream of server channels.
|
||||
pub fn incoming<S, T>(self, listener: S) -> impl Stream<Item = BaseChannel<Req, Resp, T>>
|
||||
where
|
||||
Req: Send,
|
||||
Resp: Send,
|
||||
S: Stream<Item = io::Result<T>>,
|
||||
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
|
||||
S: Stream<Item = T>,
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>> + Send,
|
||||
{
|
||||
self::filter::ConnectionFilter::filter(listener, self.config.clone())
|
||||
listener.map(move |t| BaseChannel::new(self.config.clone(), t))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,31 +123,21 @@ impl<S, F> Running<S, F> {
|
||||
unsafe_unpinned!(request_handler: F);
|
||||
}
|
||||
|
||||
impl<S, T, Req, Resp, F, Fut> Future for Running<S, F>
|
||||
impl<S, C, F, Fut> Future for Running<S, F>
|
||||
where
|
||||
S: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
|
||||
Req: Send + 'static,
|
||||
Resp: Send + 'static,
|
||||
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send + 'static,
|
||||
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
|
||||
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
|
||||
S: Sized + Stream<Item = C>,
|
||||
C: Channel + Send + 'static,
|
||||
F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone,
|
||||
Fut: Future<Output = io::Result<C::Resp>> + Send + 'static,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) {
|
||||
match channel {
|
||||
Ok(channel) => {
|
||||
let peer = channel.client_addr;
|
||||
if let Err(e) =
|
||||
crate::spawn(channel.respond_with(self.as_mut().request_handler().clone()))
|
||||
{
|
||||
warn!("[{}] Failed to spawn connection handler: {:?}", peer, e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Incoming connection error: {}", e);
|
||||
}
|
||||
if let Err(e) =
|
||||
crate::spawn(channel.respond_with(self.as_mut().request_handler().clone()))
|
||||
{
|
||||
warn!("Failed to spawn channel handler: {:?}", e);
|
||||
}
|
||||
}
|
||||
info!("Server shutting down.");
|
||||
@@ -155,18 +146,30 @@ where
|
||||
}
|
||||
|
||||
/// A utility trait enabling a stream to fluently chain a request handler.
|
||||
pub trait Handler<T, Req, Resp>
|
||||
pub trait Handler<C>
|
||||
where
|
||||
Self: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
|
||||
Req: Send,
|
||||
Resp: Send,
|
||||
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
|
||||
Self: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
/// Enforces channel per-key limits.
|
||||
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> filter::ChannelFilter<Self, K, KF>
|
||||
where
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
KF: Fn(&C) -> K,
|
||||
{
|
||||
ChannelFilter::new(self, n, keymaker)
|
||||
}
|
||||
|
||||
/// Caps the number of concurrent requests per channel.
|
||||
fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream<Self> {
|
||||
ThrottlerStream::new(self, n)
|
||||
}
|
||||
|
||||
/// Responds to all requests with `request_handler`.
|
||||
fn respond_with<F, Fut>(self, request_handler: F) -> Running<Self, F>
|
||||
where
|
||||
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
|
||||
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
|
||||
F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone,
|
||||
Fut: Future<Output = io::Result<C::Resp>> + Send + 'static,
|
||||
{
|
||||
Running {
|
||||
incoming: self,
|
||||
@@ -175,191 +178,276 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, Req, Resp, S> Handler<T, Req, Resp> for S
|
||||
impl<S, C> Handler<C> for S
|
||||
where
|
||||
S: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
|
||||
Req: Send,
|
||||
Resp: Send,
|
||||
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
|
||||
S: Sized + Stream<Item = C>,
|
||||
C: Channel,
|
||||
{
|
||||
}
|
||||
|
||||
/// Responds to all requests with `request_handler`.
|
||||
/// The server end of an open connection with a client.
|
||||
/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
|
||||
#[derive(Debug)]
|
||||
pub struct Channel<Req, Resp, T> {
|
||||
pub struct BaseChannel<Req, Resp, T> {
|
||||
config: Config,
|
||||
/// Writes responses to the wire and reads requests off the wire.
|
||||
transport: Fuse<T>,
|
||||
/// Signals the connection is closed when `Channel` is dropped.
|
||||
closed_connections: mpsc::UnboundedSender<SocketAddr>,
|
||||
/// Channel limits to prevent unlimited resource usage.
|
||||
config: Config,
|
||||
/// The address of the server connected to.
|
||||
client_addr: SocketAddr,
|
||||
/// Number of requests currently being responded to.
|
||||
in_flight_requests: FnvHashMap<u64, AbortHandle>,
|
||||
/// Types the request and response.
|
||||
ghost: PhantomData<(Req, Resp)>,
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Drop for Channel<Req, Resp, T> {
|
||||
fn drop(&mut self) {
|
||||
trace!("[{}] Closing channel.", self.client_addr);
|
||||
impl<Req, Resp, T> BaseChannel<Req, Resp, T> {
|
||||
unsafe_unpinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
|
||||
}
|
||||
|
||||
// Even in a bounded channel, each connection would have a guaranteed slot, so using
|
||||
// an unbounded sender is actually no different. And, the bound is on the maximum number
|
||||
// of open connections.
|
||||
if self
|
||||
.closed_connections
|
||||
.unbounded_send(self.client_addr)
|
||||
.is_err()
|
||||
{
|
||||
warn!(
|
||||
"[{}] Failed to send closed connection message.",
|
||||
self.client_addr
|
||||
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>> + Send,
|
||||
{
|
||||
/// Creates a new channel backed by `transport` and configured with `config`.
|
||||
pub fn new(config: Config, transport: T) -> Self {
|
||||
BaseChannel {
|
||||
config,
|
||||
transport: transport.fuse(),
|
||||
in_flight_requests: FnvHashMap::default(),
|
||||
ghost: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new channel backed by `transport` and configured with the defaults.
|
||||
pub fn with_defaults(transport: T) -> Self {
|
||||
Self::new(Config::default(), transport)
|
||||
}
|
||||
|
||||
/// Returns the inner transport.
|
||||
pub fn get_ref(&self) -> &T {
|
||||
self.transport.get_ref()
|
||||
}
|
||||
|
||||
/// Returns the pinned inner transport.
|
||||
pub fn transport<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> {
|
||||
unsafe { self.map_unchecked_mut(|me| me.transport.get_mut()) }
|
||||
}
|
||||
|
||||
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
|
||||
// It's possible the request was already completed, so it's fine
|
||||
// if this is None.
|
||||
if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) {
|
||||
self.as_mut().in_flight_requests().compact(0.1);
|
||||
|
||||
cancel_handle.abort();
|
||||
let remaining = self.as_mut().in_flight_requests().len();
|
||||
trace!(
|
||||
"[{}] Request canceled. In-flight requests = {}",
|
||||
trace_context.trace_id,
|
||||
remaining,
|
||||
);
|
||||
} else {
|
||||
trace!(
|
||||
"[{}] Received cancellation, but response handler \
|
||||
is already complete.",
|
||||
trace_context.trace_id,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Channel<Req, Resp, T> {
|
||||
unsafe_pinned!(transport: Fuse<T>);
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Channel<Req, Resp, T>
|
||||
/// The server end of an open connection with a client, streaming in requests from, and sinking
|
||||
/// responses to, the client.
|
||||
///
|
||||
/// Channels are free to somewhat rely on the assumption that all in-flight requests are eventually
|
||||
/// either [cancelled](Channel::cancel_request) or [responded to](Sink::start_send). Safety cannot
|
||||
/// rely on this assumption, but it is best for `Channel` users to always account for all outstanding
|
||||
/// requests.
|
||||
pub trait Channel
|
||||
where
|
||||
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
|
||||
Req: Send,
|
||||
Resp: Send,
|
||||
Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
|
||||
{
|
||||
pub(crate) fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> io::Result<()> {
|
||||
self.as_mut().transport().start_send(response)
|
||||
/// Type of request item.
|
||||
type Req: Send + 'static;
|
||||
|
||||
/// Type of response sink item.
|
||||
type Resp: Send + 'static;
|
||||
|
||||
/// Configuration of the channel.
|
||||
fn config(&self) -> &Config;
|
||||
|
||||
/// Returns the number of in-flight requests over this channel.
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize;
|
||||
|
||||
/// Caps the number of concurrent requests.
|
||||
fn max_concurrent_requests(self, n: usize) -> Throttler<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Throttler::new(self, n)
|
||||
}
|
||||
|
||||
pub(crate) fn poll_ready(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.as_mut().transport().poll_ready(cx)
|
||||
}
|
||||
|
||||
pub(crate) fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.as_mut().transport().poll_flush(cx)
|
||||
}
|
||||
|
||||
pub(crate) fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<ClientMessage<Req>> {
|
||||
self.as_mut().transport().poll_next(cx)
|
||||
}
|
||||
|
||||
/// Returns the address of the client connected to the channel.
|
||||
pub fn client_addr(&self) -> &SocketAddr {
|
||||
&self.client_addr
|
||||
}
|
||||
/// Tells the Channel that request with ID `request_id` is being handled.
|
||||
/// The request will be tracked until a response with the same ID is sent
|
||||
/// to the Channel.
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration;
|
||||
|
||||
/// Respond to requests coming over the channel with `f`. Returns a future that drives the
|
||||
/// responses and resolves when the connection is closed.
|
||||
pub fn respond_with<F, Fut>(self, f: F) -> impl Future<Output = ()>
|
||||
fn respond_with<F, Fut>(self, f: F) -> ResponseHandler<Self, F>
|
||||
where
|
||||
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
|
||||
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
|
||||
Req: 'static,
|
||||
Resp: 'static,
|
||||
F: FnOnce(context::Context, Self::Req) -> Fut + Send + 'static + Clone,
|
||||
Fut: Future<Output = io::Result<Self::Resp>> + Send + 'static,
|
||||
Self: Sized,
|
||||
{
|
||||
let (responses_tx, responses) = mpsc::channel(self.config.pending_response_buffer);
|
||||
let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
|
||||
let responses = responses.fuse();
|
||||
let peer = self.client_addr;
|
||||
|
||||
ClientHandler {
|
||||
ResponseHandler {
|
||||
channel: self,
|
||||
f,
|
||||
pending_responses: responses,
|
||||
responses_tx,
|
||||
in_flight_requests: FnvHashMap::default(),
|
||||
}
|
||||
.unwrap_or_else(move |e| {
|
||||
info!("[{}] ClientHandler errored out: {}", peer, e);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>> + Send + 'static,
|
||||
Req: Send + 'static,
|
||||
Resp: Send + 'static,
|
||||
{
|
||||
type Item = io::Result<Request<Req>>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
match ready!(self.as_mut().transport().poll_next(cx)?) {
|
||||
Some(message) => match message {
|
||||
ClientMessage::Request(request) => {
|
||||
return Poll::Ready(Some(Ok(request)));
|
||||
}
|
||||
ClientMessage::Cancel {
|
||||
trace_context,
|
||||
request_id,
|
||||
} => {
|
||||
self.as_mut().cancel_request(&trace_context, request_id);
|
||||
}
|
||||
},
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>> + Send + 'static,
|
||||
Req: Send + 'static,
|
||||
Resp: Send + 'static,
|
||||
{
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.transport().poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(
|
||||
mut self: Pin<&mut Self>,
|
||||
response: Response<Resp>,
|
||||
) -> Result<(), Self::Error> {
|
||||
if self
|
||||
.as_mut()
|
||||
.in_flight_requests()
|
||||
.remove(&response.request_id)
|
||||
.is_some()
|
||||
{
|
||||
self.as_mut().in_flight_requests().compact(0.1);
|
||||
}
|
||||
|
||||
self.transport().start_send(response)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.transport().poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.transport().poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
|
||||
fn as_ref(&self) -> &T {
|
||||
self.transport.get_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>> + Send + 'static,
|
||||
Req: Send + 'static,
|
||||
Resp: Send + 'static,
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
fn in_flight_requests(mut self: Pin<&mut Self>) -> usize {
|
||||
self.as_mut().in_flight_requests().len()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
let (abort_handle, abort_registration) = AbortHandle::new_pair();
|
||||
assert!(self
|
||||
.in_flight_requests()
|
||||
.insert(request_id, abort_handle)
|
||||
.is_none());
|
||||
abort_registration
|
||||
}
|
||||
}
|
||||
|
||||
/// A running handler serving all requests coming over a channel.
|
||||
#[derive(Debug)]
|
||||
struct ClientHandler<Req, Resp, T, F> {
|
||||
channel: Channel<Req, Resp, T>,
|
||||
pub struct ResponseHandler<C, F>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
channel: C,
|
||||
/// Responses waiting to be written to the wire.
|
||||
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<Resp>)>>,
|
||||
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>,
|
||||
/// Handed out to request handlers to fan in responses.
|
||||
responses_tx: mpsc::Sender<(context::Context, Response<Resp>)>,
|
||||
/// Number of requests currently being responded to.
|
||||
in_flight_requests: FnvHashMap<u64, AbortHandle>,
|
||||
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
|
||||
/// Request handler.
|
||||
f: F,
|
||||
}
|
||||
|
||||
impl<Req, Resp, T, F> ClientHandler<Req, Resp, T, F> {
|
||||
unsafe_pinned!(channel: Channel<Req, Resp, T>);
|
||||
unsafe_pinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
|
||||
unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<Resp>)>>);
|
||||
unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response<Resp>)>);
|
||||
impl<C, F> ResponseHandler<C, F>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
unsafe_pinned!(channel: C);
|
||||
unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>);
|
||||
unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>);
|
||||
// For this to be safe, field f must be private, and code in this module must never
|
||||
// construct PinMut<F>.
|
||||
unsafe_unpinned!(f: F);
|
||||
}
|
||||
|
||||
impl<Req, Resp, T, F, Fut> ClientHandler<Req, Resp, T, F>
|
||||
impl<C, F, Fut> ResponseHandler<C, F>
|
||||
where
|
||||
Req: Send + 'static,
|
||||
Resp: Send + 'static,
|
||||
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
|
||||
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
|
||||
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
|
||||
C: Channel,
|
||||
F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone,
|
||||
Fut: Future<Output = io::Result<C::Resp>> + Send + 'static,
|
||||
{
|
||||
/// If at max in-flight requests, check that there's room to immediately write a throttled
|
||||
/// response.
|
||||
fn poll_ready_if_throttling(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if self.in_flight_requests.len()
|
||||
>= self.channel.config.max_in_flight_requests_per_connection
|
||||
{
|
||||
let peer = self.as_mut().channel().client_addr;
|
||||
|
||||
while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? {
|
||||
info!(
|
||||
"[{}] In-flight requests at max ({}), and transport is not ready.",
|
||||
peer,
|
||||
self.as_mut().in_flight_requests().len(),
|
||||
);
|
||||
try_ready!(self.as_mut().channel().poll_flush(cx));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
ready!(self.as_mut().poll_ready_if_throttling(cx)?);
|
||||
|
||||
Poll::Ready(match ready!(self.as_mut().channel().poll_next(cx)?) {
|
||||
Some(message) => {
|
||||
match message.message {
|
||||
ClientMessageKind::Request(request) => {
|
||||
self.handle_request(message.trace_context, request)?;
|
||||
}
|
||||
ClientMessageKind::Cancel { request_id } => {
|
||||
self.cancel_request(&message.trace_context, request_id);
|
||||
}
|
||||
}
|
||||
Some(Ok(()))
|
||||
match ready!(self.as_mut().channel().poll_next(cx)?) {
|
||||
Some(request) => {
|
||||
self.handle_request(request)?;
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
None => {
|
||||
trace!("[{}] Read half closed", self.channel.client_addr);
|
||||
None
|
||||
}
|
||||
})
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn pump_write(
|
||||
@@ -368,7 +456,12 @@ where
|
||||
read_half_closed: bool,
|
||||
) -> PollIo<()> {
|
||||
match self.as_mut().poll_next_response(cx)? {
|
||||
Poll::Ready(Some((_, response))) => {
|
||||
Poll::Ready(Some((ctx, response))) => {
|
||||
trace!(
|
||||
"[{}] Staging response. In-flight requests = {}.",
|
||||
ctx.trace_id(),
|
||||
self.as_mut().channel().in_flight_requests(),
|
||||
);
|
||||
self.as_mut().channel().start_send(response)?;
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
@@ -384,7 +477,7 @@ where
|
||||
// Being here means there are no staged requests and all written responses are
|
||||
// fully flushed. So, if the read half is closed and there are no in-flight
|
||||
// requests, then we can close the write half.
|
||||
if read_half_closed && self.as_mut().in_flight_requests().is_empty() {
|
||||
if read_half_closed && self.as_mut().channel().in_flight_requests() == 0 {
|
||||
Poll::Ready(None)
|
||||
} else {
|
||||
Poll::Pending
|
||||
@@ -396,90 +489,33 @@ where
|
||||
fn poll_next_response(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<(context::Context, Response<Resp>)> {
|
||||
) -> PollIo<(context::Context, Response<C::Resp>)> {
|
||||
// Ensure there's room to write a response.
|
||||
while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? {
|
||||
ready!(self.as_mut().channel().poll_flush(cx)?);
|
||||
}
|
||||
|
||||
let peer = self.as_mut().channel().client_addr;
|
||||
|
||||
match ready!(self.as_mut().pending_responses().poll_next(cx)) {
|
||||
Some((ctx, response)) => {
|
||||
if self
|
||||
.as_mut()
|
||||
.in_flight_requests()
|
||||
.remove(&response.request_id)
|
||||
.is_some()
|
||||
{
|
||||
self.as_mut().in_flight_requests().compact(0.1);
|
||||
}
|
||||
trace!(
|
||||
"[{}/{}] Staging response. In-flight requests = {}.",
|
||||
ctx.trace_id(),
|
||||
peer,
|
||||
self.as_mut().in_flight_requests().len(),
|
||||
);
|
||||
Poll::Ready(Some(Ok((ctx, response))))
|
||||
}
|
||||
Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
|
||||
None => {
|
||||
// This branch likely won't happen, since the ClientHandler is holding a Sender.
|
||||
trace!("[{}] No new responses.", peer);
|
||||
// This branch likely won't happen, since the ResponseHandler is holding a Sender.
|
||||
Poll::Ready(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_request(
|
||||
mut self: Pin<&mut Self>,
|
||||
trace_context: trace::Context,
|
||||
request: Request<Req>,
|
||||
) -> io::Result<()> {
|
||||
fn handle_request(mut self: Pin<&mut Self>, request: Request<C::Req>) -> io::Result<()> {
|
||||
let request_id = request.id;
|
||||
let peer = self.as_mut().channel().client_addr;
|
||||
let ctx = context::Context {
|
||||
deadline: request.deadline,
|
||||
trace_context,
|
||||
};
|
||||
let request = request.message;
|
||||
|
||||
if self.as_mut().in_flight_requests().len()
|
||||
>= self
|
||||
.as_mut()
|
||||
.channel()
|
||||
.config
|
||||
.max_in_flight_requests_per_connection
|
||||
{
|
||||
debug!(
|
||||
"[{}/{}] Client has reached in-flight request limit ({}/{}).",
|
||||
ctx.trace_id(),
|
||||
peer,
|
||||
self.as_mut().in_flight_requests().len(),
|
||||
self.as_mut()
|
||||
.channel()
|
||||
.config
|
||||
.max_in_flight_requests_per_connection
|
||||
);
|
||||
|
||||
self.as_mut().channel().start_send(Response {
|
||||
request_id,
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::WouldBlock,
|
||||
detail: Some("Server throttled the request.".into()),
|
||||
}),
|
||||
})?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let deadline = ctx.deadline;
|
||||
let deadline = request.context.deadline;
|
||||
let timeout = deadline.as_duration();
|
||||
trace!(
|
||||
"[{}/{}] Received request with deadline {} (timeout {:?}).",
|
||||
ctx.trace_id(),
|
||||
peer,
|
||||
"[{}] Received request with deadline {} (timeout {:?}).",
|
||||
request.context.trace_id(),
|
||||
format_rfc3339(deadline),
|
||||
timeout,
|
||||
);
|
||||
let ctx = request.context;
|
||||
let request = request.message;
|
||||
let mut response_tx = self.as_mut().responses_tx().clone();
|
||||
|
||||
let trace_id = *ctx.trace_id();
|
||||
@@ -490,18 +526,19 @@ where
|
||||
request_id,
|
||||
message: match result {
|
||||
Ok(message) => Ok(message),
|
||||
Err(e) => Err(make_server_error(e, trace_id, peer, deadline)),
|
||||
Err(e) => Err(make_server_error(e, trace_id, deadline)),
|
||||
},
|
||||
};
|
||||
trace!("[{}/{}] Sending response.", trace_id, peer);
|
||||
trace!("[{}] Sending response.", trace_id);
|
||||
response_tx
|
||||
.send((ctx, response))
|
||||
.unwrap_or_else(|_| ())
|
||||
.await;
|
||||
},
|
||||
);
|
||||
let (abortable_response, abort_handle) = abortable(response);
|
||||
crate::spawn(abortable_response.map(|_| ())).map_err(|e| {
|
||||
let abort_registration = self.as_mut().channel().start_request(request_id);
|
||||
let response = Abortable::new(response, abort_registration);
|
||||
crate::spawn(response.map(|_| ())).map_err(|e| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!(
|
||||
@@ -510,92 +547,49 @@ where
|
||||
),
|
||||
)
|
||||
})?;
|
||||
self.as_mut()
|
||||
.in_flight_requests()
|
||||
.insert(request_id, abort_handle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
|
||||
// It's possible the request was already completed, so it's fine
|
||||
// if this is None.
|
||||
if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) {
|
||||
self.as_mut().in_flight_requests().compact(0.1);
|
||||
|
||||
cancel_handle.abort();
|
||||
let remaining = self.as_mut().in_flight_requests().len();
|
||||
trace!(
|
||||
"[{}/{}] Request canceled. In-flight requests = {}",
|
||||
trace_context.trace_id,
|
||||
self.channel.client_addr,
|
||||
remaining,
|
||||
);
|
||||
} else {
|
||||
trace!(
|
||||
"[{}/{}] Received cancellation, but response handler \
|
||||
is already complete.",
|
||||
trace_context.trace_id,
|
||||
self.channel.client_addr
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T, F, Fut> Future for ClientHandler<Req, Resp, T, F>
|
||||
impl<C, F, Fut> Future for ResponseHandler<C, F>
|
||||
where
|
||||
Req: Send + 'static,
|
||||
Resp: Send + 'static,
|
||||
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
|
||||
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
|
||||
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
|
||||
C: Channel,
|
||||
F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone,
|
||||
Fut: Future<Output = io::Result<C::Resp>> + Send + 'static,
|
||||
{
|
||||
type Output = io::Result<()>;
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
trace!("[{}] ClientHandler::poll", self.channel.client_addr);
|
||||
loop {
|
||||
let read = self.as_mut().pump_read(cx)?;
|
||||
match (
|
||||
read,
|
||||
self.as_mut().pump_write(cx, read == Poll::Ready(None))?,
|
||||
) {
|
||||
(Poll::Ready(None), Poll::Ready(None)) => {
|
||||
info!("[{}] Client disconnected.", self.channel.client_addr);
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
(read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => {
|
||||
trace!(
|
||||
"[{}] read: {:?}, write: {:?}.",
|
||||
self.channel.client_addr,
|
||||
read,
|
||||
write
|
||||
)
|
||||
}
|
||||
(read, write) => {
|
||||
trace!(
|
||||
"[{}] read: {:?}, write: {:?} (not ready).",
|
||||
self.channel.client_addr,
|
||||
read,
|
||||
write,
|
||||
);
|
||||
return Poll::Pending;
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
move || -> Poll<io::Result<()>> {
|
||||
loop {
|
||||
let read = self.as_mut().pump_read(cx)?;
|
||||
match (
|
||||
read,
|
||||
self.as_mut().pump_write(cx, read == Poll::Ready(None))?,
|
||||
) {
|
||||
(Poll::Ready(None), Poll::Ready(None)) => {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
(Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
|
||||
_ => {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
.map(|r| r.unwrap_or_else(|e| info!("ResponseHandler errored out: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
fn make_server_error(
|
||||
e: timeout::Error<io::Error>,
|
||||
trace_id: TraceId,
|
||||
peer: SocketAddr,
|
||||
deadline: SystemTime,
|
||||
) -> ServerError {
|
||||
if e.is_elapsed() {
|
||||
debug!(
|
||||
"[{}/{}] Response did not complete before deadline of {}s.",
|
||||
"[{}] Response did not complete before deadline of {}s.",
|
||||
trace_id,
|
||||
peer,
|
||||
format_rfc3339(deadline)
|
||||
);
|
||||
// No point in responding, since the client will have dropped the request.
|
||||
@@ -608,8 +602,8 @@ fn make_server_error(
|
||||
}
|
||||
} else if e.is_timer() {
|
||||
error!(
|
||||
"[{}/{}] Response failed because of an issue with a timer: {}",
|
||||
trace_id, peer, e
|
||||
"[{}] Response failed because of an issue with a timer: {}",
|
||||
trace_id, e
|
||||
);
|
||||
|
||||
ServerError {
|
||||
@@ -623,7 +617,7 @@ fn make_server_error(
|
||||
detail: Some(e.description().into()),
|
||||
}
|
||||
} else {
|
||||
error!("[{}/{}] Unexpected response failure: {}", trace_id, peer, e);
|
||||
error!("[{}] Unexpected response failure: {}", trace_id, e);
|
||||
|
||||
ServerError {
|
||||
kind: io::ErrorKind::Other,
|
||||
|
||||
125
rpc/src/server/testing.rs
Normal file
125
rpc/src/server/testing.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use crate::server::{Channel, Config};
|
||||
use crate::{context, Request, Response};
|
||||
use fnv::FnvHashSet;
|
||||
use futures::future::{AbortHandle, AbortRegistration};
|
||||
use futures::{Sink, Stream};
|
||||
use futures_test::task::noop_waker_ref;
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use std::collections::VecDeque;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::SystemTime;
|
||||
|
||||
pub(crate) struct FakeChannel<In, Out> {
|
||||
pub stream: VecDeque<In>,
|
||||
pub sink: VecDeque<Out>,
|
||||
pub config: Config,
|
||||
pub in_flight_requests: FnvHashSet<u64>,
|
||||
}
|
||||
|
||||
impl<In, Out> FakeChannel<In, Out> {
|
||||
unsafe_pinned!(stream: VecDeque<In>);
|
||||
unsafe_pinned!(sink: VecDeque<Out>);
|
||||
unsafe_unpinned!(in_flight_requests: FnvHashSet<u64>);
|
||||
}
|
||||
|
||||
impl<In, Out> Stream for FakeChannel<In, Out>
|
||||
where
|
||||
In: Unpin,
|
||||
{
|
||||
type Item = In;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
self.stream().poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.sink().poll_ready(cx).map_err(|e| match e {})
|
||||
}
|
||||
|
||||
fn start_send(
|
||||
mut self: Pin<&mut Self>,
|
||||
response: Response<Resp>,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.as_mut()
|
||||
.in_flight_requests()
|
||||
.remove(&response.request_id);
|
||||
self.sink().start_send(response).map_err(|e| match e {})
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.sink().poll_flush(cx).map_err(|e| match e {})
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.sink().poll_close(cx).map_err(|e| match e {})
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> Channel for FakeChannel<io::Result<Request<Req>>, Response<Resp>>
|
||||
where
|
||||
Req: Unpin + Send + 'static,
|
||||
Resp: Send + 'static,
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
self.in_flight_requests.len()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration {
|
||||
self.in_flight_requests().insert(id);
|
||||
AbortHandle::new_pair().1
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
|
||||
pub fn push_req(&mut self, id: u64, message: Req) {
|
||||
self.stream.push_back(Ok(Request {
|
||||
context: context::Context {
|
||||
deadline: SystemTime::UNIX_EPOCH,
|
||||
trace_context: Default::default(),
|
||||
},
|
||||
id,
|
||||
message,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeChannel<(), ()> {
|
||||
pub fn default<Req, Resp>() -> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
|
||||
FakeChannel {
|
||||
stream: VecDeque::default(),
|
||||
sink: VecDeque::default(),
|
||||
config: Config::default(),
|
||||
in_flight_requests: FnvHashSet::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait PollExt {
|
||||
fn is_done(&self) -> bool;
|
||||
}
|
||||
|
||||
impl<T> PollExt for Poll<Option<T>> {
|
||||
fn is_done(&self) -> bool {
|
||||
match self {
|
||||
Poll::Ready(None) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cx() -> Context<'static> {
|
||||
Context::from_waker(&noop_waker_ref())
|
||||
}
|
||||
332
rpc/src/server/throttle.rs
Normal file
332
rpc/src/server/throttle.rs
Normal file
@@ -0,0 +1,332 @@
|
||||
use super::{Channel, Config};
|
||||
use crate::{Response, ServerError};
|
||||
use futures::{
|
||||
future::AbortRegistration,
|
||||
prelude::*,
|
||||
ready,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use log::debug;
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use std::{io, pin::Pin};
|
||||
|
||||
/// A [`Channel`] that limits the number of concurrent
|
||||
/// requests by throttling.
|
||||
#[derive(Debug)]
|
||||
pub struct Throttler<C> {
|
||||
max_in_flight_requests: usize,
|
||||
inner: C,
|
||||
}
|
||||
|
||||
impl<C> Throttler<C> {
|
||||
unsafe_unpinned!(max_in_flight_requests: usize);
|
||||
unsafe_pinned!(inner: C);
|
||||
|
||||
/// Returns the inner channel.
|
||||
pub fn get_ref(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Throttler<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
/// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to
|
||||
/// `max_in_flight_requests`.
|
||||
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
|
||||
Throttler {
|
||||
inner,
|
||||
max_in_flight_requests,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Stream for Throttler<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Item = <C as Stream>::Item;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
while self.as_mut().in_flight_requests() >= *self.as_mut().max_in_flight_requests() {
|
||||
ready!(self.as_mut().inner().poll_ready(cx)?);
|
||||
|
||||
match ready!(self.as_mut().inner().poll_next(cx)?) {
|
||||
Some(request) => {
|
||||
debug!(
|
||||
"[{}] Client has reached in-flight request limit ({}/{}).",
|
||||
request.context.trace_id(),
|
||||
self.as_mut().in_flight_requests(),
|
||||
self.as_mut().max_in_flight_requests(),
|
||||
);
|
||||
|
||||
self.as_mut().start_send(Response {
|
||||
request_id: request.id,
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::WouldBlock,
|
||||
detail: Some("Server throttled the request.".into()),
|
||||
}),
|
||||
})?;
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
self.inner().poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner().poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
|
||||
self.inner().start_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
self.inner().poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
self.inner().poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> AsRef<C> for Throttler<C> {
|
||||
fn as_ref(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Channel for Throttler<C>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
type Req = <C as Channel>::Req;
|
||||
type Resp = <C as Channel>::Resp;
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
self.inner().in_flight_requests()
|
||||
}
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
self.inner.config()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
self.inner().start_request(request_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream of throttling channels.
|
||||
#[derive(Debug)]
|
||||
pub struct ThrottlerStream<S> {
|
||||
inner: S,
|
||||
max_in_flight_requests: usize,
|
||||
}
|
||||
|
||||
impl<S> ThrottlerStream<S>
|
||||
where
|
||||
S: Stream,
|
||||
<S as Stream>::Item: Channel,
|
||||
{
|
||||
unsafe_pinned!(inner: S);
|
||||
unsafe_unpinned!(max_in_flight_requests: usize);
|
||||
|
||||
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
max_in_flight_requests,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream for ThrottlerStream<S>
|
||||
where
|
||||
S: Stream,
|
||||
<S as Stream>::Item: Channel,
|
||||
{
|
||||
type Item = Throttler<<S as Stream>::Item>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
match ready!(self.as_mut().inner().poll_next(cx)) {
|
||||
Some(channel) => Poll::Ready(Some(Throttler::new(
|
||||
channel,
|
||||
*self.max_in_flight_requests(),
|
||||
))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use super::testing::{self, FakeChannel, PollExt};
|
||||
#[cfg(test)]
|
||||
use crate::Request;
|
||||
#[cfg(test)]
|
||||
use pin_utils::pin_mut;
|
||||
#[cfg(test)]
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[test]
|
||||
fn throttler_in_flight_requests() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
for i in 0..5 {
|
||||
throttler.inner.in_flight_requests.insert(i);
|
||||
}
|
||||
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_start_request() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.as_mut().start_request(1);
|
||||
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_done() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_some() -> io::Result<()> {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 1,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.inner.push_req(0, 1);
|
||||
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
|
||||
assert_eq!(
|
||||
throttler
|
||||
.as_mut()
|
||||
.poll_next(&mut testing::cx())?
|
||||
.map(|r| r.map(|r| (r.id, r.message))),
|
||||
Poll::Ready(Some((0, 1)))
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_throttled() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.inner.push_req(1, 1);
|
||||
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
||||
assert_eq!(throttler.inner.sink.len(), 1);
|
||||
let resp = throttler.inner.sink.get(0).unwrap();
|
||||
assert_eq!(resp.request_id, 1);
|
||||
assert!(resp.message.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_poll_next_throttled_sink_not_ready() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: PendingSink::default::<isize, isize>(),
|
||||
};
|
||||
pin_mut!(throttler);
|
||||
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
|
||||
|
||||
struct PendingSink<In, Out> {
|
||||
ghost: PhantomData<fn(Out) -> In>,
|
||||
}
|
||||
impl PendingSink<(), ()> {
|
||||
pub fn default<Req, Resp>() -> PendingSink<io::Result<Request<Req>>, Response<Resp>> {
|
||||
PendingSink { ghost: PhantomData }
|
||||
}
|
||||
}
|
||||
impl<In, Out> Stream for PendingSink<In, Out> {
|
||||
type Item = In;
|
||||
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
|
||||
type Error = io::Error;
|
||||
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
|
||||
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>>
|
||||
where
|
||||
Req: Send + 'static,
|
||||
Resp: Send + 'static,
|
||||
{
|
||||
type Req = Req;
|
||||
type Resp = Resp;
|
||||
fn config(&self) -> &Config {
|
||||
unimplemented!()
|
||||
}
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
0
|
||||
}
|
||||
fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_start_send() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.inner.in_flight_requests.insert(0);
|
||||
throttler
|
||||
.as_mut()
|
||||
.start_send(Response {
|
||||
request_id: 0,
|
||||
message: Ok(1),
|
||||
})
|
||||
.unwrap();
|
||||
assert!(throttler.inner.in_flight_requests.is_empty());
|
||||
assert_eq!(
|
||||
throttler.inner.sink.get(0),
|
||||
Some(&Response {
|
||||
request_id: 0,
|
||||
message: Ok(1),
|
||||
})
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user