Make server methods more composable.

-- Connection Limits

The problem with having ConnectionFilter default-enabled is elaborated on in https://github.com/google/tarpc/issues/217. The gist of it is not all servers want a policy based on `SocketAddr`. This PR allows customizing the behavior of ConnectionFilter, at the cost of not having it enabled by default. However, enabling it is as simple as one line:

incoming.max_channels_per_key(10, ip_addr)

The second argument is a key function that takes the user-chosen transport and returns some hashable, equatable, cloneable key. In the above example, it returns an `IpAddr`.

This also allows the `Transport` trait to have the addr fns removed, which means it has become simply an alias for `Stream + Sink`.

-- Per-Channel Request Throttling

With respect to Channel's throttling behavior, the same argument applies. There isn't a one size fits all solution to throttling requests, and the policy applied by tarpc is just one of potentially many solutions. As such, `Channel` is now a trait that offers a few combinators, one of which is throttling:

channel.max_concurrent_requests(10).respond_with(serve(Server))

This functionality is also available on the existing `Handler` trait, which applies it to all incoming channels and can be used in tandem with connection limits:

incoming
    .max_channels_per_key(10, ip_addr)
    .max_concurrent_requests_per_channel(10).respond_with(serve(Server))

-- Global Request Throttling

I've entirely removed the overall request limit enforced across all channels. This functionality is easily gotten back via [`StreamExt::buffer_unordered`](https://rust-lang-nursery.github.io/futures-api-docs/0.3.0-alpha.1/futures/stream/trait.StreamExt.html#method.buffer_unordered), with the difference being that the previous behavior allowed you to spawn channels onto different threads, whereas `buffer_unordered ` means the `Channels` are handled on a single thread (the per-request handlers are still spawned). Considering the existing options, I don't believe that the benefit provided by this functionality held its own.
This commit is contained in:
Tim Kuehn
2019-07-15 18:58:36 -07:00
parent 146496d08c
commit 1089415451
36 changed files with 1303 additions and 989 deletions

View File

@@ -7,7 +7,7 @@
use crate::{
context,
util::{deadline_compat, AsDuration, Compact},
ClientMessage, ClientMessageKind, PollIo, Request, Response, Transport,
ClientMessage, PollIo, Request, Response, Transport,
};
use fnv::FnvHashMap;
use futures::{
@@ -24,7 +24,6 @@ use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::{
io,
marker::{self, Unpin},
net::SocketAddr,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
@@ -44,7 +43,6 @@ pub struct Channel<Req, Resp> {
cancellation: RequestCancellation,
/// The ID to use for the next request to stage.
next_request_id: Arc<AtomicU64>,
server_addr: SocketAddr,
}
impl<Req, Resp> Clone for Channel<Req, Resp> {
@@ -53,7 +51,6 @@ impl<Req, Resp> Clone for Channel<Req, Resp> {
to_dispatch: self.to_dispatch.clone(),
cancellation: self.cancellation.clone(),
next_request_id: self.next_request_id.clone(),
server_addr: self.server_addr,
}
}
}
@@ -122,9 +119,8 @@ impl<Req, Resp> Channel<Req, Resp> {
let timeout = ctx.deadline.as_duration();
let deadline = Instant::now() + timeout;
trace!(
"[{}/{}] Queuing request with deadline {} (timeout {:?}).",
"[{}] Queuing request with deadline {} (timeout {:?}).",
ctx.trace_id(),
self.server_addr,
format_rfc3339(ctx.deadline),
timeout,
);
@@ -132,7 +128,6 @@ impl<Req, Resp> Channel<Req, Resp> {
let (response_completion, response) = oneshot::channel();
let cancellation = self.cancellation.clone();
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
let server_addr = self.server_addr;
Send {
fut: MapOkDispatchResponse::new(
MapErrConnectionReset::new(self.to_dispatch.send(DispatchRequest {
@@ -147,7 +142,6 @@ impl<Req, Resp> Channel<Req, Resp> {
request_id,
cancellation,
ctx,
server_addr,
},
),
}
@@ -171,11 +165,9 @@ struct DispatchResponse<Resp> {
complete: bool,
cancellation: RequestCancellation,
request_id: u64,
server_addr: SocketAddr,
}
impl<Resp> DispatchResponse<Resp> {
unsafe_pinned!(server_addr: SocketAddr);
unsafe_pinned!(ctx: context::Context);
}
@@ -192,7 +184,6 @@ impl<Resp> Future for DispatchResponse<Resp> {
}
Err(e) => Err({
let trace_id = *self.as_mut().ctx().trace_id();
let server_addr = *self.as_mut().server_addr();
if e.is_elapsed() {
io::Error::new(
@@ -209,12 +200,9 @@ impl<Resp> Future for DispatchResponse<Resp> {
.to_string(),
)
} else if e.is_shutdown() {
panic!("[{}/{}] Timer was shutdown", trace_id, server_addr)
panic!("[{}] Timer was shutdown", trace_id)
} else {
panic!(
"[{}/{}] Unrecognized timer error: {}",
trace_id, server_addr, e
)
panic!("[{}] Unrecognized timer error: {}", trace_id, e)
}
} else if e.is_inner() {
// The oneshot is Canceled when the dispatch task ends. In that case,
@@ -223,10 +211,7 @@ impl<Resp> Future for DispatchResponse<Resp> {
self.complete = true;
io::Error::from(io::ErrorKind::ConnectionReset)
} else {
panic!(
"[{}/{}] Unrecognized deadline error: {}",
trace_id, server_addr, e
)
panic!("[{}] Unrecognized deadline error: {}", trace_id, e)
}
}),
})
@@ -255,15 +240,11 @@ impl<Resp> Drop for DispatchResponse<Resp> {
/// Spawns a dispatch task on the default executor that manages the lifecycle of requests initiated
/// by the returned [`Channel`].
pub async fn spawn<Req, Resp, C>(
config: Config,
transport: C,
server_addr: SocketAddr,
) -> io::Result<Channel<Req, Resp>>
pub async fn spawn<Req, Resp, C>(config: Config, transport: C) -> io::Result<Channel<Req, Resp>>
where
Req: marker::Send + 'static,
Resp: marker::Send + 'static,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>> + marker::Send + 'static,
C: Transport<ClientMessage<Req>, Response<Resp>> + marker::Send + 'static,
{
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
let (cancellation, canceled_requests) = cancellations();
@@ -272,13 +253,12 @@ where
crate::spawn(
RequestDispatch {
config,
server_addr,
canceled_requests,
transport: transport.fuse(),
in_flight_requests: FnvHashMap::default(),
pending_requests: pending_requests.fuse(),
}
.unwrap_or_else(move |e| error!("[{}] Connection broken: {}", server_addr, e)),
.unwrap_or_else(move |e| error!("Connection broken: {}", e)),
)
.map_err(|e| {
io::Error::new(
@@ -293,7 +273,6 @@ where
Ok(Channel {
to_dispatch,
cancellation,
server_addr,
next_request_id: Arc::new(AtomicU64::new(0)),
})
}
@@ -311,17 +290,14 @@ struct RequestDispatch<Req, Resp, C> {
in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>,
/// Configures limits to prevent unlimited resource usage.
config: Config,
/// The address of the server connected to.
server_addr: SocketAddr,
}
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
where
Req: marker::Send,
Resp: marker::Send,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>>,
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
unsafe_pinned!(server_addr: SocketAddr);
unsafe_pinned!(in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>);
unsafe_pinned!(canceled_requests: Fuse<CanceledRequests>);
unsafe_pinned!(pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>);
@@ -333,10 +309,7 @@ where
self.complete(response);
Some(Ok(()))
}
None => {
trace!("[{}] read half closed", self.as_mut().server_addr());
None
}
None => None,
})
}
@@ -415,10 +388,7 @@ where
return Poll::Ready(Some(Ok(request)));
}
None => {
trace!("[{}] pending_requests closed", self.as_mut().server_addr());
return Poll::Ready(None);
}
None => return Poll::Ready(None),
}
}
}
@@ -440,23 +410,11 @@ where
self.as_mut().in_flight_requests().remove(&request_id)
{
self.as_mut().in_flight_requests().compact(0.1);
debug!(
"[{}/{}] Removed request.",
in_flight_data.ctx.trace_id(),
self.as_mut().server_addr()
);
debug!("[{}] Removed request.", in_flight_data.ctx.trace_id());
return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id))));
}
}
None => {
trace!(
"[{}] canceled_requests closed.",
self.as_mut().server_addr()
);
return Poll::Ready(None);
}
None => return Poll::Ready(None),
}
}
}
@@ -466,14 +424,14 @@ where
dispatch_request: DispatchRequest<Req, Resp>,
) -> io::Result<()> {
let request_id = dispatch_request.request_id;
let request = ClientMessage {
trace_context: dispatch_request.ctx.trace_context,
message: ClientMessageKind::Request(Request {
id: request_id,
message: dispatch_request.request,
let request = ClientMessage::Request(Request {
id: request_id,
message: dispatch_request.request,
context: context::Context {
deadline: dispatch_request.ctx.deadline,
}),
};
trace_context: dispatch_request.ctx.trace_context,
},
});
self.as_mut().transport().start_send(request)?;
self.as_mut().in_flight_requests().insert(
request_id,
@@ -491,16 +449,12 @@ where
request_id: u64,
) -> io::Result<()> {
let trace_id = *context.trace_id();
let cancel = ClientMessage {
let cancel = ClientMessage::Cancel {
trace_context: context.trace_context,
message: ClientMessageKind::Cancel { request_id },
request_id,
};
self.as_mut().transport().start_send(cancel)?;
trace!(
"[{}/{}] Cancel message sent.",
trace_id,
self.as_mut().server_addr()
);
trace!("[{}] Cancel message sent.", trace_id);
Ok(())
}
@@ -513,18 +467,13 @@ where
{
self.as_mut().in_flight_requests().compact(0.1);
trace!(
"[{}/{}] Received response.",
in_flight_data.ctx.trace_id(),
self.as_mut().server_addr()
);
trace!("[{}] Received response.", in_flight_data.ctx.trace_id());
let _ = in_flight_data.response_completion.send(response);
return true;
}
debug!(
"[{}] No in-flight request found for request_id = {}.",
self.as_mut().server_addr(),
"No in-flight request found for request_id = {}.",
response.request_id
);
@@ -537,58 +486,29 @@ impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
where
Req: marker::Send,
Resp: marker::Send,
C: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>>,
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
trace!("[{}] RequestDispatch::poll", self.as_mut().server_addr());
loop {
match (self.pump_read(cx)?, self.pump_write(cx)?) {
(read, write @ Poll::Ready(None)) => {
(read, Poll::Ready(None)) => {
if self.as_mut().in_flight_requests().is_empty() {
info!(
"[{}] Shutdown: write half closed, and no requests in flight.",
self.as_mut().server_addr()
);
info!("Shutdown: write half closed, and no requests in flight.");
return Poll::Ready(Ok(()));
}
let addr = *self.as_mut().server_addr();
info!(
"[{}] {} requests in flight.",
addr,
"Shutdown: write half closed, and {} requests in flight.",
self.as_mut().in_flight_requests().len()
);
match read {
Poll::Ready(Some(())) => continue,
_ => {
trace!(
"[{}] read: {:?}, write: {:?}, (not ready)",
self.as_mut().server_addr(),
read,
write,
);
return Poll::Pending;
}
_ => return Poll::Pending,
}
}
(read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => {
trace!(
"[{}] read: {:?}, write: {:?}",
self.as_mut().server_addr(),
read,
write,
)
}
(read, write) => {
trace!(
"[{}] read: {:?}, write: {:?} (not ready)",
self.as_mut().server_addr(),
read,
write,
);
return Poll::Pending;
}
(Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
_ => return Poll::Pending,
}
}
}
@@ -848,14 +768,7 @@ mod tests {
};
use futures_test::task::noop_waker_ref;
use std::time::Duration;
use std::{
marker,
net::{IpAddr, Ipv4Addr, SocketAddr},
pin::Pin,
sync::atomic::AtomicU64,
sync::Arc,
time::Instant,
};
use std::{marker, pin::Pin, sync::atomic::AtomicU64, sync::Arc, time::Instant};
#[test]
fn dispatch_response_cancels_on_timeout() {
@@ -869,7 +782,6 @@ mod tests {
request_id: 3,
cancellation,
ctx: context::current(),
server_addr: SocketAddr::from(([0, 0, 0, 0], 9999)),
};
{
pin_utils::pin_mut!(resp);
@@ -994,7 +906,6 @@ mod tests {
canceled_requests: CanceledRequests(canceled_requests).fuse(),
in_flight_requests: FnvHashMap::default(),
config: Config::default(),
server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
};
let cancellation = RequestCancellation(cancel_tx);
@@ -1002,7 +913,6 @@ mod tests {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicU64::new(0)),
server_addr: SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
};
(dispatch, channel, server_channel)

View File

@@ -8,11 +8,7 @@
use crate::{context, ClientMessage, Response, Transport};
use futures::prelude::*;
use log::warn;
use std::{
io,
net::{Ipv4Addr, SocketAddr},
};
use std::io;
/// Provides a [`Client`] backed by a transport.
pub mod channel;
@@ -137,15 +133,7 @@ pub async fn new<Req, Resp, T>(config: Config, transport: T) -> io::Result<Chann
where
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = Response<Resp>, SinkItem = ClientMessage<Req>> + Send + 'static,
T: Transport<ClientMessage<Req>, Response<Resp>> + Send + 'static,
{
let server_addr = transport.peer_addr().unwrap_or_else(|e| {
warn!(
"Setting peer to unspecified because peer could not be determined: {}",
e
);
SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0)
});
Ok(channel::spawn(config, transport, server_addr).await?)
Ok(channel::spawn(config, transport).await?)
}

View File

@@ -16,10 +16,20 @@ use trace::{self, TraceId};
/// The context should not be stored directly in a server implementation, because the context will
/// be different for each request in scope.
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct Context {
/// When the client expects the request to be complete by. The server should cancel the request
/// if it is not complete by this time.
#[cfg_attr(
feature = "serde1",
serde(serialize_with = "crate::util::serde::serialize_epoch_secs")
)]
#[cfg_attr(
feature = "serde1",
serde(deserialize_with = "crate::util::serde::deserialize_epoch_secs")
)]
#[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))]
pub deadline: SystemTime,
/// Uniquely identifies requests originating from the same source.
/// When a service handles a request by making requests itself, those requests should
@@ -28,6 +38,11 @@ pub struct Context {
pub trace_context: trace::Context,
}
#[cfg(feature = "serde1")]
fn ten_seconds_from_now() -> SystemTime {
return SystemTime::now() + Duration::from_secs(10);
}
/// Returns the context for the current request, or a default Context if no request is active.
// TODO: populate Context with request-scoped data, with default fallbacks.
pub fn current() -> Context {

View File

@@ -5,11 +5,14 @@
// https://opensource.org/licenses/MIT.
#![feature(
weak_counts,
non_exhaustive,
integer_atomics,
try_trait,
arbitrary_self_types,
async_await
async_await,
trait_alias,
async_closure
)]
#![deny(missing_docs, missing_debug_implementations)]
@@ -49,19 +52,7 @@ use std::{cell::RefCell, io, sync::Once, time::SystemTime};
#[derive(Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct ClientMessage<T> {
/// The trace context associates the message with a specific chain of causally-related actions,
/// possibly orchestrated across many distributed systems.
pub trace_context: trace::Context,
/// The message payload.
pub message: ClientMessageKind<T>,
}
/// Different messages that can be sent from a client to a server.
#[derive(Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum ClientMessageKind<T> {
pub enum ClientMessage<T> {
/// A request initiated by a user. The server responds to a request by invoking a
/// service-provided request handler. The handler completes with a [`response`](Response), which
/// the server sends back to the client.
@@ -74,35 +65,30 @@ pub enum ClientMessageKind<T> {
/// not be canceled, because the framework layer does not
/// know about them.
Cancel {
/// The trace context associates the message with a specific chain of causally-related actions,
/// possibly orchestrated across many distributed systems.
#[cfg_attr(feature = "serde", serde(default))]
trace_context: trace::Context,
/// The ID of the request to cancel.
request_id: u64,
},
}
/// A request from a client to a server.
#[derive(Debug)]
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct Request<T> {
/// Trace context, deadline, and other cross-cutting concerns.
pub context: context::Context,
/// Uniquely identifies the request across all requests sent over a single channel.
pub id: u64,
/// The request body.
pub message: T,
/// When the client expects the request to be complete by. The server will cancel the request
/// if it is not complete by this time.
#[cfg_attr(
feature = "serde1",
serde(serialize_with = "util::serde::serialize_epoch_secs")
)]
#[cfg_attr(
feature = "serde1",
serde(deserialize_with = "util::serde::deserialize_epoch_secs")
)]
pub deadline: SystemTime,
}
/// A response from a server to a client.
#[derive(Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct Response<T> {
@@ -113,7 +99,7 @@ pub struct Response<T> {
}
/// An error response from a server to a client.
#[derive(Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct ServerError {
@@ -140,7 +126,7 @@ impl From<ServerError> for io::Error {
impl<T> Request<T> {
/// Returns the deadline for this request.
pub fn deadline(&self) -> &SystemTime {
&self.deadline
&self.context.deadline
}
}

View File

@@ -5,259 +5,331 @@
// https://opensource.org/licenses/MIT.
use crate::{
server::{Channel, Config},
server::{self, Channel},
util::Compact,
ClientMessage, PollIo, Response, Transport,
Response,
};
use fnv::FnvHashMap;
use futures::{
channel::mpsc,
future::AbortRegistration,
prelude::*,
ready,
stream::Fuse,
task::{Context, Poll},
};
use log::{debug, error, info, trace, warn};
use pin_utils::unsafe_pinned;
use log::{debug, info, trace};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::sync::{Arc, Weak};
use std::{
collections::hash_map::Entry,
io,
marker::PhantomData,
net::{IpAddr, SocketAddr},
ops::Try,
option::NoneError,
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, io, marker::Unpin, ops::Try,
pin::Pin,
};
/// Drops connections under configurable conditions:
///
/// 1. If the max number of connections is reached.
/// 2. If the max number of connections for a single IP is reached.
/// A single-threaded filter that drops channels based on per-key limits.
#[derive(Debug)]
pub struct ConnectionFilter<S, Req, Resp> {
pub struct ChannelFilter<S, K, F>
where
K: Eq + Hash,
{
listener: Fuse<S>,
closed_connections: mpsc::UnboundedSender<SocketAddr>,
closed_connections_rx: mpsc::UnboundedReceiver<SocketAddr>,
config: Config,
connections_per_ip: FnvHashMap<IpAddr, usize>,
open_connections: usize,
ghost: PhantomData<(Req, Resp)>,
channels_per_key: u32,
dropped_keys: mpsc::UnboundedReceiver<K>,
dropped_keys_tx: mpsc::UnboundedSender<K>,
key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
keymaker: F,
}
enum NewConnection<Req, Resp, C> {
Filtered,
Accepted(Channel<Req, Resp, C>),
/// A channel that is tracked by a ChannelFilter.
#[derive(Debug)]
pub struct TrackedChannel<C, K> {
inner: C,
tracker: Arc<Tracker<K>>,
}
impl<Req, Resp, C> Try for NewConnection<Req, Resp, C> {
type Ok = Channel<Req, Resp, C>;
type Error = NoneError;
impl<C, K> TrackedChannel<C, K> {
unsafe_pinned!(inner: C);
}
fn into_result(self) -> Result<Channel<Req, Resp, C>, NoneError> {
#[derive(Debug)]
struct Tracker<K> {
key: Option<K>,
dropped_keys: mpsc::UnboundedSender<K>,
}
impl<K> Drop for Tracker<K> {
fn drop(&mut self) {
// Don't care if the listener is dropped.
let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap());
}
}
/// A running handler serving all requests for a single client.
#[derive(Debug)]
pub struct TrackedHandler<K, Fut> {
inner: Fut,
tracker: Tracker<K>,
}
impl<K, Fut> TrackedHandler<K, Fut>
where
Fut: Future,
{
unsafe_pinned!(inner: Fut);
}
impl<K, Fut> Future for TrackedHandler<K, Fut>
where
Fut: Future,
{
type Output = Fut::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.inner().poll(cx)
}
}
impl<C, K> Stream for TrackedChannel<C, K>
where
C: Channel,
{
type Item = <C as Stream>::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.channel().poll_next(cx)
}
}
impl<C, K> Sink<Response<C::Resp>> for TrackedChannel<C, K>
where
C: Channel,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.channel().poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Response<C::Resp>) -> Result<(), Self::Error> {
self.channel().start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.channel().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.channel().poll_close(cx)
}
}
impl<C, K> AsRef<C> for TrackedChannel<C, K> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C, K> Channel for TrackedChannel<C, K>
where
C: Channel,
{
type Req = C::Req;
type Resp = C::Resp;
fn config(&self) -> &server::Config {
self.inner.config()
}
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
self.inner().in_flight_requests()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
self.inner().start_request(request_id)
}
}
impl<C, K> TrackedChannel<C, K> {
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
}
/// Returns the pinned inner channel.
fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> {
self.inner()
}
}
enum NewChannel<C, K> {
Accepted(TrackedChannel<C, K>),
Filtered(K),
}
impl<C, K> Try for NewChannel<C, K> {
type Ok = TrackedChannel<C, K>;
type Error = K;
fn into_result(self) -> Result<TrackedChannel<C, K>, K> {
match self {
NewConnection::Filtered => Err(NoneError),
NewConnection::Accepted(channel) => Ok(channel),
NewChannel::Accepted(channel) => Ok(channel),
NewChannel::Filtered(k) => Err(k),
}
}
fn from_error(_: NoneError) -> Self {
NewConnection::Filtered
fn from_error(k: K) -> Self {
NewChannel::Filtered(k)
}
fn from_ok(channel: Channel<Req, Resp, C>) -> Self {
NewConnection::Accepted(channel)
fn from_ok(channel: TrackedChannel<C, K>) -> Self {
NewChannel::Accepted(channel)
}
}
impl<S, Req, Resp> ConnectionFilter<S, Req, Resp> {
unsafe_pinned!(open_connections: usize);
unsafe_pinned!(config: Config);
unsafe_pinned!(connections_per_ip: FnvHashMap<IpAddr, usize>);
unsafe_pinned!(closed_connections_rx: mpsc::UnboundedReceiver<SocketAddr>);
impl<S, K, F> ChannelFilter<S, K, F>
where
K: fmt::Display + Eq + Hash + Clone,
{
unsafe_pinned!(listener: Fuse<S>);
unsafe_pinned!(dropped_keys: mpsc::UnboundedReceiver<K>);
unsafe_pinned!(dropped_keys_tx: mpsc::UnboundedSender<K>);
unsafe_unpinned!(key_counts: FnvHashMap<K, Weak<Tracker<K>>>);
unsafe_unpinned!(channels_per_key: u32);
unsafe_unpinned!(keymaker: F);
}
/// Sheds new connections to stay under configured limits.
pub fn filter<C>(listener: S, config: Config) -> Self
where
S: Stream<Item = Result<C, io::Error>>,
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
let (closed_connections, closed_connections_rx) = mpsc::unbounded();
ConnectionFilter {
impl<S, K, F> ChannelFilter<S, K, F>
where
K: Eq + Hash,
S: Stream,
{
/// Sheds new channels to stay under configured limits.
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
let (dropped_keys_tx, dropped_keys) = mpsc::unbounded();
ChannelFilter {
listener: listener.fuse(),
closed_connections,
closed_connections_rx,
config,
connections_per_ip: FnvHashMap::default(),
open_connections: 0,
ghost: PhantomData,
channels_per_key,
dropped_keys,
dropped_keys_tx,
key_counts: FnvHashMap::default(),
keymaker,
}
}
}
fn handle_new_connection<C>(self: &mut Pin<&mut Self>, stream: C) -> NewConnection<Req, Resp, C>
where
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
let peer = match stream.peer_addr() {
Ok(peer) => peer,
Err(e) => {
warn!("Could not get peer_addr of new connection: {}", e);
return NewConnection::Filtered;
}
};
let open_connections = *self.as_mut().open_connections();
if open_connections >= self.as_mut().config().max_connections {
warn!(
"[{}] Shedding connection because the maximum open connections \
limit is reached ({}/{}).",
peer,
open_connections,
self.as_mut().config().max_connections
);
return NewConnection::Filtered;
}
let config = self.config.clone();
let open_connections_for_ip = self.increment_connections_for_ip(&peer)?;
*self.as_mut().open_connections() += 1;
impl<S, C, K, F> ChannelFilter<S, K, F>
where
S: Stream<Item = C>,
C: Channel,
K: fmt::Display + Eq + Hash + Clone + Unpin,
F: Fn(&C) -> K,
{
fn handle_new_channel(self: &mut Pin<&mut Self>, stream: C) -> NewChannel<C, K> {
let key = self.as_mut().keymaker()(&stream);
let tracker = self.increment_channels_for_key(key.clone())?;
let max = self.as_mut().channels_per_key();
debug!(
"[{}] Opening channel ({}/{} connections for IP, {} total).",
peer,
open_connections_for_ip,
config.max_connections_per_ip,
self.as_mut().open_connections(),
"[{}] Opening channel ({}/{}) channels for key.",
key,
Arc::strong_count(&tracker),
max
);
NewConnection::Accepted(Channel {
client_addr: peer,
closed_connections: self.closed_connections.clone(),
transport: stream.fuse(),
config,
ghost: PhantomData,
NewChannel::Accepted(TrackedChannel {
tracker,
inner: stream,
})
}
fn handle_closed_connection(self: &mut Pin<&mut Self>, addr: &SocketAddr) {
*self.as_mut().open_connections() -= 1;
debug!(
"[{}] Closing channel. {} open connections remaining.",
addr, self.open_connections
);
self.decrement_connections_for_ip(&addr);
self.as_mut().connections_per_ip().compact(0.1);
}
fn increment_channels_for_key(self: &mut Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
let channels_per_key = self.channels_per_key;
let dropped_keys = self.dropped_keys_tx.clone();
let key_counts = &mut self.as_mut().key_counts();
match key_counts.entry(key.clone()) {
Entry::Vacant(vacant) => {
let tracker = Arc::new(Tracker {
key: Some(key),
dropped_keys,
});
fn increment_connections_for_ip(self: &mut Pin<&mut Self>, peer: &SocketAddr) -> Option<usize> {
let max_connections_per_ip = self.as_mut().config().max_connections_per_ip;
let mut occupied;
let mut connections_per_ip = self.as_mut().connections_per_ip();
let occupied = match connections_per_ip.entry(peer.ip()) {
Entry::Vacant(vacant) => vacant.insert(0),
Entry::Occupied(o) => {
if *o.get() < max_connections_per_ip {
// Store the reference outside the block to extend the lifetime.
occupied = o;
occupied.get_mut()
} else {
vacant.insert(Arc::downgrade(&tracker));
Ok(tracker)
}
Entry::Occupied(mut o) => {
let count = o.get().strong_count();
if count >= channels_per_key.try_into().unwrap() {
info!(
"[{}] Opened max connections from IP ({}/{}).",
peer,
o.get(),
max_connections_per_ip
"[{}] Opened max channels from key ({}/{}).",
key, count, channels_per_key
);
return None;
}
}
};
*occupied += 1;
Some(*occupied)
}
fn decrement_connections_for_ip(self: &mut Pin<&mut Self>, addr: &SocketAddr) {
let should_compact = match self.as_mut().connections_per_ip().entry(addr.ip()) {
Entry::Vacant(_) => {
error!("[{}] Got vacant entry when closing connection.", addr);
return;
}
Entry::Occupied(mut occupied) => {
*occupied.get_mut() -= 1;
if *occupied.get() == 0 {
occupied.remove();
true
Err(key)
} else {
false
Ok(o.get().upgrade().unwrap_or_else(|| {
let tracker = Arc::new(Tracker {
key: Some(key),
dropped_keys,
});
*o.get_mut() = Arc::downgrade(&tracker);
tracker
}))
}
}
};
if should_compact {
self.as_mut().connections_per_ip().compact(0.1);
}
}
fn poll_listener<C>(
fn poll_listener(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<NewConnection<Req, Resp, C>>
where
S: Stream<Item = Result<C, io::Error>>,
C: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
{
match ready!(self.as_mut().listener().poll_next_unpin(cx)?) {
Some(codec) => Poll::Ready(Some(Ok(self.handle_new_connection(codec)))),
) -> Poll<Option<NewChannel<C, K>>> {
match ready!(self.as_mut().listener().poll_next_unpin(cx)) {
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
None => Poll::Ready(None),
}
}
fn poll_closed_connections(
self: &mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
match ready!(self.as_mut().closed_connections_rx().poll_next_unpin(cx)) {
Some(addr) => {
self.handle_closed_connection(&addr);
Poll::Ready(Ok(()))
fn poll_closed_channels(self: &mut Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
match ready!(self.as_mut().dropped_keys().poll_next_unpin(cx)) {
Some(key) => {
debug!("All channels dropped for key [{}]", key);
self.as_mut().key_counts().remove(&key);
self.as_mut().key_counts().compact(0.1);
Poll::Ready(())
}
None => unreachable!("Holding a copy of closed_connections and didn't close it."),
None => unreachable!("Holding a copy of closed_channels and didn't close it."),
}
}
}
impl<S, Req, Resp, T> Stream for ConnectionFilter<S, Req, Resp>
impl<S, C, K, F> Stream for ChannelFilter<S, K, F>
where
S: Stream<Item = Result<T, io::Error>>,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
S: Stream<Item = C>,
C: Channel,
K: fmt::Display + Eq + Hash + Clone + Unpin,
F: Fn(&C) -> K,
{
type Item = io::Result<Channel<Req, Resp, T>>;
type Item = TrackedChannel<C, K>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Channel<Req, Resp, T>> {
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<TrackedChannel<C, K>>> {
loop {
match (
self.as_mut().poll_listener(cx)?,
self.poll_closed_connections(cx)?,
self.as_mut().poll_listener(cx),
self.poll_closed_channels(cx),
) {
(Poll::Ready(Some(NewConnection::Accepted(channel))), _) => {
return Poll::Ready(Some(Ok(channel)));
(Poll::Ready(Some(NewChannel::Accepted(channel))), _) => {
return Poll::Ready(Some(channel));
}
(Poll::Ready(Some(NewConnection::Filtered)), _) | (_, Poll::Ready(())) => {
trace!(
"Filtered a connection; {} open.",
self.as_mut().open_connections()
);
(Poll::Ready(Some(NewChannel::Filtered(_))), _) => {
continue;
}
(_, Poll::Ready(())) => continue,
(Poll::Pending, Poll::Pending) => return Poll::Pending,
(Poll::Ready(None), Poll::Pending) => {
if *self.as_mut().open_connections() > 0 {
trace!(
"Listener closed; {} open connections.",
self.as_mut().open_connections()
);
return Poll::Pending;
}
trace!("Shutting down listener: all connections closed, and no more coming.");
trace!("Shutting down listener.");
return Poll::Ready(None);
}
}

View File

@@ -7,27 +7,27 @@
//! Provides a server that concurrently handles many connections sending multiplexed requests.
use crate::{
context, util::deadline_compat, util::AsDuration, util::Compact, ClientMessage,
ClientMessageKind, PollIo, Request, Response, ServerError, Transport,
context, util::deadline_compat, util::AsDuration, util::Compact, ClientMessage, PollIo,
Request, Response, ServerError, Transport,
};
use fnv::FnvHashMap;
use futures::{
channel::mpsc,
future::{abortable, AbortHandle},
future::{AbortHandle, AbortRegistration, Abortable},
prelude::*,
ready,
stream::Fuse,
task::{Context, Poll},
try_ready,
};
use humantime::format_rfc3339;
use log::{debug, error, info, trace, warn};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::{
error::Error as StdError,
fmt,
hash::Hash,
io,
marker::PhantomData,
net::SocketAddr,
pin::Pin,
time::{Instant, SystemTime},
};
@@ -35,6 +35,14 @@ use tokio_timer::timeout;
use trace::{self, TraceId};
mod filter;
#[cfg(test)]
mod testing;
mod throttle;
pub use self::{
filter::ChannelFilter,
throttle::{Throttler, ThrottlerStream},
};
/// Manages clients, serving multiplexed requests over each connection.
#[derive(Debug)]
@@ -53,17 +61,6 @@ impl<Req, Resp> Default for Server<Req, Resp> {
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct Config {
/// The maximum number of clients that can be connected to the server at once. When at the
/// limit, existing connections are honored and new connections are rejected.
pub max_connections: usize,
/// The maximum number of clients per IP address that can be connected to the server at once.
/// When an IP is at the limit, existing connections are honored and new connections on that IP
/// address are rejected.
pub max_connections_per_ip: usize,
/// The maximum number of requests that can be in flight for each client. When a client is at
/// the in-flight request limit, existing requests are fulfilled and new requests are rejected.
/// Rejected requests are sent a response error.
pub max_in_flight_requests_per_connection: usize,
/// The number of responses per client that can be buffered server-side before being sent.
/// `pending_response_buffer` controls the buffer size of the channel that a server's
/// response tasks use to send responses to the client handler task.
@@ -73,14 +70,21 @@ pub struct Config {
impl Default for Config {
fn default() -> Self {
Config {
max_connections: 1_000_000,
max_connections_per_ip: 1_000,
max_in_flight_requests_per_connection: 1_000,
pending_response_buffer: 100,
}
}
}
impl Config {
/// Returns a channel backed by `transport` and configured with `self`.
pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send,
{
BaseChannel::new(self, transport)
}
}
/// Returns a new server with configuration specified `config`.
pub fn new<Req, Resp>(config: Config) -> Server<Req, Resp> {
Server {
@@ -95,18 +99,15 @@ impl<Req, Resp> Server<Req, Resp> {
&self.config
}
/// Returns a stream of the incoming connections to the server.
pub fn incoming<S, T>(
self,
listener: S,
) -> impl Stream<Item = io::Result<Channel<Req, Resp, T>>>
/// Returns a stream of server channels.
pub fn incoming<S, T>(self, listener: S) -> impl Stream<Item = BaseChannel<Req, Resp, T>>
where
Req: Send,
Resp: Send,
S: Stream<Item = io::Result<T>>,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
S: Stream<Item = T>,
T: Transport<Response<Resp>, ClientMessage<Req>> + Send,
{
self::filter::ConnectionFilter::filter(listener, self.config.clone())
listener.map(move |t| BaseChannel::new(self.config.clone(), t))
}
}
@@ -122,31 +123,21 @@ impl<S, F> Running<S, F> {
unsafe_unpinned!(request_handler: F);
}
impl<S, T, Req, Resp, F, Fut> Future for Running<S, F>
impl<S, C, F, Fut> Future for Running<S, F>
where
S: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send + 'static,
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
S: Sized + Stream<Item = C>,
C: Channel + Send + 'static,
F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<C::Resp>> + Send + 'static,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) {
match channel {
Ok(channel) => {
let peer = channel.client_addr;
if let Err(e) =
crate::spawn(channel.respond_with(self.as_mut().request_handler().clone()))
{
warn!("[{}] Failed to spawn connection handler: {:?}", peer, e);
}
}
Err(e) => {
warn!("Incoming connection error: {}", e);
}
if let Err(e) =
crate::spawn(channel.respond_with(self.as_mut().request_handler().clone()))
{
warn!("Failed to spawn channel handler: {:?}", e);
}
}
info!("Server shutting down.");
@@ -155,18 +146,30 @@ where
}
/// A utility trait enabling a stream to fluently chain a request handler.
pub trait Handler<T, Req, Resp>
pub trait Handler<C>
where
Self: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send,
Resp: Send,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
Self: Sized + Stream<Item = C>,
C: Channel,
{
/// Enforces channel per-key limits.
fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> filter::ChannelFilter<Self, K, KF>
where
K: fmt::Display + Eq + Hash + Clone + Unpin,
KF: Fn(&C) -> K,
{
ChannelFilter::new(self, n, keymaker)
}
/// Caps the number of concurrent requests per channel.
fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream<Self> {
ThrottlerStream::new(self, n)
}
/// Responds to all requests with `request_handler`.
fn respond_with<F, Fut>(self, request_handler: F) -> Running<Self, F>
where
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<C::Resp>> + Send + 'static,
{
Running {
incoming: self,
@@ -175,191 +178,276 @@ where
}
}
impl<T, Req, Resp, S> Handler<T, Req, Resp> for S
impl<S, C> Handler<C> for S
where
S: Sized + Stream<Item = io::Result<Channel<Req, Resp, T>>>,
Req: Send,
Resp: Send,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
S: Sized + Stream<Item = C>,
C: Channel,
{
}
/// Responds to all requests with `request_handler`.
/// The server end of an open connection with a client.
/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
#[derive(Debug)]
pub struct Channel<Req, Resp, T> {
pub struct BaseChannel<Req, Resp, T> {
config: Config,
/// Writes responses to the wire and reads requests off the wire.
transport: Fuse<T>,
/// Signals the connection is closed when `Channel` is dropped.
closed_connections: mpsc::UnboundedSender<SocketAddr>,
/// Channel limits to prevent unlimited resource usage.
config: Config,
/// The address of the server connected to.
client_addr: SocketAddr,
/// Number of requests currently being responded to.
in_flight_requests: FnvHashMap<u64, AbortHandle>,
/// Types the request and response.
ghost: PhantomData<(Req, Resp)>,
}
impl<Req, Resp, T> Drop for Channel<Req, Resp, T> {
fn drop(&mut self) {
trace!("[{}] Closing channel.", self.client_addr);
impl<Req, Resp, T> BaseChannel<Req, Resp, T> {
unsafe_unpinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
}
// Even in a bounded channel, each connection would have a guaranteed slot, so using
// an unbounded sender is actually no different. And, the bound is on the maximum number
// of open connections.
if self
.closed_connections
.unbounded_send(self.client_addr)
.is_err()
{
warn!(
"[{}] Failed to send closed connection message.",
self.client_addr
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send,
{
/// Creates a new channel backed by `transport` and configured with `config`.
pub fn new(config: Config, transport: T) -> Self {
BaseChannel {
config,
transport: transport.fuse(),
in_flight_requests: FnvHashMap::default(),
ghost: PhantomData,
}
}
/// Creates a new channel backed by `transport` and configured with the defaults.
pub fn with_defaults(transport: T) -> Self {
Self::new(Config::default(), transport)
}
/// Returns the inner transport.
pub fn get_ref(&self) -> &T {
self.transport.get_ref()
}
/// Returns the pinned inner transport.
pub fn transport<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> {
unsafe { self.map_unchecked_mut(|me| me.transport.get_mut()) }
}
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
// It's possible the request was already completed, so it's fine
// if this is None.
if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) {
self.as_mut().in_flight_requests().compact(0.1);
cancel_handle.abort();
let remaining = self.as_mut().in_flight_requests().len();
trace!(
"[{}] Request canceled. In-flight requests = {}",
trace_context.trace_id,
remaining,
);
} else {
trace!(
"[{}] Received cancellation, but response handler \
is already complete.",
trace_context.trace_id,
);
}
}
}
impl<Req, Resp, T> Channel<Req, Resp, T> {
unsafe_pinned!(transport: Fuse<T>);
}
impl<Req, Resp, T> Channel<Req, Resp, T>
/// The server end of an open connection with a client, streaming in requests from, and sinking
/// responses to, the client.
///
/// Channels are free to somewhat rely on the assumption that all in-flight requests are eventually
/// either [cancelled](Channel::cancel_request) or [responded to](Sink::start_send). Safety cannot
/// rely on this assumption, but it is best for `Channel` users to always account for all outstanding
/// requests.
pub trait Channel
where
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
Req: Send,
Resp: Send,
Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
{
pub(crate) fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> io::Result<()> {
self.as_mut().transport().start_send(response)
/// Type of request item.
type Req: Send + 'static;
/// Type of response sink item.
type Resp: Send + 'static;
/// Configuration of the channel.
fn config(&self) -> &Config;
/// Returns the number of in-flight requests over this channel.
fn in_flight_requests(self: Pin<&mut Self>) -> usize;
/// Caps the number of concurrent requests.
fn max_concurrent_requests(self, n: usize) -> Throttler<Self>
where
Self: Sized,
{
Throttler::new(self, n)
}
pub(crate) fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
self.as_mut().transport().poll_ready(cx)
}
pub(crate) fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
self.as_mut().transport().poll_flush(cx)
}
pub(crate) fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<ClientMessage<Req>> {
self.as_mut().transport().poll_next(cx)
}
/// Returns the address of the client connected to the channel.
pub fn client_addr(&self) -> &SocketAddr {
&self.client_addr
}
/// Tells the Channel that request with ID `request_id` is being handled.
/// The request will be tracked until a response with the same ID is sent
/// to the Channel.
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration;
/// Respond to requests coming over the channel with `f`. Returns a future that drives the
/// responses and resolves when the connection is closed.
pub fn respond_with<F, Fut>(self, f: F) -> impl Future<Output = ()>
fn respond_with<F, Fut>(self, f: F) -> ResponseHandler<Self, F>
where
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
Req: 'static,
Resp: 'static,
F: FnOnce(context::Context, Self::Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Self::Resp>> + Send + 'static,
Self: Sized,
{
let (responses_tx, responses) = mpsc::channel(self.config.pending_response_buffer);
let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
let responses = responses.fuse();
let peer = self.client_addr;
ClientHandler {
ResponseHandler {
channel: self,
f,
pending_responses: responses,
responses_tx,
in_flight_requests: FnvHashMap::default(),
}
.unwrap_or_else(move |e| {
info!("[{}] ClientHandler errored out: {}", peer, e);
})
}
}
impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send + 'static,
Req: Send + 'static,
Resp: Send + 'static,
{
type Item = io::Result<Request<Req>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
loop {
match ready!(self.as_mut().transport().poll_next(cx)?) {
Some(message) => match message {
ClientMessage::Request(request) => {
return Poll::Ready(Some(Ok(request)));
}
ClientMessage::Cancel {
trace_context,
request_id,
} => {
self.as_mut().cancel_request(&trace_context, request_id);
}
},
None => return Poll::Ready(None),
}
}
}
}
impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send + 'static,
Req: Send + 'static,
Resp: Send + 'static,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.transport().poll_ready(cx)
}
fn start_send(
mut self: Pin<&mut Self>,
response: Response<Resp>,
) -> Result<(), Self::Error> {
if self
.as_mut()
.in_flight_requests()
.remove(&response.request_id)
.is_some()
{
self.as_mut().in_flight_requests().compact(0.1);
}
self.transport().start_send(response)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.transport().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.transport().poll_close(cx)
}
}
impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
fn as_ref(&self) -> &T {
self.transport.get_ref()
}
}
impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>> + Send + 'static,
Req: Send + 'static,
Resp: Send + 'static,
{
type Req = Req;
type Resp = Resp;
fn config(&self) -> &Config {
&self.config
}
fn in_flight_requests(mut self: Pin<&mut Self>) -> usize {
self.as_mut().in_flight_requests().len()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
let (abort_handle, abort_registration) = AbortHandle::new_pair();
assert!(self
.in_flight_requests()
.insert(request_id, abort_handle)
.is_none());
abort_registration
}
}
/// A running handler serving all requests coming over a channel.
#[derive(Debug)]
struct ClientHandler<Req, Resp, T, F> {
channel: Channel<Req, Resp, T>,
pub struct ResponseHandler<C, F>
where
C: Channel,
{
channel: C,
/// Responses waiting to be written to the wire.
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<Resp>)>>,
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>,
/// Handed out to request handlers to fan in responses.
responses_tx: mpsc::Sender<(context::Context, Response<Resp>)>,
/// Number of requests currently being responded to.
in_flight_requests: FnvHashMap<u64, AbortHandle>,
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
/// Request handler.
f: F,
}
impl<Req, Resp, T, F> ClientHandler<Req, Resp, T, F> {
unsafe_pinned!(channel: Channel<Req, Resp, T>);
unsafe_pinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<Resp>)>>);
unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response<Resp>)>);
impl<C, F> ResponseHandler<C, F>
where
C: Channel,
{
unsafe_pinned!(channel: C);
unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>);
unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>);
// For this to be safe, field f must be private, and code in this module must never
// construct PinMut<F>.
unsafe_unpinned!(f: F);
}
impl<Req, Resp, T, F, Fut> ClientHandler<Req, Resp, T, F>
impl<C, F, Fut> ResponseHandler<C, F>
where
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
C: Channel,
F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<C::Resp>> + Send + 'static,
{
/// If at max in-flight requests, check that there's room to immediately write a throttled
/// response.
fn poll_ready_if_throttling(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
if self.in_flight_requests.len()
>= self.channel.config.max_in_flight_requests_per_connection
{
let peer = self.as_mut().channel().client_addr;
while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? {
info!(
"[{}] In-flight requests at max ({}), and transport is not ready.",
peer,
self.as_mut().in_flight_requests().len(),
);
try_ready!(self.as_mut().channel().poll_flush(cx));
}
}
Poll::Ready(Ok(()))
}
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
ready!(self.as_mut().poll_ready_if_throttling(cx)?);
Poll::Ready(match ready!(self.as_mut().channel().poll_next(cx)?) {
Some(message) => {
match message.message {
ClientMessageKind::Request(request) => {
self.handle_request(message.trace_context, request)?;
}
ClientMessageKind::Cancel { request_id } => {
self.cancel_request(&message.trace_context, request_id);
}
}
Some(Ok(()))
match ready!(self.as_mut().channel().poll_next(cx)?) {
Some(request) => {
self.handle_request(request)?;
Poll::Ready(Some(Ok(())))
}
None => {
trace!("[{}] Read half closed", self.channel.client_addr);
None
}
})
None => Poll::Ready(None),
}
}
fn pump_write(
@@ -368,7 +456,12 @@ where
read_half_closed: bool,
) -> PollIo<()> {
match self.as_mut().poll_next_response(cx)? {
Poll::Ready(Some((_, response))) => {
Poll::Ready(Some((ctx, response))) => {
trace!(
"[{}] Staging response. In-flight requests = {}.",
ctx.trace_id(),
self.as_mut().channel().in_flight_requests(),
);
self.as_mut().channel().start_send(response)?;
Poll::Ready(Some(Ok(())))
}
@@ -384,7 +477,7 @@ where
// Being here means there are no staged requests and all written responses are
// fully flushed. So, if the read half is closed and there are no in-flight
// requests, then we can close the write half.
if read_half_closed && self.as_mut().in_flight_requests().is_empty() {
if read_half_closed && self.as_mut().channel().in_flight_requests() == 0 {
Poll::Ready(None)
} else {
Poll::Pending
@@ -396,90 +489,33 @@ where
fn poll_next_response(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<(context::Context, Response<Resp>)> {
) -> PollIo<(context::Context, Response<C::Resp>)> {
// Ensure there's room to write a response.
while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? {
ready!(self.as_mut().channel().poll_flush(cx)?);
}
let peer = self.as_mut().channel().client_addr;
match ready!(self.as_mut().pending_responses().poll_next(cx)) {
Some((ctx, response)) => {
if self
.as_mut()
.in_flight_requests()
.remove(&response.request_id)
.is_some()
{
self.as_mut().in_flight_requests().compact(0.1);
}
trace!(
"[{}/{}] Staging response. In-flight requests = {}.",
ctx.trace_id(),
peer,
self.as_mut().in_flight_requests().len(),
);
Poll::Ready(Some(Ok((ctx, response))))
}
Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
None => {
// This branch likely won't happen, since the ClientHandler is holding a Sender.
trace!("[{}] No new responses.", peer);
// This branch likely won't happen, since the ResponseHandler is holding a Sender.
Poll::Ready(None)
}
}
}
fn handle_request(
mut self: Pin<&mut Self>,
trace_context: trace::Context,
request: Request<Req>,
) -> io::Result<()> {
fn handle_request(mut self: Pin<&mut Self>, request: Request<C::Req>) -> io::Result<()> {
let request_id = request.id;
let peer = self.as_mut().channel().client_addr;
let ctx = context::Context {
deadline: request.deadline,
trace_context,
};
let request = request.message;
if self.as_mut().in_flight_requests().len()
>= self
.as_mut()
.channel()
.config
.max_in_flight_requests_per_connection
{
debug!(
"[{}/{}] Client has reached in-flight request limit ({}/{}).",
ctx.trace_id(),
peer,
self.as_mut().in_flight_requests().len(),
self.as_mut()
.channel()
.config
.max_in_flight_requests_per_connection
);
self.as_mut().channel().start_send(Response {
request_id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()),
}),
})?;
return Ok(());
}
let deadline = ctx.deadline;
let deadline = request.context.deadline;
let timeout = deadline.as_duration();
trace!(
"[{}/{}] Received request with deadline {} (timeout {:?}).",
ctx.trace_id(),
peer,
"[{}] Received request with deadline {} (timeout {:?}).",
request.context.trace_id(),
format_rfc3339(deadline),
timeout,
);
let ctx = request.context;
let request = request.message;
let mut response_tx = self.as_mut().responses_tx().clone();
let trace_id = *ctx.trace_id();
@@ -490,18 +526,19 @@ where
request_id,
message: match result {
Ok(message) => Ok(message),
Err(e) => Err(make_server_error(e, trace_id, peer, deadline)),
Err(e) => Err(make_server_error(e, trace_id, deadline)),
},
};
trace!("[{}/{}] Sending response.", trace_id, peer);
trace!("[{}] Sending response.", trace_id);
response_tx
.send((ctx, response))
.unwrap_or_else(|_| ())
.await;
},
);
let (abortable_response, abort_handle) = abortable(response);
crate::spawn(abortable_response.map(|_| ())).map_err(|e| {
let abort_registration = self.as_mut().channel().start_request(request_id);
let response = Abortable::new(response, abort_registration);
crate::spawn(response.map(|_| ())).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!(
@@ -510,92 +547,49 @@ where
),
)
})?;
self.as_mut()
.in_flight_requests()
.insert(request_id, abort_handle);
Ok(())
}
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
// It's possible the request was already completed, so it's fine
// if this is None.
if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) {
self.as_mut().in_flight_requests().compact(0.1);
cancel_handle.abort();
let remaining = self.as_mut().in_flight_requests().len();
trace!(
"[{}/{}] Request canceled. In-flight requests = {}",
trace_context.trace_id,
self.channel.client_addr,
remaining,
);
} else {
trace!(
"[{}/{}] Received cancellation, but response handler \
is already complete.",
trace_context.trace_id,
self.channel.client_addr
);
}
}
}
impl<Req, Resp, T, F, Fut> Future for ClientHandler<Req, Resp, T, F>
impl<C, F, Fut> Future for ResponseHandler<C, F>
where
Req: Send + 'static,
Resp: Send + 'static,
T: Transport<Item = ClientMessage<Req>, SinkItem = Response<Resp>> + Send,
F: FnOnce(context::Context, Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<Resp>> + Send + 'static,
C: Channel,
F: FnOnce(context::Context, C::Req) -> Fut + Send + 'static + Clone,
Fut: Future<Output = io::Result<C::Resp>> + Send + 'static,
{
type Output = io::Result<()>;
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
trace!("[{}] ClientHandler::poll", self.channel.client_addr);
loop {
let read = self.as_mut().pump_read(cx)?;
match (
read,
self.as_mut().pump_write(cx, read == Poll::Ready(None))?,
) {
(Poll::Ready(None), Poll::Ready(None)) => {
info!("[{}] Client disconnected.", self.channel.client_addr);
return Poll::Ready(Ok(()));
}
(read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => {
trace!(
"[{}] read: {:?}, write: {:?}.",
self.channel.client_addr,
read,
write
)
}
(read, write) => {
trace!(
"[{}] read: {:?}, write: {:?} (not ready).",
self.channel.client_addr,
read,
write,
);
return Poll::Pending;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
move || -> Poll<io::Result<()>> {
loop {
let read = self.as_mut().pump_read(cx)?;
match (
read,
self.as_mut().pump_write(cx, read == Poll::Ready(None))?,
) {
(Poll::Ready(None), Poll::Ready(None)) => {
return Poll::Ready(Ok(()));
}
(Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
_ => {
return Poll::Pending;
}
}
}
}
}()
.map(|r| r.unwrap_or_else(|e| info!("ResponseHandler errored out: {}", e)))
}
}
fn make_server_error(
e: timeout::Error<io::Error>,
trace_id: TraceId,
peer: SocketAddr,
deadline: SystemTime,
) -> ServerError {
if e.is_elapsed() {
debug!(
"[{}/{}] Response did not complete before deadline of {}s.",
"[{}] Response did not complete before deadline of {}s.",
trace_id,
peer,
format_rfc3339(deadline)
);
// No point in responding, since the client will have dropped the request.
@@ -608,8 +602,8 @@ fn make_server_error(
}
} else if e.is_timer() {
error!(
"[{}/{}] Response failed because of an issue with a timer: {}",
trace_id, peer, e
"[{}] Response failed because of an issue with a timer: {}",
trace_id, e
);
ServerError {
@@ -623,7 +617,7 @@ fn make_server_error(
detail: Some(e.description().into()),
}
} else {
error!("[{}/{}] Unexpected response failure: {}", trace_id, peer, e);
error!("[{}] Unexpected response failure: {}", trace_id, e);
ServerError {
kind: io::ErrorKind::Other,

125
rpc/src/server/testing.rs Normal file
View File

@@ -0,0 +1,125 @@
use crate::server::{Channel, Config};
use crate::{context, Request, Response};
use fnv::FnvHashSet;
use futures::future::{AbortHandle, AbortRegistration};
use futures::{Sink, Stream};
use futures_test::task::noop_waker_ref;
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::collections::VecDeque;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::SystemTime;
pub(crate) struct FakeChannel<In, Out> {
pub stream: VecDeque<In>,
pub sink: VecDeque<Out>,
pub config: Config,
pub in_flight_requests: FnvHashSet<u64>,
}
impl<In, Out> FakeChannel<In, Out> {
unsafe_pinned!(stream: VecDeque<In>);
unsafe_pinned!(sink: VecDeque<Out>);
unsafe_unpinned!(in_flight_requests: FnvHashSet<u64>);
}
impl<In, Out> Stream for FakeChannel<In, Out>
where
In: Unpin,
{
type Item = In;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.stream().poll_next(cx)
}
}
impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.sink().poll_ready(cx).map_err(|e| match e {})
}
fn start_send(
mut self: Pin<&mut Self>,
response: Response<Resp>,
) -> Result<(), Self::Error> {
self.as_mut()
.in_flight_requests()
.remove(&response.request_id);
self.sink().start_send(response).map_err(|e| match e {})
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.sink().poll_flush(cx).map_err(|e| match e {})
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.sink().poll_close(cx).map_err(|e| match e {})
}
}
impl<Req, Resp> Channel for FakeChannel<io::Result<Request<Req>>, Response<Resp>>
where
Req: Unpin + Send + 'static,
Resp: Send + 'static,
{
type Req = Req;
type Resp = Resp;
fn config(&self) -> &Config {
&self.config
}
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
self.in_flight_requests.len()
}
fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration {
self.in_flight_requests().insert(id);
AbortHandle::new_pair().1
}
}
impl<Req, Resp> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
pub fn push_req(&mut self, id: u64, message: Req) {
self.stream.push_back(Ok(Request {
context: context::Context {
deadline: SystemTime::UNIX_EPOCH,
trace_context: Default::default(),
},
id,
message,
}));
}
}
impl FakeChannel<(), ()> {
pub fn default<Req, Resp>() -> FakeChannel<io::Result<Request<Req>>, Response<Resp>> {
FakeChannel {
stream: VecDeque::default(),
sink: VecDeque::default(),
config: Config::default(),
in_flight_requests: FnvHashSet::default(),
}
}
}
pub trait PollExt {
fn is_done(&self) -> bool;
}
impl<T> PollExt for Poll<Option<T>> {
fn is_done(&self) -> bool {
match self {
Poll::Ready(None) => true,
_ => false,
}
}
}
pub fn cx() -> Context<'static> {
Context::from_waker(&noop_waker_ref())
}

332
rpc/src/server/throttle.rs Normal file
View File

@@ -0,0 +1,332 @@
use super::{Channel, Config};
use crate::{Response, ServerError};
use futures::{
future::AbortRegistration,
prelude::*,
ready,
task::{Context, Poll},
};
use log::debug;
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::{io, pin::Pin};
/// A [`Channel`] that limits the number of concurrent
/// requests by throttling.
#[derive(Debug)]
pub struct Throttler<C> {
max_in_flight_requests: usize,
inner: C,
}
impl<C> Throttler<C> {
unsafe_unpinned!(max_in_flight_requests: usize);
unsafe_pinned!(inner: C);
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
}
}
impl<C> Throttler<C>
where
C: Channel,
{
/// Returns a new `Throttler` that wraps the given channel and limits concurrent requests to
/// `max_in_flight_requests`.
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
Throttler {
inner,
max_in_flight_requests,
}
}
}
impl<C> Stream for Throttler<C>
where
C: Channel,
{
type Item = <C as Stream>::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
while self.as_mut().in_flight_requests() >= *self.as_mut().max_in_flight_requests() {
ready!(self.as_mut().inner().poll_ready(cx)?);
match ready!(self.as_mut().inner().poll_next(cx)?) {
Some(request) => {
debug!(
"[{}] Client has reached in-flight request limit ({}/{}).",
request.context.trace_id(),
self.as_mut().in_flight_requests(),
self.as_mut().max_in_flight_requests(),
);
self.as_mut().start_send(Response {
request_id: request.id,
message: Err(ServerError {
kind: io::ErrorKind::WouldBlock,
detail: Some("Server throttled the request.".into()),
}),
})?;
}
None => return Poll::Ready(None),
}
}
self.inner().poll_next(cx)
}
}
impl<C> Sink<Response<<C as Channel>::Resp>> for Throttler<C>
where
C: Channel,
{
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.inner().poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
self.inner().start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.inner().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.inner().poll_close(cx)
}
}
impl<C> AsRef<C> for Throttler<C> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C> Channel for Throttler<C>
where
C: Channel,
{
type Req = <C as Channel>::Req;
type Resp = <C as Channel>::Resp;
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
self.inner().in_flight_requests()
}
fn config(&self) -> &Config {
self.inner.config()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
self.inner().start_request(request_id)
}
}
/// A stream of throttling channels.
#[derive(Debug)]
pub struct ThrottlerStream<S> {
inner: S,
max_in_flight_requests: usize,
}
impl<S> ThrottlerStream<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
unsafe_pinned!(inner: S);
unsafe_unpinned!(max_in_flight_requests: usize);
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
Self {
inner,
max_in_flight_requests,
}
}
}
impl<S> Stream for ThrottlerStream<S>
where
S: Stream,
<S as Stream>::Item: Channel,
{
type Item = Throttler<<S as Stream>::Item>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match ready!(self.as_mut().inner().poll_next(cx)) {
Some(channel) => Poll::Ready(Some(Throttler::new(
channel,
*self.max_in_flight_requests(),
))),
None => Poll::Ready(None),
}
}
}
#[cfg(test)]
use super::testing::{self, FakeChannel, PollExt};
#[cfg(test)]
use crate::Request;
#[cfg(test)]
use pin_utils::pin_mut;
#[cfg(test)]
use std::marker::PhantomData;
#[test]
fn throttler_in_flight_requests() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
for i in 0..5 {
throttler.inner.in_flight_requests.insert(i);
}
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
}
#[test]
fn throttler_start_request() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.as_mut().start_request(1);
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
}
#[test]
fn throttler_poll_next_done() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
}
#[test]
fn throttler_poll_next_some() -> io::Result<()> {
let throttler = Throttler {
max_in_flight_requests: 1,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(0, 1);
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
assert_eq!(
throttler
.as_mut()
.poll_next(&mut testing::cx())?
.map(|r| r.map(|r| (r.id, r.message))),
Poll::Ready(Some((0, 1)))
);
Ok(())
}
#[test]
fn throttler_poll_next_throttled() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.push_req(1, 1);
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
assert_eq!(throttler.inner.sink.len(), 1);
let resp = throttler.inner.sink.get(0).unwrap();
assert_eq!(resp.request_id, 1);
assert!(resp.message.is_err());
}
#[test]
fn throttler_poll_next_throttled_sink_not_ready() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: PendingSink::default::<isize, isize>(),
};
pin_mut!(throttler);
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
struct PendingSink<In, Out> {
ghost: PhantomData<fn(Out) -> In>,
}
impl PendingSink<(), ()> {
pub fn default<Req, Resp>() -> PendingSink<io::Result<Request<Req>>, Response<Resp>> {
PendingSink { ghost: PhantomData }
}
}
impl<In, Out> Stream for PendingSink<In, Out> {
type Item = In;
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
unimplemented!()
}
}
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
}
impl<Req, Resp> Channel for PendingSink<io::Result<Request<Req>>, Response<Resp>>
where
Req: Send + 'static,
Resp: Send + 'static,
{
type Req = Req;
type Resp = Resp;
fn config(&self) -> &Config {
unimplemented!()
}
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
0
}
fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration {
unimplemented!()
}
}
}
#[test]
fn throttler_start_send() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.in_flight_requests.insert(0);
throttler
.as_mut()
.start_send(Response {
request_id: 0,
message: Ok(1),
})
.unwrap();
assert!(throttler.inner.in_flight_requests.is_empty());
assert_eq!(
throttler.inner.sink.get(0),
Some(&Response {
request_id: 0,
message: Ok(1),
})
);
}

View File

@@ -6,14 +6,11 @@
//! Transports backed by in-memory channels.
use crate::{PollIo, Transport};
use crate::PollIo;
use futures::{channel::mpsc, task::Context, Poll, Sink, Stream};
use pin_utils::unsafe_pinned;
use std::io;
use std::pin::Pin;
use std::{
io,
net::{IpAddr, Ipv4Addr, SocketAddr},
};
/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's
/// [`Sink`].
@@ -51,7 +48,7 @@ impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
}
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
type SinkError = io::Error;
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.tx()
@@ -65,7 +62,7 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::SinkError>> {
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx()
.poll_flush(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
@@ -78,19 +75,6 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
}
}
impl<Item, SinkItem> Transport for UnboundedChannel<Item, SinkItem> {
type SinkItem = SinkItem;
type Item = Item;
fn peer_addr(&self) -> io::Result<SocketAddr> {
Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
}
fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
}
}
#[cfg(test)]
mod tests {
use crate::{
@@ -110,7 +94,7 @@ mod tests {
let (client_channel, server_channel) = transport::channel::unbounded();
let server = Server::<String, u64>::default()
.incoming(stream::once(future::ready(Ok(server_channel))))
.incoming(stream::once(future::ready(server_channel)))
.respond_with(|_ctx, request| {
future::ready(request.parse::<u64>().map_err(|_| {
io::Error::new(

View File

@@ -10,114 +10,10 @@
//! can be plugged in, using whatever protocol it wants.
use futures::prelude::*;
use std::{
io,
marker::PhantomData,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
};
use std::io;
pub mod channel;
/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages.
pub trait Transport
where
Self: Stream<Item = io::Result<<Self as Transport>::Item>>,
Self: Sink<<Self as Transport>::SinkItem, SinkError = io::Error>,
{
/// The type read off the transport.
type Item;
/// The type written to the transport.
type SinkItem;
/// The address of the remote peer this transport is in communication with.
fn peer_addr(&self) -> io::Result<SocketAddr>;
/// The address of the local half of this transport.
fn local_addr(&self) -> io::Result<SocketAddr>;
}
/// Returns a new Transport backed by the given Stream + Sink and connecting addresses.
pub fn new<S, SinkItem, Item>(
inner: S,
peer_addr: SocketAddr,
local_addr: SocketAddr,
) -> impl Transport<Item = Item, SinkItem = SinkItem>
where
S: Stream<Item = io::Result<Item>>,
S: Sink<SinkItem, SinkError = io::Error>,
{
TransportShim {
inner,
peer_addr,
local_addr,
_marker: PhantomData,
}
}
/// A transport created by adding peers to a Stream + Sink.
#[derive(Debug)]
struct TransportShim<S, SinkItem> {
peer_addr: SocketAddr,
local_addr: SocketAddr,
inner: S,
_marker: PhantomData<SinkItem>,
}
impl<S, SinkItem> TransportShim<S, SinkItem> {
pin_utils::unsafe_pinned!(inner: S);
}
impl<S, SinkItem> Stream for TransportShim<S, SinkItem>
where
S: Stream,
{
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<S::Item>> {
self.inner().poll_next(cx)
}
}
impl<S, Item> Sink<Item> for TransportShim<S, Item>
where
S: Sink<Item>,
{
type SinkError = S::SinkError;
fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), S::SinkError> {
self.inner().start_send(item)
}
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::SinkError>> {
self.inner().poll_ready(cx)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::SinkError>> {
self.inner().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::SinkError>> {
self.inner().poll_close(cx)
}
}
impl<S, SinkItem, Item> Transport for TransportShim<S, SinkItem>
where
S: Stream + Sink<SinkItem>,
Self: Stream<Item = io::Result<Item>>,
Self: Sink<SinkItem, SinkError = io::Error>,
{
type Item = Item;
type SinkItem = SinkItem;
/// The address of the remote peer this transport is in communication with.
fn peer_addr(&self) -> io::Result<SocketAddr> {
Ok(self.peer_addr)
}
/// The address of the local half of this transport.
fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(self.local_addr)
}
}
pub trait Transport<SinkItem, Item> =
Stream<Item = io::Result<Item>> + Sink<SinkItem, Error = io::Error>;

View File

@@ -38,9 +38,11 @@ where
H: BuildHasher,
{
fn compact(&mut self, usage_ratio_threshold: f64) {
let usage_ratio = self.len() as f64 / self.capacity() as f64;
if usage_ratio < usage_ratio_threshold {
self.shrink_to_fit();
if self.capacity() > 1000 {
let usage_ratio = self.len() as f64 / self.capacity() as f64;
if usage_ratio < usage_ratio_threshold {
self.shrink_to_fit();
}
}
}
}