Handle deadlines in BaseChannel.

Before this commit, deadlines were handled by a timeout future that
wrapped each request handler. However, request handlers can be dropped
before sending a response back to the channel, so they can't be relied
on for channel state cleanup. Additionally, clients can't be relied on
to send cancellation messages. It was therefore theoretically possible
for pathological behaviors to cause an unbounded growth in orphan
request data in the Channel.

With this change, as long as requests sent have reasonable deadlines,
then the channel will be able to clean itself up. It is still possible
for requests to be sent with very large deadlines, which would prevent
the channel from cleaning itself up.
This commit is contained in:
Tim Kuehn
2021-03-07 00:12:46 -08:00
parent 6f419e9a9a
commit 3c978c5bf6
7 changed files with 375 additions and 149 deletions

View File

@@ -37,7 +37,7 @@ serde = { optional = true, version = "1.0", features = ["derive"] }
static_assertions = "1.1.0"
tarpc-plugins = { path = "../plugins", version = "0.9" }
tokio = { version = "1", features = ["time"] }
tokio-util = { optional = true, version = "0.6" }
tokio-util = { version = "0.6.3", features = ["time"] }
tokio-serde = { optional = true, version = "0.8" }
[dev-dependencies]
@@ -46,10 +46,11 @@ bincode = "1.3"
bytes = { version = "1", features = ["serde"] }
env_logger = "0.8"
flate2 = "1.0"
futures-test = "0.3"
log = "0.4"
pin-utils = "0.1.0-alpha"
serde_bytes = "0.11"
tokio = { version = "1", features = ["full"] }
tokio = { version = "1", features = ["full", "test-util"] }
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
trybuild = "1.0"

View File

@@ -4,16 +4,12 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use futures::{
future::{self, Ready},
prelude::*,
};
use futures::future::{self, Ready};
use std::io;
use tarpc::{
client, context,
server::{self, Channel},
};
use tokio_serde::formats::Json;
/// This is the service definition. It looks a lot like a trait definition.
/// It defines one RPC, hello, which takes one arg, name, and returns a String.

View File

@@ -6,25 +6,22 @@
//! Provides a server that concurrently handles many connections sending multiplexed requests.
use crate::{
context, trace, util::Compact, util::TimeUntil, ClientMessage, PollIo, Request, Response,
ServerError, Transport,
};
use fnv::FnvHashMap;
use crate::{context, ClientMessage, PollIo, Request, Response, ServerError, Transport};
use futures::{
channel::mpsc,
future::{AbortHandle, AbortRegistration, Abortable},
future::{AbortRegistration, Abortable},
prelude::*,
ready,
stream::Fuse,
task::*,
};
use humantime::format_rfc3339;
use log::{debug, trace};
use pin_project::{pin_project, pinned_drop};
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin};
use log::{debug, info, trace};
use pin_project::pin_project;
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
mod filter;
mod in_flight_requests;
#[cfg(test)]
mod testing;
mod throttle;
@@ -134,14 +131,14 @@ where
/// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation
/// messages. Instead, it internally handles them by cancelling corresponding requests (removing
/// the corresponding in-flight requests and aborting their handlers).
#[pin_project(PinnedDrop)]
#[pin_project]
pub struct BaseChannel<Req, Resp, T> {
config: Config,
/// Writes responses to the wire and reads requests off the wire.
#[pin]
transport: Fuse<T>,
/// Number of requests currently being responded to.
in_flight_requests: FnvHashMap<u64, AbortHandle>,
/// Holds data necessary to clean up in-flight requests.
in_flight_requests: in_flight_requests::InFlightRequests,
/// Types the request and response.
ghost: PhantomData<(Req, Resp)>,
}
@@ -155,7 +152,7 @@ where
BaseChannel {
config,
transport: transport.fuse(),
in_flight_requests: FnvHashMap::default(),
in_flight_requests: in_flight_requests::InFlightRequests::default(),
ghost: PhantomData,
}
}
@@ -176,35 +173,6 @@ where
}
}
impl<Req, Resp, T> BaseChannel<Req, Resp, T> {
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
// It's possible the request was already completed, so it's fine
// if this is None.
if let Some(cancel_handle) = self
.as_mut()
.project()
.in_flight_requests
.remove(&request_id)
{
self.as_mut().project().in_flight_requests.compact(0.1);
cancel_handle.abort();
let remaining = self.as_mut().project().in_flight_requests.len();
trace!(
"[{}] Request canceled. In-flight requests = {}",
trace_context.trace_id,
remaining,
);
} else {
trace!(
"[{}] Received cancellation, but response handler \
is already complete.",
trace_context.trace_id,
);
}
}
}
impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "BaseChannel")
@@ -260,7 +228,14 @@ where
/// Tells the Channel that request with ID `request_id` is being handled.
/// The request will be tracked until a response with the same ID is sent
/// to the Channel.
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration;
fn start_request(
self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, in_flight_requests::AlreadyExistsError>;
/// Yields a request that has expired, aborting any ongoing processing of that request.
fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo<u64>;
/// Returns a stream of requests that automatically handle request cancellation and response
/// routing.
@@ -312,7 +287,25 @@ where
trace_context,
request_id,
} => {
self.as_mut().cancel_request(&trace_context, request_id);
if self
.as_mut()
.project()
.in_flight_requests
.cancel_request(request_id)
{
let remaining = self.in_flight_requests.len();
trace!(
"[{}] Request canceled. In-flight requests = {}",
trace_context.trace_id,
remaining,
);
} else {
trace!(
"[{}] Received cancellation, but response handler \
is already complete.",
trace_context.trace_id,
);
}
}
},
None => return Poll::Ready(None),
@@ -332,16 +325,10 @@ where
}
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
if self
.as_mut()
self.as_mut()
.project()
.in_flight_requests
.remove(&response.request_id)
.is_some()
{
self.as_mut().project().in_flight_requests.compact(0.1);
}
.remove_request(response.request_id);
self.project().transport.start_send(response)
}
@@ -354,17 +341,6 @@ where
}
}
#[pinned_drop]
impl<Req, Resp, T> PinnedDrop for BaseChannel<Req, Resp, T> {
fn drop(mut self: Pin<&mut Self>) {
self.as_mut()
.project()
.in_flight_requests
.values()
.for_each(AbortHandle::abort);
}
}
impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
fn as_ref(&self) -> &T {
self.transport.get_ref()
@@ -386,14 +362,18 @@ where
self.in_flight_requests.len()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
let (abort_handle, abort_registration) = AbortHandle::new_pair();
assert!(self
.project()
fn start_request(
self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, in_flight_requests::AlreadyExistsError> {
self.project()
.in_flight_requests
.insert(request_id, abort_handle)
.is_none());
abort_registration
.start_request(id, deadline)
}
fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo<u64> {
self.project().in_flight_requests.poll_expired(cx)
}
}
@@ -426,16 +406,41 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<InFlightRequest<C::Req, C::Resp>> {
match ready!(self.as_mut().project().channel.poll_next(cx)?) {
Some(request) => {
let abort_registration = self.as_mut().project().channel.start_request(request.id);
Poll::Ready(Some(Ok(InFlightRequest {
request,
response_tx: self.responses_tx.clone(),
abort_registration,
})))
loop {
match ready!(self.as_mut().project().channel.poll_next(cx)?) {
Some(request) => {
trace!(
"[{}] Handling request with deadline {}.",
request.context.trace_id(),
format_rfc3339(request.context.deadline),
);
match self
.channel_pin_mut()
.start_request(request.id, request.context.deadline)
{
Ok(abort_registration) => {
return Poll::Ready(Some(Ok(InFlightRequest {
request,
response_tx: self.responses_tx.clone(),
abort_registration,
})))
}
// Instead of closing the channel if a duplicate request is sent, just
// ignore it, since it's already being processed. Note that we cannot
// return Poll::Pending here, since nothing has scheduled a wakeup yet.
Err(in_flight_requests::AlreadyExistsError) => {
info!(
"[{}] Request ID {} delivered more than once.",
request.context.trace_id(),
request.id
);
continue;
}
}
}
None => return Poll::Ready(None),
}
None => Poll::Ready(None),
}
}
@@ -444,6 +449,17 @@ where
cx: &mut Context<'_>,
read_half_closed: bool,
) -> PollIo<()> {
if let Poll::Ready(Some(request_id)) = self.channel_pin_mut().poll_expired(cx)? {
debug!("Request {} did not complete before deadline", request_id);
self.channel_pin_mut().start_send(Response {
request_id,
message: Err(ServerError {
kind: io::ErrorKind::TimedOut,
detail: Some(format!("Request did not complete before deadline.")),
}),
})?;
return Poll::Ready(Some(Ok(())));
}
match self.as_mut().poll_next_response(cx)? {
Poll::Ready(Some((context, response))) => {
trace!(
@@ -451,6 +467,12 @@ where
context.trace_id(),
self.channel.in_flight_requests(),
);
// TODO: it's possible for poll_flush to be starved and start_send to end up full.
// Currently that would cause the channel to shut down. serde_transport internally
// uses tokio-util Framed, which will allocate as much as needed. But other
// transports may work differently.
//
// There should be a way to know if a flush is needed soon.
self.channel_pin_mut().start_send(response)?;
Poll::Ready(Some(Ok(())))
}
@@ -543,39 +565,10 @@ impl<Req, Res> InFlightRequest<Req, Res> {
message,
id: request_id,
} = request;
let trace_id = *request.context.trace_id();
let deadline = request.context.deadline;
let timeout = deadline.time_until();
trace!(
"[{}] Handling request with deadline {} (timeout {:?}).",
trace_id,
format_rfc3339(deadline),
timeout,
);
let result =
tokio::time::timeout(timeout, async { serve.serve(context, message).await })
.await;
let response = serve.serve(context, message).await;
let response = Response {
request_id,
message: match result {
Ok(message) => Ok(message),
Err(tokio::time::error::Elapsed { .. }) => {
debug!(
"[{}] Response did not complete before deadline of {}s.",
trace_id,
format_rfc3339(deadline)
);
// No point in responding, since the client will have dropped the
// request.
Err(ServerError {
kind: io::ErrorKind::TimedOut,
detail: Some(format!(
"Response did not complete before deadline of {}s.",
format_rfc3339(deadline)
)),
})
}
},
message: Ok(response),
};
let _ = response_tx.send((context, response)).await;
},
@@ -687,7 +680,7 @@ where
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
tokio::spawn(channel.execute(self.serve.clone()));
}
log::info!("Server shutting down.");
info!("Server shutting down.");
Poll::Ready(())
}
}
@@ -713,7 +706,7 @@ where
});
}
Err(e) => {
log::info!("Requests stream errored out: {}", e);
info!("Requests stream errored out: {}", e);
break;
}
}

View File

@@ -7,6 +7,7 @@
use crate::{
server::{self, Channel},
util::Compact,
PollIo,
};
use fnv::FnvHashMap;
use futures::{channel::mpsc, future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
@@ -15,6 +16,7 @@ use pin_project::pin_project;
use std::sync::{Arc, Weak};
use std::{
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
time::SystemTime,
};
/// A single-threaded filter that drops channels based on per-key limits.
@@ -112,8 +114,16 @@ where
self.inner.in_flight_requests()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
self.project().inner.start_request(request_id)
fn start_request(
self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
self.project().inner.start_request(id, deadline)
}
fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo<u64> {
self.project().inner.poll_expired(cx)
}
}

View File

@@ -0,0 +1,193 @@
use crate::{
util::{Compact, TimeUntil},
PollIo,
};
use fnv::FnvHashMap;
use futures::{
future::{AbortHandle, AbortRegistration},
ready,
};
use std::{
collections::hash_map,
io,
task::{Context, Poll},
time::SystemTime,
};
use tokio_util::time::delay_queue::{self, DelayQueue};
/// A data structure that tracks in-flight requests. It aborts requests,
/// either on demand or when a request deadline expires.
#[derive(Debug, Default)]
pub struct InFlightRequests {
request_data: FnvHashMap<u64, RequestData>,
deadlines: DelayQueue<u64>,
}
#[derive(Debug)]
/// Data needed to clean up a single in-flight request.
struct RequestData {
/// Aborts the response handler for the associated request.
abort_handle: AbortHandle,
/// The key to remove the timer for the request's deadline.
deadline_key: delay_queue::Key,
}
/// An error returned when a request attempted to start with the same ID as a request already
/// in flight.
#[derive(Debug)]
pub struct AlreadyExistsError;
impl InFlightRequests {
pub fn len(&self) -> usize {
self.request_data.len()
}
/// Starts a request, unless a request with the same ID is already in flight.
pub fn start_request(
&mut self,
request_id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, AlreadyExistsError> {
let timeout = deadline.time_until();
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let deadline_key = self.deadlines.insert(request_id, timeout);
match self.request_data.entry(request_id) {
hash_map::Entry::Vacant(vacant) => {
vacant.insert(RequestData {
abort_handle,
deadline_key,
});
Ok(abort_registration)
}
hash_map::Entry::Occupied(_) => {
self.deadlines.remove(&deadline_key);
Err(AlreadyExistsError)
}
}
}
/// Cancels an in-flight request. Returns true iff the request was found.
pub fn cancel_request(&mut self, request_id: u64) -> bool {
if let Some(request_data) = self.request_data.remove(&request_id) {
self.request_data.compact(0.1);
request_data.abort_handle.abort();
self.deadlines.remove(&request_data.deadline_key);
true
} else {
false
}
}
/// Removes a request without aborting. Returns true iff the request was found.
pub fn remove_request(&mut self, request_id: u64) -> bool {
if let Some(request_data) = self.request_data.remove(&request_id) {
self.request_data.compact(0.1);
self.deadlines.remove(&request_data.deadline_key);
true
} else {
false
}
}
/// Yields a request that has expired, aborting any ongoing processing of that request.
pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo<u64> {
Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) {
Some(Ok(expired)) => {
if let Some(request_data) = self.request_data.remove(expired.get_ref()) {
self.request_data.compact(0.1);
request_data.abort_handle.abort();
}
Some(Ok(expired.into_inner()))
}
Some(Err(e)) => Some(Err(io::Error::new(io::ErrorKind::Other, e))),
None => None,
})
}
}
/// When InFlightRequests is dropped, any requests still in flight are aborted.
impl Drop for InFlightRequests {
fn drop(self: &mut Self) {
self.request_data
.values()
.for_each(|request_data| request_data.abort_handle.abort())
}
}
#[cfg(test)]
use {
assert_matches::assert_matches,
futures::{
future::{pending, Abortable},
FutureExt,
},
futures_test::task::noop_context,
};
#[tokio::test]
async fn start_request_increases_len() {
let mut in_flight_requests = InFlightRequests::default();
assert_eq!(in_flight_requests.len(), 0);
in_flight_requests
.start_request(0, SystemTime::now())
.unwrap();
assert_eq!(in_flight_requests.len(), 1);
}
#[tokio::test]
async fn polling_expired_aborts() {
let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests
.start_request(0, SystemTime::now())
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
tokio::time::pause();
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Ready(Some(Ok(_)))
);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Ready(Err(_))
);
assert_eq!(in_flight_requests.len(), 0);
}
#[tokio::test]
async fn cancel_request_aborts() {
let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests
.start_request(0, SystemTime::now())
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
assert_eq!(in_flight_requests.cancel_request(0), true);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Ready(Err(_))
);
assert_eq!(in_flight_requests.len(), 0);
}
#[tokio::test]
async fn remove_request_doesnt_abort() {
let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests
.start_request(0, SystemTime::now())
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
assert_eq!(in_flight_requests.remove_request(0), true);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Pending
);
assert_eq!(in_flight_requests.len(), 0);
}

View File

@@ -4,14 +4,12 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use crate::server::{Channel, Config};
use crate::{context, Request, Response};
use fnv::FnvHashSet;
use futures::{
future::{AbortHandle, AbortRegistration},
task::*,
Sink, Stream,
use crate::{
context,
server::{Channel, Config},
PollIo, Request, Response,
};
use futures::{future::AbortRegistration, task::*, Sink, Stream};
use pin_project::pin_project;
use std::collections::VecDeque;
use std::io;
@@ -25,7 +23,7 @@ pub(crate) struct FakeChannel<In, Out> {
#[pin]
pub sink: VecDeque<Out>,
pub config: Config,
pub in_flight_requests: FnvHashSet<u64>,
pub in_flight_requests: super::in_flight_requests::InFlightRequests,
}
impl<In, Out> Stream for FakeChannel<In, Out>
@@ -50,7 +48,7 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
self.as_mut()
.project()
.in_flight_requests
.remove(&response.request_id);
.remove_request(response.request_id);
self.project()
.sink
.start_send(response)
@@ -81,9 +79,18 @@ where
self.in_flight_requests.len()
}
fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration {
self.project().in_flight_requests.insert(id);
AbortHandle::new_pair().1
fn start_request(
self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
self.project()
.in_flight_requests
.start_request(id, deadline)
}
fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo<u64> {
self.project().in_flight_requests.poll_expired(cx)
}
}

View File

@@ -5,11 +5,11 @@
// https://opensource.org/licenses/MIT.
use super::{Channel, Config};
use crate::{Response, ServerError};
use crate::{PollIo, Response, ServerError};
use futures::{future::AbortRegistration, prelude::*, ready, task::*};
use log::debug;
use pin_project::pin_project;
use std::{io, pin::Pin};
use std::{io, pin::Pin, time::SystemTime};
/// A [`Channel`] that limits the number of concurrent
/// requests by throttling.
@@ -121,8 +121,16 @@ where
self.inner.config()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
self.project().inner.start_request(request_id)
fn start_request(
self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
self.project().inner.start_request(id, deadline)
}
fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo<u64> {
self.project().inner.poll_expired(cx)
}
}
@@ -173,10 +181,10 @@ use crate::Request;
#[cfg(test)]
use pin_utils::pin_mut;
#[cfg(test)]
use std::marker::PhantomData;
use std::{marker::PhantomData, time::Duration};
#[test]
fn throttler_in_flight_requests() {
#[tokio::test]
async fn throttler_in_flight_requests() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
@@ -184,20 +192,27 @@ fn throttler_in_flight_requests() {
pin_mut!(throttler);
for i in 0..5 {
throttler.inner.in_flight_requests.insert(i);
throttler
.inner
.in_flight_requests
.start_request(i, SystemTime::now() + Duration::from_secs(1))
.unwrap();
}
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
}
#[test]
fn throttler_start_request() {
#[tokio::test]
async fn throttler_start_request() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.as_mut().start_request(1);
throttler
.as_mut()
.start_request(1, SystemTime::now() + Duration::from_secs(1))
.unwrap();
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
}
@@ -295,21 +310,32 @@ fn throttler_poll_next_throttled_sink_not_ready() {
fn in_flight_requests(&self) -> usize {
0
}
fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration {
fn start_request(
self: Pin<&mut Self>,
_id: u64,
_deadline: SystemTime,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
unimplemented!()
}
fn poll_expired(self: Pin<&mut Self>, _cx: &mut Context) -> PollIo<u64> {
unimplemented!()
}
}
}
#[test]
fn throttler_start_send() {
#[tokio::test]
async fn throttler_start_send() {
let throttler = Throttler {
max_in_flight_requests: 0,
inner: FakeChannel::default::<isize, isize>(),
};
pin_mut!(throttler);
throttler.inner.in_flight_requests.insert(0);
throttler
.inner
.in_flight_requests
.start_request(0, SystemTime::now() + Duration::from_secs(1))
.unwrap();
throttler
.as_mut()
.start_send(Response {
@@ -317,7 +343,7 @@ fn throttler_start_send() {
message: Ok(1),
})
.unwrap();
assert!(throttler.inner.in_flight_requests.is_empty());
assert_eq!(throttler.inner.in_flight_requests.len(), 0);
assert_eq!(
throttler.inner.sink.get(0),
Some(&Response {