From 68863e3db0ca75f2860af1a1ec26d773f88974c3 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Fri, 12 Aug 2022 13:09:50 -0700 Subject: [PATCH] Remove Channel::request_cancellation. This trait fn returns a private type, which means it's useless for anyone using the Channel. Instead, add an inert (now-public) ResponseGuard to TrackedRequest that, when taken out of the ManuallyDrop, ensures a Channel's request state is cleaned up. It's preferable to make ResponseGuard public instead of RequestCancellations because it's a smaller API surface (no public methods, just a Drop fn) and harder to misuse, because it is already associated with the correct request ID to cancel. --- tarpc/src/cancellations.rs | 8 ++ tarpc/src/server.rs | 76 ++++++++++--------- tarpc/src/server/limits/channels_per_key.rs | 5 -- .../src/server/limits/requests_per_channel.rs | 8 -- tarpc/src/server/testing.rs | 13 ++-- 5 files changed, 57 insertions(+), 53 deletions(-) diff --git a/tarpc/src/cancellations.rs b/tarpc/src/cancellations.rs index 6d6d684..631c7b1 100644 --- a/tarpc/src/cancellations.rs +++ b/tarpc/src/cancellations.rs @@ -20,6 +20,14 @@ pub fn cancellations() -> (RequestCancellation, CanceledRequests) { impl RequestCancellation { /// Cancels the request with ID `request_id`. + /// + /// No validation is done of `request_id`. There is no way to know if the request id provided + /// corresponds to a request actually tracked by the backing channel. `RequestCancellation` is + /// a one-way communication channel. + /// + /// Once request data is cleaned up, a response will never be received by the client. This is + /// useful primarily when request processing ends prematurely for requests with long deadlines + /// which would otherwise continue to be tracked by the backing channel—a kind of leak. pub fn cancel(&self, request_id: u64) { let _ = self.0.send(request_id); } diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 3a0aae5..7cf6a95 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -21,7 +21,14 @@ use futures::{ }; use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; -use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, mem, pin::Pin}; +use std::{ + convert::TryFrom, + error::Error, + fmt, + marker::PhantomData, + mem::{self, ManuallyDrop}, + pin::Pin, +}; use tracing::{info_span, instrument::Instrument, Span}; mod in_flight_requests; @@ -199,9 +206,13 @@ where Ok(abort_registration) => { drop(entered); Ok(TrackedRequest { - request, abort_registration, span, + response_guard: ManuallyDrop::new(ResponseGuard { + request_id: request.id, + request_cancellation: self.request_cancellation.clone(), + }), + request, }) } Err(AlreadyExistsError) => { @@ -228,6 +239,8 @@ pub struct TrackedRequest { pub abort_registration: AbortRegistration, /// A span representing the server processing of this request. pub span: Span, + /// An inert response guard. Becomes active in an InFlightRequest. + pub response_guard: ManuallyDrop, } /// The server end of an open connection with a client, receiving requests from, and sending @@ -246,13 +259,15 @@ pub struct TrackedRequest { /// [`Sink::send`](futures::sink::SinkExt::send) - A user is free to manually read requests /// from, and send responses into, a Channel in lieu of the previous methods. Channels stream /// [`TrackedRequests`](TrackedRequest), which, in addition to the request itself, contains the -/// server [`Span`] and request lifetime [`AbortRegistration`]. Wrapping response -/// logic in an [`Abortable`] future using the abort registration will ensure that the response -/// does not execute longer than the request deadline. The `Channel` itself will clean up -/// request state once either the deadline expires, or a cancellation message is received, or a -/// response is sent. Because there is no guarantee that a cancellation message will ever be -/// received for a request, or that requests come with reasonably short deadlines, services -/// should strive to clean up Channel resources by sending a response for every request. +/// server [`Span`], request lifetime [`AbortRegistration`], and an inert [`ResponseGuard`]. +/// Wrapping response logic in an [`Abortable`] future using the abort registration will ensure +/// that the response does not execute longer than the request deadline. The `Channel` itself +/// will clean up request state once either the deadline expires, or the response guard is +/// dropped, or a response is sent. +/// +/// Channels must be implemented using the decorator pattern: the only way to create a +/// `TrackedRequest` is to get one from another `Channel`. Ultimately, all `TrackedRequests` are +/// created by [`BaseChannel`]. pub trait Channel where Self: Transport::Resp>, TrackedRequest<::Req>>, @@ -275,12 +290,6 @@ where /// Returns the transport underlying the channel. fn transport(&self) -> &Self::Transport; - /// Returns a reference to the channel's request cancellation channel, which can be used to - /// clean up request data when request processing ends prematurely. - /// - /// Once request data is cleaned up, a response cannot be sent back to the client. - fn request_cancellation(&self) -> &RequestCancellation; - /// Caps the number of concurrent requests to `limit`. An error will be returned for requests /// over the concurrency limit. /// @@ -525,10 +534,6 @@ where fn transport(&self) -> &Self::Transport { self.get_ref() } - - fn request_cancellation(&self) -> &RequestCancellation { - &self.request_cancellation - } } /// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so @@ -571,19 +576,22 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, C::Error>>> { - self.channel_pin_mut().poll_next(cx).map_ok(|request| { - let request_id = request.request.id; - InFlightRequest { - request: request.request, - abort_registration: request.abort_registration, - span: request.span, - response_tx: self.responses_tx.clone(), - response_guard: ResponseGuard { - request_id, - request_cancellation: self.channel.request_cancellation().clone(), - }, - } - }) + self.channel_pin_mut().poll_next(cx).map_ok( + |TrackedRequest { + request, + abort_registration, + span, + response_guard, + }| { + InFlightRequest { + request, + abort_registration, + span, + response_guard: ManuallyDrop::into_inner(response_guard), + response_tx: self.responses_tx.clone(), + } + }, + ) } fn pump_write( @@ -660,10 +668,10 @@ where } } -/// A fail-safe to ensure requests are properly canceled if an InFlightRequest is dropped before +/// A fail-safe to ensure requests are properly canceled if request processing is aborted before /// completing. #[derive(Debug)] -struct ResponseGuard { +pub struct ResponseGuard { request_cancellation: RequestCancellation, request_id: u64, } diff --git a/tarpc/src/server/limits/channels_per_key.rs b/tarpc/src/server/limits/channels_per_key.rs index c0ddebc..272dd56 100644 --- a/tarpc/src/server/limits/channels_per_key.rs +++ b/tarpc/src/server/limits/channels_per_key.rs @@ -5,7 +5,6 @@ // https://opensource.org/licenses/MIT. use crate::{ - cancellations::RequestCancellation, server::{self, Channel}, util::Compact, }; @@ -120,10 +119,6 @@ where fn transport(&self) -> &Self::Transport { self.inner.transport() } - - fn request_cancellation(&self) -> &RequestCancellation { - self.inner.request_cancellation() - } } impl TrackedChannel { diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 9929e9b..3c29836 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -5,7 +5,6 @@ // https://opensource.org/licenses/MIT. use crate::{ - cancellations::RequestCancellation, server::{Channel, Config}, Response, ServerError, }; @@ -132,10 +131,6 @@ where fn transport(&self) -> &Self::Transport { self.inner.transport() } - - fn request_cancellation(&self) -> &RequestCancellation { - self.inner.request_cancellation() - } } /// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on @@ -315,9 +310,6 @@ mod tests { fn transport(&self) -> &() { &() } - fn request_cancellation(&self) -> &RequestCancellation { - unreachable!() - } } } diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index ad64bc4..1c683da 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -7,12 +7,12 @@ use crate::{ cancellations::{cancellations, CanceledRequests, RequestCancellation}, context, - server::{Channel, Config, TrackedRequest}, + server::{Channel, Config, ResponseGuard, TrackedRequest}, Request, Response, }; use futures::{task::*, Sink, Stream}; use pin_project::pin_project; -use std::{collections::VecDeque, io, pin::Pin, time::SystemTime}; +use std::{collections::VecDeque, io, mem::ManuallyDrop, pin::Pin, time::SystemTime}; use tracing::Span; #[pin_project] @@ -84,15 +84,12 @@ where fn transport(&self) -> &() { &() } - - fn request_cancellation(&self) -> &RequestCancellation { - &self.request_cancellation - } } impl FakeChannel>, Response> { pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); + let (request_cancellation, _) = cancellations(); self.stream.push_back(Ok(TrackedRequest { request: Request { context: context::Context { @@ -104,6 +101,10 @@ impl FakeChannel>, Response> { }, abort_registration, span: Span::none(), + response_guard: ManuallyDrop::new(ResponseGuard { + request_cancellation, + request_id: id, + }), })); } }