From 3b422eb179dd319940b659b41914ab2b08a8d973 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Sun, 24 Jan 2021 17:57:44 -0800 Subject: [PATCH] Abort all in-flight requests when dropping BaseChannel. Fixes #341 --- tarpc/src/rpc/client.rs | 4 +- tarpc/src/rpc/server.rs | 33 +++++++++++++++- tarpc/tests/service_functional.rs | 65 ++++++++++++++++++++++++++++++- 3 files changed, 97 insertions(+), 5 deletions(-) diff --git a/tarpc/src/rpc/client.rs b/tarpc/src/rpc/client.rs index 24dc91a..1f03e31 100644 --- a/tarpc/src/rpc/client.rs +++ b/tarpc/src/rpc/client.rs @@ -144,11 +144,11 @@ where #[cfg(feature = "tokio1")] #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] pub fn spawn(self) -> io::Result { - use log::error; + use log::warn; let dispatch = self .dispatch - .unwrap_or_else(move |e| error!("Connection broken: {}", e)); + .unwrap_or_else(move |e| warn!("Connection broken: {}", e)); tokio::spawn(dispatch); Ok(self.client) } diff --git a/tarpc/src/rpc/server.rs b/tarpc/src/rpc/server.rs index 6da957e..72d4cf6 100644 --- a/tarpc/src/rpc/server.rs +++ b/tarpc/src/rpc/server.rs @@ -21,7 +21,7 @@ use futures::{ }; use humantime::format_rfc3339; use log::{debug, trace}; -use pin_project::pin_project; +use pin_project::{pin_project, pinned_drop}; use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime}; use tokio::time::Timeout; @@ -171,7 +171,7 @@ where } /// BaseChannel lifts a Transport to a Channel by tracking in-flight requests. -#[pin_project] +#[pin_project(PinnedDrop)] pub struct BaseChannel { config: Config, /// Writes responses to the wire and reads requests off the wire. @@ -211,7 +211,9 @@ where pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> { self.project().transport.get_pin_mut() } +} +impl BaseChannel { fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) { // It's possible the request was already completed, so it's fine // if this is None. @@ -360,6 +362,17 @@ where } } +#[pinned_drop] +impl PinnedDrop for BaseChannel { + fn drop(mut self: Pin<&mut Self>) { + self.as_mut() + .project() + .in_flight_requests + .values() + .for_each(AbortHandle::abort); + } +} + impl AsRef for BaseChannel { fn as_ref(&self) -> &T { self.transport.get_ref() @@ -729,3 +742,19 @@ where Poll::Ready(()) } } + +#[tokio::test] +async fn abort_in_flight_requests_on_channel_drop() { + use assert_matches::assert_matches; + use futures::future::Aborted; + + let (_, server_transport) = + super::transport::channel::unbounded::, ClientMessage<()>>(); + let channel = BaseChannel::with_defaults(server_transport); + let mut channel = Box::pin(channel); + + let abort_registration = channel.as_mut().start_request(1); + let future = Abortable::new(async { () }, abort_registration); + drop(channel); + assert_matches!(future.await, Err(Aborted)); +} diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 52e88ba..18f315b 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -3,7 +3,11 @@ use futures::{ future::{join_all, ready, Ready}, prelude::*, }; -use std::io; +use std::{ + io, + sync::Arc, + time::{Duration, SystemTime}, +}; use tarpc::{ client::{self}, context, @@ -57,6 +61,65 @@ async fn sequential() -> io::Result<()> { Ok(()) } +#[tokio::test] +async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> { + #[tarpc_plugins::service] + trait Loop { + async fn r#loop(); + } + + struct LoopServer(tokio::sync::mpsc::UnboundedSender); + + #[derive(Debug)] + struct AllHandlersComplete; + + impl Drop for LoopServer { + fn drop(&mut self) { + let _ = self.0.send(AllHandlersComplete); + } + } + + #[tarpc::server] + impl Loop for Arc { + async fn r#loop(self, _: context::Context) { + loop { + futures::pending!(); + } + } + } + + let _ = env_logger::try_init(); + + let (tx, rx) = channel::unbounded(); + let (rpc_finished_tx, mut rpc_finished) = tokio::sync::mpsc::unbounded_channel(); + + // Set up a client that initiates a long-lived request. + // The request will complete in error when the server drops the connection. + tokio::spawn(async move { + let mut client = LoopClient::new(client::Config::default(), tx) + .spawn() + .unwrap(); + + let mut ctx = context::current(); + ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60); + let _ = client.r#loop(ctx).await; + }); + + let mut server = + BaseChannel::with_defaults(rx).respond_with(Arc::new(LoopServer(rpc_finished_tx)).serve()); + let first_handler = server.next().await.unwrap()?; + + drop(server); + first_handler.await; + + // At this point, a single RPC has been sent and a single response initiated. + // The request handler will loop for a long time unless aborted. + // Now, we assert that the act of disconnecting a client is sufficient to abort all + // handlers initiated by the connection's RPCs. + assert_matches!(rpc_finished.recv().await, Some(AllHandlersComplete)); + Ok(()) +} + #[cfg(all(feature = "serde-transport", feature = "tcp"))] #[tokio::test] async fn serde() -> io::Result<()> {