Alternate polling expired and new requests.

Previously, there were two loops:

- Expired in-flight requests are polled until Pending.
- New requests are polled until Pending.

Now there is one loop that alternates between polling expired requests
and new requests. This way, neither type of action can face starvation.
This commit is contained in:
Tim Kuehn
2021-03-08 20:19:57 -08:00
parent 3feb465ad3
commit 27aacab432
2 changed files with 354 additions and 19 deletions

View File

@@ -292,20 +292,20 @@ where
fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
enum ReceiverStatus {
NotReady,
Pending,
Closed,
}
let pending_requests_status = match self.as_mut().poll_write_request(cx)? {
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
Poll::Pending => ReceiverStatus::Pending,
};
let canceled_requests_status = match self.as_mut().poll_write_cancel(cx)? {
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
Poll::Pending => ReceiverStatus::Pending,
};
// Receiving Poll::Ready(None) when polling expired requests never indicates "Closed",
@@ -323,7 +323,7 @@ where
ready!(self.transport_pin_mut().poll_flush(cx)?);
Poll::Ready(None)
}
(ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
(ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => {
// No more messages to process, so flush any messages buffered in the transport.
ready!(self.transport_pin_mut().poll_flush(cx)?);

View File

@@ -15,6 +15,7 @@ use futures::{
task::*,
};
use humantime::format_rfc3339;
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
use log::{debug, info, trace};
use pin_project::pin_project;
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
@@ -138,7 +139,7 @@ pub struct BaseChannel<Req, Resp, T> {
#[pin]
transport: Fuse<T>,
/// Holds data necessary to clean up in-flight requests.
in_flight_requests: in_flight_requests::InFlightRequests,
in_flight_requests: InFlightRequests,
/// Types the request and response.
ghost: PhantomData<(Req, Resp)>,
}
@@ -152,7 +153,7 @@ where
BaseChannel {
config,
transport: transport.fuse(),
in_flight_requests: in_flight_requests::InFlightRequests::default(),
in_flight_requests: InFlightRequests::default(),
ghost: PhantomData,
}
}
@@ -171,6 +172,14 @@ where
pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> {
self.project().transport.get_pin_mut()
}
fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests {
self.as_mut().project().in_flight_requests
}
fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<T>> {
self.as_mut().project().transport
}
}
impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
@@ -231,7 +240,7 @@ where
self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, in_flight_requests::AlreadyExistsError>;
) -> Result<AbortRegistration, AlreadyExistsError>;
/// Returns a stream of requests that automatically handle request cancellation and response
/// routing.
@@ -276,16 +285,27 @@ where
type Item = io::Result<Request<Req>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let self_ = self.as_mut().project();
while let Poll::Ready(Some(request_id)) = self_.in_flight_requests.poll_expired(cx)? {
// No need to send a response, since the client wouldn't be waiting for one anymore.
debug!("Request {} did not complete before deadline", request_id);
enum ReceiverStatus {
Ready,
Pending,
Closed,
}
use ReceiverStatus::*;
loop {
let self_ = self.as_mut().project();
match ready!(self_.transport.poll_next(cx)?) {
Some(message) => match message {
let expiration_status = match self.in_flight_requests_mut().poll_expired(cx)? {
Poll::Ready(Some(request_id)) => {
// No need to send a response, since the client wouldn't be waiting for one
// anymore.
debug!("Request {} did not complete before deadline", request_id);
Ready
}
Poll::Ready(None) => Closed,
Poll::Pending => Pending,
};
let request_status = match self.transport_pin_mut().poll_next(cx)? {
Poll::Ready(Some(message)) => match message {
ClientMessage::Request(request) => {
return Poll::Ready(Some(Ok(request)));
}
@@ -293,8 +313,8 @@ where
trace_context,
request_id,
} => {
if self_.in_flight_requests.cancel_request(request_id) {
let remaining = self_.in_flight_requests.len();
if self.in_flight_requests_mut().cancel_request(request_id) {
let remaining = self.in_flight_requests.len();
trace!(
"[{}] Request canceled. In-flight requests = {}",
trace_context.trace_id,
@@ -307,9 +327,17 @@ where
trace_context.trace_id,
);
}
Ready
}
},
None => return Poll::Ready(None),
Poll::Ready(None) => Closed,
Poll::Pending => Pending,
};
match (expiration_status, request_status) {
(Ready, _) | (_, Ready) => continue,
(Closed, Closed) => return Poll::Ready(None),
(Pending, Closed) | (Closed, Pending) | (Pending, Pending) => return Poll::Pending,
}
}
}
@@ -367,7 +395,7 @@ where
self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, in_flight_requests::AlreadyExistsError> {
) -> Result<AbortRegistration, AlreadyExistsError> {
self.project()
.in_flight_requests
.start_request(id, deadline)
@@ -432,7 +460,7 @@ where
// Instead of closing the channel if a duplicate request is sent, just
// ignore it, since it's already being processed. Note that we cannot
// return Poll::Pending here, since nothing has scheduled a wakeup yet.
Err(in_flight_requests::AlreadyExistsError) => {
Err(AlreadyExistsError) => {
info!(
"[{}] Request ID {} delivered more than once.",
request.context.trace_id(),
@@ -712,3 +740,310 @@ where
Poll::Ready(())
}
}
#[cfg(test)]
use {
crate::{
trace,
transport::channel::{self, UnboundedChannel},
},
assert_matches::assert_matches,
futures::future::{pending, Aborted},
futures_test::task::noop_context,
std::time::Duration,
};
#[cfg(test)]
fn test_channel<Req, Resp>() -> (
Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
UnboundedChannel<Response<Resp>, ClientMessage<Req>>,
) {
let (tx, rx) = crate::transport::channel::unbounded();
(Box::pin(BaseChannel::new(Config::default(), rx)), tx)
}
#[cfg(test)]
fn test_requests<Req, Resp>() -> (
Pin<
Box<Requests<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
>,
UnboundedChannel<Response<Resp>, ClientMessage<Req>>,
) {
let (tx, rx) = crate::transport::channel::unbounded();
(
Box::pin(BaseChannel::new(Config::default(), rx).requests()),
tx,
)
}
#[cfg(test)]
fn test_bounded_requests<Req, Resp>(
capacity: usize,
) -> (
Pin<
Box<Requests<BaseChannel<Req, Resp, channel::Channel<ClientMessage<Req>, Response<Resp>>>>>,
>,
channel::Channel<Response<Resp>, ClientMessage<Req>>,
) {
let (tx, rx) = crate::transport::channel::bounded(capacity);
let mut config = Config::default();
// Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded).
config.pending_response_buffer = capacity + 1;
(Box::pin(BaseChannel::new(config, rx).requests()), tx)
}
#[cfg(test)]
fn fake_request<Req>(req: Req) -> ClientMessage<Req> {
ClientMessage::Request(Request {
context: context::current(),
id: 0,
message: req,
})
}
#[cfg(test)]
fn test_abortable(
abort_registration: AbortRegistration,
) -> impl Future<Output = Result<(), Aborted>> {
Abortable::new(pending(), abort_registration)
}
#[tokio::test]
async fn base_channel_start_send_duplicate_request_returns_error() {
let (mut channel, _tx) = test_channel::<(), ()>();
channel
.as_mut()
.start_request(0, SystemTime::now())
.unwrap();
assert_matches!(
channel.as_mut().start_request(0, SystemTime::now()),
Err(AlreadyExistsError)
);
}
#[tokio::test]
async fn base_channel_poll_next_aborts_multiple_requests() {
let (mut channel, _tx) = test_channel::<(), ()>();
tokio::time::pause();
let abort_registration0 = channel
.as_mut()
.start_request(0, SystemTime::now())
.unwrap();
let abort_registration1 = channel
.as_mut()
.start_request(1, SystemTime::now())
.unwrap();
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
assert_matches!(
channel.as_mut().poll_next(&mut noop_context()),
Poll::Pending
);
assert_matches!(test_abortable(abort_registration0).await, Err(Aborted));
assert_matches!(test_abortable(abort_registration1).await, Err(Aborted));
}
#[tokio::test]
async fn base_channel_poll_next_aborts_canceled_request() {
let (mut channel, mut tx) = test_channel::<(), ()>();
tokio::time::pause();
let abort_registration = channel
.as_mut()
.start_request(0, SystemTime::now() + Duration::from_millis(100))
.unwrap();
tx.send(ClientMessage::Cancel {
trace_context: trace::Context::default(),
request_id: 0,
})
.await
.unwrap();
assert_matches!(
channel.as_mut().poll_next(&mut noop_context()),
Poll::Pending
);
assert_matches!(test_abortable(abort_registration).await, Err(Aborted));
}
#[tokio::test]
async fn base_channel_with_closed_transport_and_in_flight_request_returns_pending() {
let (mut channel, tx) = test_channel::<(), ()>();
tokio::time::pause();
let _abort_registration = channel
.as_mut()
.start_request(0, SystemTime::now() + Duration::from_millis(100))
.unwrap();
drop(tx);
assert_matches!(
channel.as_mut().poll_next(&mut noop_context()),
Poll::Pending
);
}
#[tokio::test]
async fn base_channel_with_closed_transport_and_no_in_flight_requests_returns_closed() {
let (mut channel, tx) = test_channel::<(), ()>();
drop(tx);
assert_matches!(
channel.as_mut().poll_next(&mut noop_context()),
Poll::Ready(None)
);
}
#[tokio::test]
async fn base_channel_poll_next_yields_request() {
let (mut channel, mut tx) = test_channel::<(), ()>();
tx.send(fake_request(())).await.unwrap();
assert_matches!(
channel.as_mut().poll_next(&mut noop_context()),
Poll::Ready(Some(Ok(_)))
);
}
#[tokio::test]
async fn base_channel_poll_next_aborts_request_and_yields_request() {
let (mut channel, mut tx) = test_channel::<(), ()>();
tokio::time::pause();
let abort_registration = channel
.as_mut()
.start_request(0, SystemTime::now())
.unwrap();
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
tx.send(fake_request(())).await.unwrap();
assert_matches!(
channel.as_mut().poll_next(&mut noop_context()),
Poll::Ready(Some(Ok(_)))
);
assert_matches!(test_abortable(abort_registration).await, Err(Aborted));
}
#[tokio::test]
async fn base_channel_start_send_removes_in_flight_request() {
let (mut channel, _tx) = test_channel::<(), ()>();
channel
.as_mut()
.start_request(0, SystemTime::now())
.unwrap();
assert_eq!(channel.in_flight_requests(), 1);
channel
.as_mut()
.start_send(Response {
request_id: 0,
message: Ok(()),
})
.unwrap();
assert_eq!(channel.in_flight_requests(), 0);
}
#[tokio::test]
async fn requests_poll_next_response_returns_pending_when_buffer_full() {
let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);
// Response written to the transport.
requests
.as_mut()
.channel_pin_mut()
.start_send(Response {
request_id: 0,
message: Ok(()),
})
.unwrap();
// Response waiting to be written.
requests
.as_mut()
.project()
.responses_tx
.send((
context::current(),
Response {
request_id: 1,
message: Ok(()),
},
))
.await
.unwrap();
requests
.as_mut()
.channel_pin_mut()
.start_request(1, SystemTime::now())
.unwrap();
assert_matches!(
requests.as_mut().poll_next_response(&mut noop_context()),
Poll::Pending
);
}
#[tokio::test]
async fn requests_pump_write_returns_pending_when_buffer_full() {
let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);
// Response written to the transport.
requests
.as_mut()
.channel_pin_mut()
.start_send(Response {
request_id: 0,
message: Ok(()),
})
.unwrap();
// Response waiting to be written.
requests
.as_mut()
.project()
.responses_tx
.send((
context::current(),
Response {
request_id: 1,
message: Ok(()),
},
))
.await
.unwrap();
requests
.as_mut()
.channel_pin_mut()
.start_request(1, SystemTime::now())
.unwrap();
assert_matches!(
requests.as_mut().pump_write(&mut noop_context(), true),
Poll::Pending
);
// Assert that the pending response was not polled while the channel was blocked.
assert_matches!(
requests.as_mut().pending_responses_mut().recv().await,
Some(_)
);
}
#[tokio::test]
async fn requests_pump_read() {
let (mut requests, mut tx) = test_requests::<(), ()>();
// Response written to the transport.
tx.send(fake_request(())).await.unwrap();
assert_matches!(
requests.as_mut().pump_read(&mut noop_context()),
Poll::Ready(Some(Ok(_)))
);
assert_eq!(requests.channel.in_flight_requests(), 1);
}