mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-01-08 12:32:11 +01:00
Port to pin-project
This commit is contained in:
@@ -18,7 +18,7 @@ use futures::{
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use log::{debug, info, trace};
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use pin_project::pin_project;
|
||||
use raii_counter::{Counter, WeakCounter};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::{
|
||||
@@ -26,30 +26,32 @@ use std::{
|
||||
};
|
||||
|
||||
/// A single-threaded filter that drops channels based on per-key limits.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ChannelFilter<S, K, F>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
{
|
||||
#[pin]
|
||||
listener: Fuse<S>,
|
||||
channels_per_key: u32,
|
||||
#[pin]
|
||||
dropped_keys: mpsc::UnboundedReceiver<K>,
|
||||
#[pin]
|
||||
dropped_keys_tx: mpsc::UnboundedSender<K>,
|
||||
key_counts: FnvHashMap<K, TrackerPrototype<K>>,
|
||||
keymaker: F,
|
||||
}
|
||||
|
||||
/// A channel that is tracked by a ChannelFilter.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct TrackedChannel<C, K> {
|
||||
#[pin]
|
||||
inner: C,
|
||||
tracker: Tracker<K>,
|
||||
}
|
||||
|
||||
impl<C, K> TrackedChannel<C, K> {
|
||||
unsafe_pinned!(inner: C);
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Tracker<K> {
|
||||
key: Option<Arc<K>>,
|
||||
@@ -130,11 +132,11 @@ where
|
||||
}
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
self.inner().in_flight_requests()
|
||||
self.project().inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
self.inner().start_request(request_id)
|
||||
self.project().inner.start_request(request_id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,22 +148,10 @@ impl<C, K> TrackedChannel<C, K> {
|
||||
|
||||
/// Returns the pinned inner channel.
|
||||
fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> {
|
||||
self.inner()
|
||||
self.project().inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, K, F> ChannelFilter<S, K, F>
|
||||
where
|
||||
K: fmt::Display + Eq + Hash + Clone,
|
||||
{
|
||||
unsafe_pinned!(listener: Fuse<S>);
|
||||
unsafe_pinned!(dropped_keys: mpsc::UnboundedReceiver<K>);
|
||||
unsafe_pinned!(dropped_keys_tx: mpsc::UnboundedSender<K>);
|
||||
unsafe_unpinned!(key_counts: FnvHashMap<K, TrackerPrototype<K>>);
|
||||
unsafe_unpinned!(channels_per_key: u32);
|
||||
unsafe_unpinned!(keymaker: F);
|
||||
}
|
||||
|
||||
impl<S, K, F> ChannelFilter<S, K, F>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
@@ -192,14 +182,14 @@ where
|
||||
mut self: Pin<&mut Self>,
|
||||
stream: S::Item,
|
||||
) -> Result<TrackedChannel<S::Item, K>, K> {
|
||||
let key = self.as_mut().keymaker()(&stream);
|
||||
let key = (self.as_mut().keymaker)(&stream);
|
||||
let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
|
||||
|
||||
trace!(
|
||||
"[{}] Opening channel ({}/{}) channels for key.",
|
||||
key,
|
||||
tracker.counter.count(),
|
||||
self.as_mut().channels_per_key()
|
||||
self.as_mut().project().channels_per_key
|
||||
);
|
||||
|
||||
Ok(TrackedChannel {
|
||||
@@ -211,7 +201,7 @@ where
|
||||
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Tracker<K>, K> {
|
||||
let channels_per_key = self.channels_per_key;
|
||||
let dropped_keys = self.dropped_keys_tx.clone();
|
||||
let key_counts = &mut self.as_mut().key_counts();
|
||||
let key_counts = &mut self.as_mut().project().key_counts;
|
||||
match key_counts.entry(key.clone()) {
|
||||
Entry::Vacant(vacant) => {
|
||||
let key = Arc::new(key);
|
||||
@@ -256,18 +246,18 @@ where
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
|
||||
match ready!(self.as_mut().listener().poll_next_unpin(cx)) {
|
||||
match ready!(self.as_mut().project().listener.poll_next_unpin(cx)) {
|
||||
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
match ready!(self.as_mut().dropped_keys().poll_next_unpin(cx)) {
|
||||
match ready!(self.as_mut().project().dropped_keys.poll_next_unpin(cx)) {
|
||||
Some(key) => {
|
||||
debug!("All channels dropped for key [{}]", key);
|
||||
self.as_mut().key_counts().remove(&key);
|
||||
self.as_mut().key_counts().compact(0.1);
|
||||
self.as_mut().project().key_counts.remove(&key);
|
||||
self.as_mut().project().key_counts.compact(0.1);
|
||||
Poll::Ready(())
|
||||
}
|
||||
None => unreachable!("Holding a copy of closed_channels and didn't close it."),
|
||||
|
||||
@@ -21,7 +21,7 @@ use futures::{
|
||||
};
|
||||
use humantime::format_rfc3339;
|
||||
use log::{debug, trace};
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use pin_project::pin_project;
|
||||
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
|
||||
use tokio_timer::{timeout, Timeout};
|
||||
|
||||
@@ -165,10 +165,12 @@ where
|
||||
}
|
||||
|
||||
/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct BaseChannel<Req, Resp, T> {
|
||||
config: Config,
|
||||
/// Writes responses to the wire and reads requests off the wire.
|
||||
#[pin]
|
||||
transport: Fuse<T>,
|
||||
/// Number of requests currently being responded to.
|
||||
in_flight_requests: FnvHashMap<u64, AbortHandle>,
|
||||
@@ -176,10 +178,6 @@ pub struct BaseChannel<Req, Resp, T> {
|
||||
ghost: PhantomData<(Req, Resp)>,
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> BaseChannel<Req, Resp, T> {
|
||||
unsafe_unpinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
@@ -204,19 +202,19 @@ where
|
||||
self.transport.get_ref()
|
||||
}
|
||||
|
||||
/// Returns the pinned inner transport.
|
||||
pub fn transport<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> {
|
||||
unsafe { self.map_unchecked_mut(|me| me.transport.get_mut()) }
|
||||
}
|
||||
|
||||
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
|
||||
// It's possible the request was already completed, so it's fine
|
||||
// if this is None.
|
||||
if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) {
|
||||
self.as_mut().in_flight_requests().compact(0.1);
|
||||
if let Some(cancel_handle) = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&request_id)
|
||||
{
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
|
||||
cancel_handle.abort();
|
||||
let remaining = self.as_mut().in_flight_requests().len();
|
||||
let remaining = self.as_mut().project().in_flight_requests.len();
|
||||
trace!(
|
||||
"[{}] Request canceled. In-flight requests = {}",
|
||||
trace_context.trace_id,
|
||||
@@ -295,7 +293,7 @@ where
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
match ready!(self.as_mut().transport().poll_next(cx)?) {
|
||||
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
|
||||
Some(message) => match message {
|
||||
ClientMessage::Request(request) => {
|
||||
return Poll::Ready(Some(Ok(request)));
|
||||
@@ -321,28 +319,29 @@ where
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.transport().poll_ready(cx)
|
||||
self.project().transport.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
|
||||
if self
|
||||
.as_mut()
|
||||
.in_flight_requests()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&response.request_id)
|
||||
.is_some()
|
||||
{
|
||||
self.as_mut().in_flight_requests().compact(0.1);
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
}
|
||||
|
||||
self.transport().start_send(response)
|
||||
self.project().transport.start_send(response)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.transport().poll_flush(cx)
|
||||
self.project().transport.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.transport().poll_close(cx)
|
||||
self.project().transport.poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,13 +363,14 @@ where
|
||||
}
|
||||
|
||||
fn in_flight_requests(mut self: Pin<&mut Self>) -> usize {
|
||||
self.as_mut().in_flight_requests().len()
|
||||
self.as_mut().project().in_flight_requests.len()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
let (abort_handle, abort_registration) = AbortHandle::new_pair();
|
||||
assert!(self
|
||||
.in_flight_requests()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.insert(request_id, abort_handle)
|
||||
.is_none());
|
||||
abort_registration
|
||||
@@ -378,32 +378,24 @@ where
|
||||
}
|
||||
|
||||
/// A running handler serving all requests coming over a channel.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ClientHandler<C, S>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
#[pin]
|
||||
channel: C,
|
||||
/// Responses waiting to be written to the wire.
|
||||
#[pin]
|
||||
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>,
|
||||
/// Handed out to request handlers to fan in responses.
|
||||
#[pin]
|
||||
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
|
||||
/// Server
|
||||
server: S,
|
||||
}
|
||||
|
||||
impl<C, S> ClientHandler<C, S>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
unsafe_pinned!(channel: C);
|
||||
unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>);
|
||||
unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>);
|
||||
// For this to be safe, field f must be private, and code in this module must never
|
||||
// construct PinMut<S>.
|
||||
unsafe_unpinned!(server: S);
|
||||
}
|
||||
|
||||
impl<C, S> ClientHandler<C, S>
|
||||
where
|
||||
C: Channel,
|
||||
@@ -413,7 +405,7 @@ where
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<RequestHandler<S::Fut, C::Resp>> {
|
||||
match ready!(self.as_mut().channel().poll_next(cx)?) {
|
||||
match ready!(self.as_mut().project().channel.poll_next(cx)?) {
|
||||
Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
@@ -429,24 +421,24 @@ where
|
||||
trace!(
|
||||
"[{}] Staging response. In-flight requests = {}.",
|
||||
ctx.trace_id(),
|
||||
self.as_mut().channel().in_flight_requests(),
|
||||
self.as_mut().project().channel.in_flight_requests(),
|
||||
);
|
||||
self.as_mut().channel().start_send(response)?;
|
||||
self.as_mut().project().channel.start_send(response)?;
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
// Shutdown can't be done before we finish pumping out remaining responses.
|
||||
ready!(self.as_mut().channel().poll_flush(cx)?);
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Pending => {
|
||||
// No more requests to process, so flush any requests buffered in the transport.
|
||||
ready!(self.as_mut().channel().poll_flush(cx)?);
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
|
||||
// Being here means there are no staged requests and all written responses are
|
||||
// fully flushed. So, if the read half is closed and there are no in-flight
|
||||
// requests, then we can close the write half.
|
||||
if read_half_closed && self.as_mut().channel().in_flight_requests() == 0 {
|
||||
if read_half_closed && self.as_mut().project().channel.in_flight_requests() == 0 {
|
||||
Poll::Ready(None)
|
||||
} else {
|
||||
Poll::Pending
|
||||
@@ -460,11 +452,11 @@ where
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<(context::Context, Response<C::Resp>)> {
|
||||
// Ensure there's room to write a response.
|
||||
while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? {
|
||||
ready!(self.as_mut().channel().poll_flush(cx)?);
|
||||
while let Poll::Pending = self.as_mut().project().channel.poll_ready(cx)? {
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
}
|
||||
|
||||
match ready!(self.as_mut().pending_responses().poll_next(cx)) {
|
||||
match ready!(self.as_mut().project().pending_responses.poll_next(cx)) {
|
||||
Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
|
||||
None => {
|
||||
// This branch likely won't happen, since the ClientHandler is holding a Sender.
|
||||
@@ -489,7 +481,7 @@ where
|
||||
let ctx = request.context;
|
||||
let request = request.message;
|
||||
|
||||
let response = self.as_mut().server().clone().serve(ctx, request);
|
||||
let response = self.as_mut().project().server.clone().serve(ctx, request);
|
||||
let response = Resp {
|
||||
state: RespState::PollResp,
|
||||
request_id,
|
||||
@@ -497,9 +489,9 @@ where
|
||||
deadline,
|
||||
f: Timeout::new(response, timeout),
|
||||
response: None,
|
||||
response_tx: self.as_mut().responses_tx().clone(),
|
||||
response_tx: self.as_mut().project().responses_tx.clone(),
|
||||
};
|
||||
let abort_registration = self.as_mut().channel().start_request(request_id);
|
||||
let abort_registration = self.as_mut().project().channel.start_request(request_id);
|
||||
RequestHandler {
|
||||
resp: Abortable::new(response, abort_registration),
|
||||
}
|
||||
@@ -507,15 +499,13 @@ where
|
||||
}
|
||||
|
||||
/// A future fulfilling a single client request.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct RequestHandler<F, R> {
|
||||
#[pin]
|
||||
resp: Abortable<Resp<F, R>>,
|
||||
}
|
||||
|
||||
impl<F, R> RequestHandler<F, R> {
|
||||
unsafe_pinned!(resp: Abortable<Resp<F, R>>);
|
||||
}
|
||||
|
||||
impl<F, R> Future for RequestHandler<F, R>
|
||||
where
|
||||
F: Future<Output = R>,
|
||||
@@ -523,19 +513,22 @@ where
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
let _ = ready!(self.resp().poll(cx));
|
||||
let _ = ready!(self.project().resp.poll(cx));
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
struct Resp<F, R> {
|
||||
state: RespState,
|
||||
request_id: u64,
|
||||
ctx: context::Context,
|
||||
deadline: SystemTime,
|
||||
#[pin]
|
||||
f: Timeout<F>,
|
||||
response: Option<Response<R>>,
|
||||
#[pin]
|
||||
response_tx: mpsc::Sender<(context::Context, Response<R>)>,
|
||||
}
|
||||
|
||||
@@ -546,13 +539,6 @@ enum RespState {
|
||||
PollFlush,
|
||||
}
|
||||
|
||||
impl<F, R> Resp<F, R> {
|
||||
unsafe_pinned!(f: Timeout<F>);
|
||||
unsafe_pinned!(response_tx: mpsc::Sender<(context::Context, Response<R>)>);
|
||||
unsafe_unpinned!(response: Option<Response<R>>);
|
||||
unsafe_unpinned!(state: RespState);
|
||||
}
|
||||
|
||||
impl<F, R> Future for Resp<F, R>
|
||||
where
|
||||
F: Future<Output = R>,
|
||||
@@ -561,10 +547,10 @@ where
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
loop {
|
||||
match self.as_mut().state() {
|
||||
match self.as_mut().project().state {
|
||||
RespState::PollResp => {
|
||||
let result = ready!(self.as_mut().f().poll(cx));
|
||||
*self.as_mut().response() = Some(Response {
|
||||
let result = ready!(self.as_mut().project().f.poll(cx));
|
||||
*self.as_mut().project().response = Some(Response {
|
||||
request_id: self.request_id,
|
||||
message: match result {
|
||||
Ok(message) => Ok(message),
|
||||
@@ -588,21 +574,27 @@ where
|
||||
},
|
||||
_non_exhaustive: (),
|
||||
});
|
||||
*self.as_mut().state() = RespState::PollReady;
|
||||
*self.as_mut().project().state = RespState::PollReady;
|
||||
}
|
||||
RespState::PollReady => {
|
||||
let ready = ready!(self.as_mut().response_tx().poll_ready(cx));
|
||||
let ready = ready!(self.as_mut().project().response_tx.poll_ready(cx));
|
||||
if ready.is_err() {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
let resp = (self.ctx, self.as_mut().response().take().unwrap());
|
||||
if self.as_mut().response_tx().start_send(resp).is_err() {
|
||||
let resp = (self.ctx, self.as_mut().project().response.take().unwrap());
|
||||
if self
|
||||
.as_mut()
|
||||
.project()
|
||||
.response_tx
|
||||
.start_send(resp)
|
||||
.is_err()
|
||||
{
|
||||
return Poll::Ready(());
|
||||
}
|
||||
*self.as_mut().state() = RespState::PollFlush;
|
||||
*self.as_mut().project().state = RespState::PollFlush;
|
||||
}
|
||||
RespState::PollFlush => {
|
||||
let ready = ready!(self.as_mut().response_tx().poll_flush(cx));
|
||||
let ready = ready!(self.as_mut().project().response_tx.poll_flush(cx));
|
||||
if ready.is_err() {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
@@ -672,19 +664,15 @@ where
|
||||
|
||||
/// A future that drives the server by spawning channels and request handlers on the default
|
||||
/// executor.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[cfg(feature = "tokio1")]
|
||||
pub struct Running<St, Se> {
|
||||
#[pin]
|
||||
incoming: St,
|
||||
server: Se,
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
impl<St, Se> Running<St, Se> {
|
||||
unsafe_pinned!(incoming: St);
|
||||
unsafe_unpinned!(server: Se);
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
impl<St, C, Se> Future for Running<St, Se>
|
||||
where
|
||||
@@ -700,10 +688,10 @@ where
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
use log::info;
|
||||
|
||||
while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) {
|
||||
while let Some(channel) = ready!(self.as_mut().project().incoming.poll_next(cx)) {
|
||||
tokio::spawn(
|
||||
channel
|
||||
.respond_with(self.as_mut().server().clone())
|
||||
.respond_with(self.as_mut().project().server.clone())
|
||||
.execute(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,26 +4,23 @@ use fnv::FnvHashSet;
|
||||
use futures::future::{AbortHandle, AbortRegistration};
|
||||
use futures::{Sink, Stream};
|
||||
use futures_test::task::noop_waker_ref;
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use pin_project::pin_project;
|
||||
use std::collections::VecDeque;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[pin_project]
|
||||
pub(crate) struct FakeChannel<In, Out> {
|
||||
#[pin]
|
||||
pub stream: VecDeque<In>,
|
||||
#[pin]
|
||||
pub sink: VecDeque<Out>,
|
||||
pub config: Config,
|
||||
pub in_flight_requests: FnvHashSet<u64>,
|
||||
}
|
||||
|
||||
impl<In, Out> FakeChannel<In, Out> {
|
||||
unsafe_pinned!(stream: VecDeque<In>);
|
||||
unsafe_pinned!(sink: VecDeque<Out>);
|
||||
unsafe_unpinned!(in_flight_requests: FnvHashSet<u64>);
|
||||
}
|
||||
|
||||
impl<In, Out> Stream for FakeChannel<In, Out>
|
||||
where
|
||||
In: Unpin,
|
||||
@@ -31,7 +28,7 @@ where
|
||||
type Item = In;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
self.stream().poll_next(cx)
|
||||
self.project().stream.poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,22 +36,26 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.sink().poll_ready(cx).map_err(|e| match e {})
|
||||
self.project().sink.poll_ready(cx).map_err(|e| match e {})
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
|
||||
self.as_mut()
|
||||
.in_flight_requests()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&response.request_id);
|
||||
self.sink().start_send(response).map_err(|e| match e {})
|
||||
self.project()
|
||||
.sink
|
||||
.start_send(response)
|
||||
.map_err(|e| match e {})
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.sink().poll_flush(cx).map_err(|e| match e {})
|
||||
self.project().sink.poll_flush(cx).map_err(|e| match e {})
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.sink().poll_close(cx).map_err(|e| match e {})
|
||||
self.project().sink.poll_close(cx).map_err(|e| match e {})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +75,7 @@ where
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration {
|
||||
self.in_flight_requests().insert(id);
|
||||
self.project().in_flight_requests.insert(id);
|
||||
AbortHandle::new_pair().1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,21 +7,20 @@ use futures::{
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use log::debug;
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use pin_project::pin_project;
|
||||
use std::{io, pin::Pin};
|
||||
|
||||
/// A [`Channel`] that limits the number of concurrent
|
||||
/// requests by throttling.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct Throttler<C> {
|
||||
max_in_flight_requests: usize,
|
||||
#[pin]
|
||||
inner: C,
|
||||
}
|
||||
|
||||
impl<C> Throttler<C> {
|
||||
unsafe_unpinned!(max_in_flight_requests: usize);
|
||||
unsafe_pinned!(inner: C);
|
||||
|
||||
/// Returns the inner channel.
|
||||
pub fn get_ref(&self) -> &C {
|
||||
&self.inner
|
||||
@@ -49,16 +48,17 @@ where
|
||||
type Item = <C as Stream>::Item;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
while self.as_mut().in_flight_requests() >= *self.as_mut().max_in_flight_requests() {
|
||||
ready!(self.as_mut().inner().poll_ready(cx)?);
|
||||
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
|
||||
{
|
||||
ready!(self.as_mut().project().inner.poll_ready(cx)?);
|
||||
|
||||
match ready!(self.as_mut().inner().poll_next(cx)?) {
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
|
||||
Some(request) => {
|
||||
debug!(
|
||||
"[{}] Client has reached in-flight request limit ({}/{}).",
|
||||
request.context.trace_id(),
|
||||
self.as_mut().in_flight_requests(),
|
||||
self.as_mut().max_in_flight_requests(),
|
||||
self.as_mut().project().max_in_flight_requests,
|
||||
);
|
||||
|
||||
self.as_mut().start_send(Response {
|
||||
@@ -74,7 +74,7 @@ where
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
self.inner().poll_next(cx)
|
||||
self.project().inner.poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,19 +85,19 @@ where
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner().poll_ready(cx)
|
||||
self.project().inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
|
||||
self.inner().start_send(item)
|
||||
self.project().inner.start_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
self.inner().poll_flush(cx)
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
self.inner().poll_close(cx)
|
||||
self.project().inner.poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ where
|
||||
type Resp = <C as Channel>::Resp;
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
self.inner().in_flight_requests()
|
||||
self.project().inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
@@ -123,13 +123,15 @@ where
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
self.inner().start_request(request_id)
|
||||
self.project().inner.start_request(request_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream of throttling channels.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ThrottlerStream<S> {
|
||||
#[pin]
|
||||
inner: S,
|
||||
max_in_flight_requests: usize,
|
||||
}
|
||||
@@ -139,9 +141,6 @@ where
|
||||
S: Stream,
|
||||
<S as Stream>::Item: Channel,
|
||||
{
|
||||
unsafe_pinned!(inner: S);
|
||||
unsafe_unpinned!(max_in_flight_requests: usize);
|
||||
|
||||
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
@@ -158,10 +157,10 @@ where
|
||||
type Item = Throttler<<S as Stream>::Item>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
match ready!(self.as_mut().inner().poll_next(cx)) {
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)) {
|
||||
Some(channel) => Poll::Ready(Some(Throttler::new(
|
||||
channel,
|
||||
*self.max_in_flight_requests(),
|
||||
*self.project().max_in_flight_requests,
|
||||
))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user