Clean up Channel request data more reliably.

When an InFlightRequest is dropped before response completion, request
data in the Channel persists until either the request expires or the
client cancels the request. In rare cases, requests with very large
deadlines could clog up the Channel long after request processing
ceases.

This commit adds a drop hook to InFlightRequest so that if it is dropped
before execution completes, a cancellation message is sent to the
Channel so that it can clean up the associated request data.

This only works for when using `InFlightRequest::execute` or
`Channel::execute`. However, users of raw `Channel` have access
to the `RequestCancellation` handle via `Channel::request_cancellation`,
so they can implement a similar method if they wish to manually clean up
request data.

Note that once a Channel's request data is cleaned up, that request can
never be responded to, even if a response is produced afterward.

Fixes https://github.com/google/tarpc/issues/314
This commit is contained in:
Tim Kuehn
2022-06-07 00:36:12 -07:00
parent 012c481861
commit 104dd71bba
5 changed files with 184 additions and 31 deletions

View File

@@ -827,13 +827,13 @@ mod tests {
request: request.to_string(),
response_completion,
};
channel.to_dispatch.send(request).await.unwrap();
ResponseGuard {
let response_guard = ResponseGuard {
response,
cancellation: &channel.cancellation,
request_id,
}
};
channel.to_dispatch.send(request).await.unwrap();
response_guard
}
async fn send_response(

View File

@@ -7,6 +7,7 @@
//! Provides a server that concurrently handles many connections sending multiplexed requests.
use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation},
context::{self, SpanExt},
trace, ClientMessage, Request, Response, Transport,
};
@@ -20,7 +21,7 @@ use futures::{
};
use in_flight_requests::{AlreadyExistsError, InFlightRequests};
use pin_project::pin_project;
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin};
use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, mem, pin::Pin};
use tracing::{info_span, instrument::Instrument, Span};
mod in_flight_requests;
@@ -111,6 +112,11 @@ pub struct BaseChannel<Req, Resp, T> {
/// Writes responses to the wire and reads requests off the wire.
#[pin]
transport: Fuse<T>,
/// In-flight requests that were dropped by the server before completion.
#[pin]
canceled_requests: CanceledRequests,
/// Notifies `canceled_requests` when a request is canceled.
request_cancellation: RequestCancellation,
/// Holds data necessary to clean up in-flight requests.
in_flight_requests: InFlightRequests,
/// Types the request and response.
@@ -123,9 +129,12 @@ where
{
/// Creates a new channel backed by `transport` and configured with `config`.
pub fn new(config: Config, transport: T) -> Self {
let (request_cancellation, canceled_requests) = cancellations();
BaseChannel {
config,
transport: transport.fuse(),
canceled_requests,
request_cancellation,
in_flight_requests: InFlightRequests::default(),
ghost: PhantomData,
}
@@ -150,12 +159,18 @@ where
self.as_mut().project().in_flight_requests
}
fn canceled_requests_pin_mut<'a>(
self: &'a mut Pin<&mut Self>,
) -> Pin<&'a mut CanceledRequests> {
self.as_mut().project().canceled_requests
}
fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<T>> {
self.as_mut().project().transport
}
fn start_request(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
mut request: Request<Req>,
) -> Result<TrackedRequest<Req>, AlreadyExistsError> {
let span = info_span!(
@@ -175,7 +190,7 @@ where
});
let entered = span.enter();
tracing::info!("ReceiveRequest");
let start = self.project().in_flight_requests.start_request(
let start = self.in_flight_requests_mut().start_request(
request.id,
request.context.deadline,
span.clone(),
@@ -260,6 +275,12 @@ where
/// Returns the transport underlying the channel.
fn transport(&self) -> &Self::Transport;
/// Returns a reference to the channel's request cancellation channel, which can be used to
/// clean up request data when request processing ends prematurely.
///
/// Once request data is cleaned up, a response cannot be sent back to the client.
fn request_cancellation(&self) -> &RequestCancellation;
/// Caps the number of concurrent requests to `limit`. An error will be returned for requests
/// over the concurrency limit.
///
@@ -334,15 +355,44 @@ where
type Item = Result<TrackedRequest<Req>, ChannelError<T::Error>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
#[derive(Debug)]
#[derive(Clone, Copy, Debug)]
enum ReceiverStatus {
Ready,
Pending,
Closed,
}
impl ReceiverStatus {
fn combine(self, other: Self) -> Self {
use ReceiverStatus::*;
match (self, other) {
(Ready, _) | (_, Ready) => Ready,
(Closed, Closed) => Closed,
(Pending, Closed) | (Closed, Pending) | (Pending, Pending) => Pending,
}
}
}
use ReceiverStatus::*;
loop {
let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) {
Poll::Ready(Some(request_id)) => {
if let Some(span) = self.in_flight_requests_mut().remove_request(request_id) {
let _entered = span.enter();
tracing::info!("ResponseCancelled");
}
Ready
}
// Pending cancellations don't block Channel closure, because all they do is ensure
// the Channel's internal state is cleaned up. But Channel closure also cleans up
// the Channel state, so there's no reason to wait on a cancellation before
// closing.
//
// Ready(None) can't happen, since `self` holds a Cancellation.
Poll::Pending | Poll::Ready(None) => Closed,
};
let expiration_status = match self
.in_flight_requests_mut()
.poll_expired(cx)
@@ -395,10 +445,13 @@ where
expiration_status,
request_status
);
match (expiration_status, request_status) {
(Ready, _) | (_, Ready) => continue,
(Closed, Closed) => return Poll::Ready(None),
(Pending, Closed) | (Closed, Pending) | (Pending, Pending) => return Poll::Pending,
match cancellation_status
.combine(expiration_status)
.combine(request_status)
{
Ready => continue,
Closed => return Poll::Ready(None),
Pending => return Poll::Pending,
}
}
}
@@ -420,9 +473,7 @@ where
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
if let Some(span) = self
.as_mut()
.project()
.in_flight_requests
.in_flight_requests_mut()
.remove_request(response.request_id)
{
let _entered = span.enter();
@@ -478,6 +529,10 @@ where
fn transport(&self) -> &Self::Transport {
self.get_ref()
}
fn request_cancellation(&self) -> &RequestCancellation {
&self.request_cancellation
}
}
/// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so
@@ -499,6 +554,11 @@ impl<C> Requests<C>
where
C: Channel,
{
/// Returns a reference to the inner channel over which messages are sent and received.
pub fn channel(&self) -> &C {
&self.channel
}
/// Returns the inner channel over which messages are sent and received.
pub fn channel_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
self.as_mut().project().channel
@@ -515,12 +575,19 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<InFlightRequest<C::Req, C::Resp>, C::Error>>> {
self.channel_pin_mut()
.poll_next(cx)
.map_ok(|request| InFlightRequest {
request,
self.channel_pin_mut().poll_next(cx).map_ok(|request| {
let request_id = request.request.id;
InFlightRequest {
request: request.request,
abort_registration: request.abort_registration,
span: request.span,
response_tx: self.responses_tx.clone(),
})
response_guard: ResponseGuard {
request_id,
request_cancellation: self.channel.request_cancellation().clone(),
},
}
})
}
fn pump_write(
@@ -597,17 +664,37 @@ where
}
}
/// A fail-safe to ensure requests are properly canceled if an InFlightRequest is dropped before
/// completing.
#[derive(Debug)]
struct ResponseGuard {
request_cancellation: RequestCancellation,
request_id: u64,
}
impl Drop for ResponseGuard {
fn drop(&mut self) {
self.request_cancellation.cancel(self.request_id);
}
}
/// A request produced by [Channel::requests].
///
/// If dropped without calling [`execute`](InFlightRequest::execute), a cancellation message will
/// be sent to the Channel to clean up associated request state.
#[derive(Debug)]
pub struct InFlightRequest<Req, Res> {
request: TrackedRequest<Req>,
request: Request<Req>,
abort_registration: AbortRegistration,
response_guard: ResponseGuard,
span: Span,
response_tx: mpsc::Sender<Response<Res>>,
}
impl<Req, Res> InFlightRequest<Req, Res> {
/// Returns a reference to the request.
pub fn get(&self) -> &Request<Req> {
&self.request.request
&self.request
}
/// Returns a [future](Future) that executes the request using the given [service
@@ -621,22 +708,23 @@ impl<Req, Res> InFlightRequest<Req, Res> {
/// message](ClientMessage::Cancel) for this request.
/// 2. The request [deadline](crate::context::Context::deadline) is reached.
/// 3. The service function completes.
///
/// If the returned Future is dropped before completion, a cancellation message will be sent to
/// the Channel to clean up associated request state.
pub async fn execute<S>(self, serve: S)
where
S: Serve<Req, Resp = Res>,
{
let Self {
response_tx,
response_guard,
abort_registration,
span,
request:
TrackedRequest {
abort_registration,
span,
request:
Request {
context,
message,
id: request_id,
},
Request {
context,
message,
id: request_id,
},
} = self;
let method = serve.method(&message);
@@ -657,6 +745,10 @@ impl<Req, Res> InFlightRequest<Req, Res> {
)
.instrument(span)
.await;
// Request processing has completed, meaning either the channel canceled the request or
// a request was sent back to the channel. Either way, the channel will clean up the
// request data, so the request does not need to be canceled.
mem::forget(response_guard);
}
}
@@ -932,6 +1024,44 @@ mod tests {
assert_eq!(channel.in_flight_requests(), 0);
}
#[tokio::test]
async fn in_flight_request_drop_cancels_request() {
let (mut requests, mut tx) = test_requests::<(), ()>();
tx.send(fake_request(())).await.unwrap();
let request = match requests.as_mut().poll_next(&mut noop_context()) {
Poll::Ready(Some(Ok(request))) => request,
result => panic!("Unexpected result: {:?}", result),
};
drop(request);
let poll = requests
.as_mut()
.channel_pin_mut()
.poll_next(&mut noop_context());
assert!(poll.is_pending());
let in_flight_requests = requests.channel().in_flight_requests();
assert_eq!(in_flight_requests, 0);
}
#[tokio::test]
async fn in_flight_requests_successful_execute_doesnt_cancel_request() {
let (mut requests, mut tx) = test_requests::<(), ()>();
tx.send(fake_request(())).await.unwrap();
let request = match requests.as_mut().poll_next(&mut noop_context()) {
Poll::Ready(Some(Ok(request))) => request,
result => panic!("Unexpected result: {:?}", result),
};
request.execute(|_, _| async {}).await;
assert!(requests
.as_mut()
.channel_pin_mut()
.canceled_requests
.poll_recv(&mut noop_context())
.is_pending());
}
#[tokio::test]
async fn requests_poll_next_response_returns_pending_when_buffer_full() {
let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);

View File

@@ -5,6 +5,7 @@
// https://opensource.org/licenses/MIT.
use crate::{
cancellations::RequestCancellation,
server::{self, Channel},
util::Compact,
};
@@ -119,6 +120,10 @@ where
fn transport(&self) -> &Self::Transport {
self.inner.transport()
}
fn request_cancellation(&self) -> &RequestCancellation {
self.inner.request_cancellation()
}
}
impl<C, K> TrackedChannel<C, K> {

View File

@@ -5,6 +5,7 @@
// https://opensource.org/licenses/MIT.
use crate::{
cancellations::RequestCancellation,
server::{Channel, Config},
Response, ServerError,
};
@@ -131,6 +132,10 @@ where
fn transport(&self) -> &Self::Transport {
self.inner.transport()
}
fn request_cancellation(&self) -> &RequestCancellation {
self.inner.request_cancellation()
}
}
/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on
@@ -310,6 +315,9 @@ mod tests {
fn transport(&self) -> &() {
&()
}
fn request_cancellation(&self) -> &RequestCancellation {
unreachable!()
}
}
}

View File

@@ -5,6 +5,7 @@
// https://opensource.org/licenses/MIT.
use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation},
context,
server::{Channel, Config, TrackedRequest},
Request, Response,
@@ -22,6 +23,8 @@ pub(crate) struct FakeChannel<In, Out> {
pub sink: VecDeque<Out>,
pub config: Config,
pub in_flight_requests: super::in_flight_requests::InFlightRequests,
pub request_cancellation: RequestCancellation,
pub canceled_requests: CanceledRequests,
}
impl<In, Out> Stream for FakeChannel<In, Out>
@@ -81,6 +84,10 @@ where
fn transport(&self) -> &() {
&()
}
fn request_cancellation(&self) -> &RequestCancellation {
&self.request_cancellation
}
}
impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
@@ -103,11 +110,14 @@ impl<Req, Resp> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
impl FakeChannel<(), ()> {
pub fn default<Req, Resp>() -> FakeChannel<io::Result<TrackedRequest<Req>>, Response<Resp>> {
let (request_cancellation, canceled_requests) = cancellations();
FakeChannel {
stream: Default::default(),
sink: Default::default(),
config: Default::default(),
in_flight_requests: Default::default(),
request_cancellation,
canceled_requests,
}
}
}