diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 60b34c2..df2df47 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -827,13 +827,13 @@ mod tests { request: request.to_string(), response_completion, }; - channel.to_dispatch.send(request).await.unwrap(); - - ResponseGuard { + let response_guard = ResponseGuard { response, cancellation: &channel.cancellation, request_id, - } + }; + channel.to_dispatch.send(request).await.unwrap(); + response_guard } async fn send_response( diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index afb1861..52bcd4c 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -7,6 +7,7 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. use crate::{ + cancellations::{cancellations, CanceledRequests, RequestCancellation}, context::{self, SpanExt}, trace, ClientMessage, Request, Response, Transport, }; @@ -20,7 +21,7 @@ use futures::{ }; use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; -use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin}; +use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, mem, pin::Pin}; use tracing::{info_span, instrument::Instrument, Span}; mod in_flight_requests; @@ -111,6 +112,11 @@ pub struct BaseChannel { /// Writes responses to the wire and reads requests off the wire. #[pin] transport: Fuse, + /// In-flight requests that were dropped by the server before completion. + #[pin] + canceled_requests: CanceledRequests, + /// Notifies `canceled_requests` when a request is canceled. + request_cancellation: RequestCancellation, /// Holds data necessary to clean up in-flight requests. in_flight_requests: InFlightRequests, /// Types the request and response. @@ -123,9 +129,12 @@ where { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { + let (request_cancellation, canceled_requests) = cancellations(); BaseChannel { config, transport: transport.fuse(), + canceled_requests, + request_cancellation, in_flight_requests: InFlightRequests::default(), ghost: PhantomData, } @@ -150,12 +159,18 @@ where self.as_mut().project().in_flight_requests } + fn canceled_requests_pin_mut<'a>( + self: &'a mut Pin<&mut Self>, + ) -> Pin<&'a mut CanceledRequests> { + self.as_mut().project().canceled_requests + } + fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse> { self.as_mut().project().transport } fn start_request( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, mut request: Request, ) -> Result, AlreadyExistsError> { let span = info_span!( @@ -175,7 +190,7 @@ where }); let entered = span.enter(); tracing::info!("ReceiveRequest"); - let start = self.project().in_flight_requests.start_request( + let start = self.in_flight_requests_mut().start_request( request.id, request.context.deadline, span.clone(), @@ -260,6 +275,12 @@ where /// Returns the transport underlying the channel. fn transport(&self) -> &Self::Transport; + /// Returns a reference to the channel's request cancellation channel, which can be used to + /// clean up request data when request processing ends prematurely. + /// + /// Once request data is cleaned up, a response cannot be sent back to the client. + fn request_cancellation(&self) -> &RequestCancellation; + /// Caps the number of concurrent requests to `limit`. An error will be returned for requests /// over the concurrency limit. /// @@ -334,15 +355,44 @@ where type Item = Result, ChannelError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - #[derive(Debug)] + #[derive(Clone, Copy, Debug)] enum ReceiverStatus { Ready, Pending, Closed, } + + impl ReceiverStatus { + fn combine(self, other: Self) -> Self { + use ReceiverStatus::*; + match (self, other) { + (Ready, _) | (_, Ready) => Ready, + (Closed, Closed) => Closed, + (Pending, Closed) | (Closed, Pending) | (Pending, Pending) => Pending, + } + } + } + use ReceiverStatus::*; loop { + let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) { + Poll::Ready(Some(request_id)) => { + if let Some(span) = self.in_flight_requests_mut().remove_request(request_id) { + let _entered = span.enter(); + tracing::info!("ResponseCancelled"); + } + Ready + } + // Pending cancellations don't block Channel closure, because all they do is ensure + // the Channel's internal state is cleaned up. But Channel closure also cleans up + // the Channel state, so there's no reason to wait on a cancellation before + // closing. + // + // Ready(None) can't happen, since `self` holds a Cancellation. + Poll::Pending | Poll::Ready(None) => Closed, + }; + let expiration_status = match self .in_flight_requests_mut() .poll_expired(cx) @@ -395,10 +445,13 @@ where expiration_status, request_status ); - match (expiration_status, request_status) { - (Ready, _) | (_, Ready) => continue, - (Closed, Closed) => return Poll::Ready(None), - (Pending, Closed) | (Closed, Pending) | (Pending, Pending) => return Poll::Pending, + match cancellation_status + .combine(expiration_status) + .combine(request_status) + { + Ready => continue, + Closed => return Poll::Ready(None), + Pending => return Poll::Pending, } } } @@ -420,9 +473,7 @@ where fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { if let Some(span) = self - .as_mut() - .project() - .in_flight_requests + .in_flight_requests_mut() .remove_request(response.request_id) { let _entered = span.enter(); @@ -478,6 +529,10 @@ where fn transport(&self) -> &Self::Transport { self.get_ref() } + + fn request_cancellation(&self) -> &RequestCancellation { + &self.request_cancellation + } } /// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so @@ -499,6 +554,11 @@ impl Requests where C: Channel, { + /// Returns a reference to the inner channel over which messages are sent and received. + pub fn channel(&self) -> &C { + &self.channel + } + /// Returns the inner channel over which messages are sent and received. pub fn channel_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> { self.as_mut().project().channel @@ -515,12 +575,19 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, C::Error>>> { - self.channel_pin_mut() - .poll_next(cx) - .map_ok(|request| InFlightRequest { - request, + self.channel_pin_mut().poll_next(cx).map_ok(|request| { + let request_id = request.request.id; + InFlightRequest { + request: request.request, + abort_registration: request.abort_registration, + span: request.span, response_tx: self.responses_tx.clone(), - }) + response_guard: ResponseGuard { + request_id, + request_cancellation: self.channel.request_cancellation().clone(), + }, + } + }) } fn pump_write( @@ -597,17 +664,37 @@ where } } +/// A fail-safe to ensure requests are properly canceled if an InFlightRequest is dropped before +/// completing. +#[derive(Debug)] +struct ResponseGuard { + request_cancellation: RequestCancellation, + request_id: u64, +} + +impl Drop for ResponseGuard { + fn drop(&mut self) { + self.request_cancellation.cancel(self.request_id); + } +} + /// A request produced by [Channel::requests]. +/// +/// If dropped without calling [`execute`](InFlightRequest::execute), a cancellation message will +/// be sent to the Channel to clean up associated request state. #[derive(Debug)] pub struct InFlightRequest { - request: TrackedRequest, + request: Request, + abort_registration: AbortRegistration, + response_guard: ResponseGuard, + span: Span, response_tx: mpsc::Sender>, } impl InFlightRequest { /// Returns a reference to the request. pub fn get(&self) -> &Request { - &self.request.request + &self.request } /// Returns a [future](Future) that executes the request using the given [service @@ -621,22 +708,23 @@ impl InFlightRequest { /// message](ClientMessage::Cancel) for this request. /// 2. The request [deadline](crate::context::Context::deadline) is reached. /// 3. The service function completes. + /// + /// If the returned Future is dropped before completion, a cancellation message will be sent to + /// the Channel to clean up associated request state. pub async fn execute(self, serve: S) where S: Serve, { let Self { response_tx, + response_guard, + abort_registration, + span, request: - TrackedRequest { - abort_registration, - span, - request: - Request { - context, - message, - id: request_id, - }, + Request { + context, + message, + id: request_id, }, } = self; let method = serve.method(&message); @@ -657,6 +745,10 @@ impl InFlightRequest { ) .instrument(span) .await; + // Request processing has completed, meaning either the channel canceled the request or + // a request was sent back to the channel. Either way, the channel will clean up the + // request data, so the request does not need to be canceled. + mem::forget(response_guard); } } @@ -932,6 +1024,44 @@ mod tests { assert_eq!(channel.in_flight_requests(), 0); } + #[tokio::test] + async fn in_flight_request_drop_cancels_request() { + let (mut requests, mut tx) = test_requests::<(), ()>(); + tx.send(fake_request(())).await.unwrap(); + + let request = match requests.as_mut().poll_next(&mut noop_context()) { + Poll::Ready(Some(Ok(request))) => request, + result => panic!("Unexpected result: {:?}", result), + }; + drop(request); + + let poll = requests + .as_mut() + .channel_pin_mut() + .poll_next(&mut noop_context()); + assert!(poll.is_pending()); + let in_flight_requests = requests.channel().in_flight_requests(); + assert_eq!(in_flight_requests, 0); + } + + #[tokio::test] + async fn in_flight_requests_successful_execute_doesnt_cancel_request() { + let (mut requests, mut tx) = test_requests::<(), ()>(); + tx.send(fake_request(())).await.unwrap(); + + let request = match requests.as_mut().poll_next(&mut noop_context()) { + Poll::Ready(Some(Ok(request))) => request, + result => panic!("Unexpected result: {:?}", result), + }; + request.execute(|_, _| async {}).await; + assert!(requests + .as_mut() + .channel_pin_mut() + .canceled_requests + .poll_recv(&mut noop_context()) + .is_pending()); + } + #[tokio::test] async fn requests_poll_next_response_returns_pending_when_buffer_full() { let (mut requests, _tx) = test_bounded_requests::<(), ()>(0); diff --git a/tarpc/src/server/limits/channels_per_key.rs b/tarpc/src/server/limits/channels_per_key.rs index 272dd56..c0ddebc 100644 --- a/tarpc/src/server/limits/channels_per_key.rs +++ b/tarpc/src/server/limits/channels_per_key.rs @@ -5,6 +5,7 @@ // https://opensource.org/licenses/MIT. use crate::{ + cancellations::RequestCancellation, server::{self, Channel}, util::Compact, }; @@ -119,6 +120,10 @@ where fn transport(&self) -> &Self::Transport { self.inner.transport() } + + fn request_cancellation(&self) -> &RequestCancellation { + self.inner.request_cancellation() + } } impl TrackedChannel { diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 3c29836..9929e9b 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -5,6 +5,7 @@ // https://opensource.org/licenses/MIT. use crate::{ + cancellations::RequestCancellation, server::{Channel, Config}, Response, ServerError, }; @@ -131,6 +132,10 @@ where fn transport(&self) -> &Self::Transport { self.inner.transport() } + + fn request_cancellation(&self) -> &RequestCancellation { + self.inner.request_cancellation() + } } /// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on @@ -310,6 +315,9 @@ mod tests { fn transport(&self) -> &() { &() } + fn request_cancellation(&self) -> &RequestCancellation { + unreachable!() + } } } diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 709c90c..ad64bc4 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -5,6 +5,7 @@ // https://opensource.org/licenses/MIT. use crate::{ + cancellations::{cancellations, CanceledRequests, RequestCancellation}, context, server::{Channel, Config, TrackedRequest}, Request, Response, @@ -22,6 +23,8 @@ pub(crate) struct FakeChannel { pub sink: VecDeque, pub config: Config, pub in_flight_requests: super::in_flight_requests::InFlightRequests, + pub request_cancellation: RequestCancellation, + pub canceled_requests: CanceledRequests, } impl Stream for FakeChannel @@ -81,6 +84,10 @@ where fn transport(&self) -> &() { &() } + + fn request_cancellation(&self) -> &RequestCancellation { + &self.request_cancellation + } } impl FakeChannel>, Response> { @@ -103,11 +110,14 @@ impl FakeChannel>, Response> { impl FakeChannel<(), ()> { pub fn default() -> FakeChannel>, Response> { + let (request_cancellation, canceled_requests) = cancellations(); FakeChannel { stream: Default::default(), sink: Default::default(), config: Default::default(), in_flight_requests: Default::default(), + request_cancellation, + canceled_requests, } } }