From 104dd71bba9f409daaae675deccd64cbb7d98964 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Tue, 7 Jun 2022 00:36:12 -0700 Subject: [PATCH] Clean up Channel request data more reliably. When an InFlightRequest is dropped before response completion, request data in the Channel persists until either the request expires or the client cancels the request. In rare cases, requests with very large deadlines could clog up the Channel long after request processing ceases. This commit adds a drop hook to InFlightRequest so that if it is dropped before execution completes, a cancellation message is sent to the Channel so that it can clean up the associated request data. This only works for when using `InFlightRequest::execute` or `Channel::execute`. However, users of raw `Channel` have access to the `RequestCancellation` handle via `Channel::request_cancellation`, so they can implement a similar method if they wish to manually clean up request data. Note that once a Channel's request data is cleaned up, that request can never be responded to, even if a response is produced afterward. Fixes https://github.com/google/tarpc/issues/314 --- tarpc/src/client.rs | 8 +- tarpc/src/server.rs | 184 +++++++++++++++--- tarpc/src/server/limits/channels_per_key.rs | 5 + .../src/server/limits/requests_per_channel.rs | 8 + tarpc/src/server/testing.rs | 10 + 5 files changed, 184 insertions(+), 31 deletions(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 60b34c2..df2df47 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -827,13 +827,13 @@ mod tests { request: request.to_string(), response_completion, }; - channel.to_dispatch.send(request).await.unwrap(); - - ResponseGuard { + let response_guard = ResponseGuard { response, cancellation: &channel.cancellation, request_id, - } + }; + channel.to_dispatch.send(request).await.unwrap(); + response_guard } async fn send_response( diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index afb1861..52bcd4c 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -7,6 +7,7 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. use crate::{ + cancellations::{cancellations, CanceledRequests, RequestCancellation}, context::{self, SpanExt}, trace, ClientMessage, Request, Response, Transport, }; @@ -20,7 +21,7 @@ use futures::{ }; use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; -use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin}; +use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, mem, pin::Pin}; use tracing::{info_span, instrument::Instrument, Span}; mod in_flight_requests; @@ -111,6 +112,11 @@ pub struct BaseChannel { /// Writes responses to the wire and reads requests off the wire. #[pin] transport: Fuse, + /// In-flight requests that were dropped by the server before completion. + #[pin] + canceled_requests: CanceledRequests, + /// Notifies `canceled_requests` when a request is canceled. + request_cancellation: RequestCancellation, /// Holds data necessary to clean up in-flight requests. in_flight_requests: InFlightRequests, /// Types the request and response. @@ -123,9 +129,12 @@ where { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { + let (request_cancellation, canceled_requests) = cancellations(); BaseChannel { config, transport: transport.fuse(), + canceled_requests, + request_cancellation, in_flight_requests: InFlightRequests::default(), ghost: PhantomData, } @@ -150,12 +159,18 @@ where self.as_mut().project().in_flight_requests } + fn canceled_requests_pin_mut<'a>( + self: &'a mut Pin<&mut Self>, + ) -> Pin<&'a mut CanceledRequests> { + self.as_mut().project().canceled_requests + } + fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse> { self.as_mut().project().transport } fn start_request( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, mut request: Request, ) -> Result, AlreadyExistsError> { let span = info_span!( @@ -175,7 +190,7 @@ where }); let entered = span.enter(); tracing::info!("ReceiveRequest"); - let start = self.project().in_flight_requests.start_request( + let start = self.in_flight_requests_mut().start_request( request.id, request.context.deadline, span.clone(), @@ -260,6 +275,12 @@ where /// Returns the transport underlying the channel. fn transport(&self) -> &Self::Transport; + /// Returns a reference to the channel's request cancellation channel, which can be used to + /// clean up request data when request processing ends prematurely. + /// + /// Once request data is cleaned up, a response cannot be sent back to the client. + fn request_cancellation(&self) -> &RequestCancellation; + /// Caps the number of concurrent requests to `limit`. An error will be returned for requests /// over the concurrency limit. /// @@ -334,15 +355,44 @@ where type Item = Result, ChannelError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - #[derive(Debug)] + #[derive(Clone, Copy, Debug)] enum ReceiverStatus { Ready, Pending, Closed, } + + impl ReceiverStatus { + fn combine(self, other: Self) -> Self { + use ReceiverStatus::*; + match (self, other) { + (Ready, _) | (_, Ready) => Ready, + (Closed, Closed) => Closed, + (Pending, Closed) | (Closed, Pending) | (Pending, Pending) => Pending, + } + } + } + use ReceiverStatus::*; loop { + let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) { + Poll::Ready(Some(request_id)) => { + if let Some(span) = self.in_flight_requests_mut().remove_request(request_id) { + let _entered = span.enter(); + tracing::info!("ResponseCancelled"); + } + Ready + } + // Pending cancellations don't block Channel closure, because all they do is ensure + // the Channel's internal state is cleaned up. But Channel closure also cleans up + // the Channel state, so there's no reason to wait on a cancellation before + // closing. + // + // Ready(None) can't happen, since `self` holds a Cancellation. + Poll::Pending | Poll::Ready(None) => Closed, + }; + let expiration_status = match self .in_flight_requests_mut() .poll_expired(cx) @@ -395,10 +445,13 @@ where expiration_status, request_status ); - match (expiration_status, request_status) { - (Ready, _) | (_, Ready) => continue, - (Closed, Closed) => return Poll::Ready(None), - (Pending, Closed) | (Closed, Pending) | (Pending, Pending) => return Poll::Pending, + match cancellation_status + .combine(expiration_status) + .combine(request_status) + { + Ready => continue, + Closed => return Poll::Ready(None), + Pending => return Poll::Pending, } } } @@ -420,9 +473,7 @@ where fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { if let Some(span) = self - .as_mut() - .project() - .in_flight_requests + .in_flight_requests_mut() .remove_request(response.request_id) { let _entered = span.enter(); @@ -478,6 +529,10 @@ where fn transport(&self) -> &Self::Transport { self.get_ref() } + + fn request_cancellation(&self) -> &RequestCancellation { + &self.request_cancellation + } } /// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so @@ -499,6 +554,11 @@ impl Requests where C: Channel, { + /// Returns a reference to the inner channel over which messages are sent and received. + pub fn channel(&self) -> &C { + &self.channel + } + /// Returns the inner channel over which messages are sent and received. pub fn channel_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> { self.as_mut().project().channel @@ -515,12 +575,19 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, C::Error>>> { - self.channel_pin_mut() - .poll_next(cx) - .map_ok(|request| InFlightRequest { - request, + self.channel_pin_mut().poll_next(cx).map_ok(|request| { + let request_id = request.request.id; + InFlightRequest { + request: request.request, + abort_registration: request.abort_registration, + span: request.span, response_tx: self.responses_tx.clone(), - }) + response_guard: ResponseGuard { + request_id, + request_cancellation: self.channel.request_cancellation().clone(), + }, + } + }) } fn pump_write( @@ -597,17 +664,37 @@ where } } +/// A fail-safe to ensure requests are properly canceled if an InFlightRequest is dropped before +/// completing. +#[derive(Debug)] +struct ResponseGuard { + request_cancellation: RequestCancellation, + request_id: u64, +} + +impl Drop for ResponseGuard { + fn drop(&mut self) { + self.request_cancellation.cancel(self.request_id); + } +} + /// A request produced by [Channel::requests]. +/// +/// If dropped without calling [`execute`](InFlightRequest::execute), a cancellation message will +/// be sent to the Channel to clean up associated request state. #[derive(Debug)] pub struct InFlightRequest { - request: TrackedRequest, + request: Request, + abort_registration: AbortRegistration, + response_guard: ResponseGuard, + span: Span, response_tx: mpsc::Sender>, } impl InFlightRequest { /// Returns a reference to the request. pub fn get(&self) -> &Request { - &self.request.request + &self.request } /// Returns a [future](Future) that executes the request using the given [service @@ -621,22 +708,23 @@ impl InFlightRequest { /// message](ClientMessage::Cancel) for this request. /// 2. The request [deadline](crate::context::Context::deadline) is reached. /// 3. The service function completes. + /// + /// If the returned Future is dropped before completion, a cancellation message will be sent to + /// the Channel to clean up associated request state. pub async fn execute(self, serve: S) where S: Serve, { let Self { response_tx, + response_guard, + abort_registration, + span, request: - TrackedRequest { - abort_registration, - span, - request: - Request { - context, - message, - id: request_id, - }, + Request { + context, + message, + id: request_id, }, } = self; let method = serve.method(&message); @@ -657,6 +745,10 @@ impl InFlightRequest { ) .instrument(span) .await; + // Request processing has completed, meaning either the channel canceled the request or + // a request was sent back to the channel. Either way, the channel will clean up the + // request data, so the request does not need to be canceled. + mem::forget(response_guard); } } @@ -932,6 +1024,44 @@ mod tests { assert_eq!(channel.in_flight_requests(), 0); } + #[tokio::test] + async fn in_flight_request_drop_cancels_request() { + let (mut requests, mut tx) = test_requests::<(), ()>(); + tx.send(fake_request(())).await.unwrap(); + + let request = match requests.as_mut().poll_next(&mut noop_context()) { + Poll::Ready(Some(Ok(request))) => request, + result => panic!("Unexpected result: {:?}", result), + }; + drop(request); + + let poll = requests + .as_mut() + .channel_pin_mut() + .poll_next(&mut noop_context()); + assert!(poll.is_pending()); + let in_flight_requests = requests.channel().in_flight_requests(); + assert_eq!(in_flight_requests, 0); + } + + #[tokio::test] + async fn in_flight_requests_successful_execute_doesnt_cancel_request() { + let (mut requests, mut tx) = test_requests::<(), ()>(); + tx.send(fake_request(())).await.unwrap(); + + let request = match requests.as_mut().poll_next(&mut noop_context()) { + Poll::Ready(Some(Ok(request))) => request, + result => panic!("Unexpected result: {:?}", result), + }; + request.execute(|_, _| async {}).await; + assert!(requests + .as_mut() + .channel_pin_mut() + .canceled_requests + .poll_recv(&mut noop_context()) + .is_pending()); + } + #[tokio::test] async fn requests_poll_next_response_returns_pending_when_buffer_full() { let (mut requests, _tx) = test_bounded_requests::<(), ()>(0); diff --git a/tarpc/src/server/limits/channels_per_key.rs b/tarpc/src/server/limits/channels_per_key.rs index 272dd56..c0ddebc 100644 --- a/tarpc/src/server/limits/channels_per_key.rs +++ b/tarpc/src/server/limits/channels_per_key.rs @@ -5,6 +5,7 @@ // https://opensource.org/licenses/MIT. use crate::{ + cancellations::RequestCancellation, server::{self, Channel}, util::Compact, }; @@ -119,6 +120,10 @@ where fn transport(&self) -> &Self::Transport { self.inner.transport() } + + fn request_cancellation(&self) -> &RequestCancellation { + self.inner.request_cancellation() + } } impl TrackedChannel { diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 3c29836..9929e9b 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -5,6 +5,7 @@ // https://opensource.org/licenses/MIT. use crate::{ + cancellations::RequestCancellation, server::{Channel, Config}, Response, ServerError, }; @@ -131,6 +132,10 @@ where fn transport(&self) -> &Self::Transport { self.inner.transport() } + + fn request_cancellation(&self) -> &RequestCancellation { + self.inner.request_cancellation() + } } /// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on @@ -310,6 +315,9 @@ mod tests { fn transport(&self) -> &() { &() } + fn request_cancellation(&self) -> &RequestCancellation { + unreachable!() + } } } diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 709c90c..ad64bc4 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -5,6 +5,7 @@ // https://opensource.org/licenses/MIT. use crate::{ + cancellations::{cancellations, CanceledRequests, RequestCancellation}, context, server::{Channel, Config, TrackedRequest}, Request, Response, @@ -22,6 +23,8 @@ pub(crate) struct FakeChannel { pub sink: VecDeque, pub config: Config, pub in_flight_requests: super::in_flight_requests::InFlightRequests, + pub request_cancellation: RequestCancellation, + pub canceled_requests: CanceledRequests, } impl Stream for FakeChannel @@ -81,6 +84,10 @@ where fn transport(&self) -> &() { &() } + + fn request_cancellation(&self) -> &RequestCancellation { + &self.request_cancellation + } } impl FakeChannel>, Response> { @@ -103,11 +110,14 @@ impl FakeChannel>, Response> { impl FakeChannel<(), ()> { pub fn default() -> FakeChannel>, Response> { + let (request_cancellation, canceled_requests) = cancellations(); FakeChannel { stream: Default::default(), sink: Default::default(), config: Default::default(), in_flight_requests: Default::default(), + request_cancellation, + canceled_requests, } } }