Abort all in-flight requests when dropping BaseChannel.

Fixes #341
This commit is contained in:
Tim Kuehn
2021-01-24 17:57:44 -08:00
parent 4b513bad73
commit 3b422eb179
3 changed files with 97 additions and 5 deletions

View File

@@ -144,11 +144,11 @@ where
#[cfg(feature = "tokio1")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
pub fn spawn(self) -> io::Result<C> {
use log::error;
use log::warn;
let dispatch = self
.dispatch
.unwrap_or_else(move |e| error!("Connection broken: {}", e));
.unwrap_or_else(move |e| warn!("Connection broken: {}", e));
tokio::spawn(dispatch);
Ok(self.client)
}

View File

@@ -21,7 +21,7 @@ use futures::{
};
use humantime::format_rfc3339;
use log::{debug, trace};
use pin_project::pin_project;
use pin_project::{pin_project, pinned_drop};
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
use tokio::time::Timeout;
@@ -171,7 +171,7 @@ where
}
/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
#[pin_project]
#[pin_project(PinnedDrop)]
pub struct BaseChannel<Req, Resp, T> {
config: Config,
/// Writes responses to the wire and reads requests off the wire.
@@ -211,7 +211,9 @@ where
pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> {
self.project().transport.get_pin_mut()
}
}
impl<Req, Resp, T> BaseChannel<Req, Resp, T> {
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
// It's possible the request was already completed, so it's fine
// if this is None.
@@ -360,6 +362,17 @@ where
}
}
#[pinned_drop]
impl<Req, Resp, T> PinnedDrop for BaseChannel<Req, Resp, T> {
fn drop(mut self: Pin<&mut Self>) {
self.as_mut()
.project()
.in_flight_requests
.values()
.for_each(AbortHandle::abort);
}
}
impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
fn as_ref(&self) -> &T {
self.transport.get_ref()
@@ -729,3 +742,19 @@ where
Poll::Ready(())
}
}
#[tokio::test]
async fn abort_in_flight_requests_on_channel_drop() {
use assert_matches::assert_matches;
use futures::future::Aborted;
let (_, server_transport) =
super::transport::channel::unbounded::<Response<()>, ClientMessage<()>>();
let channel = BaseChannel::with_defaults(server_transport);
let mut channel = Box::pin(channel);
let abort_registration = channel.as_mut().start_request(1);
let future = Abortable::new(async { () }, abort_registration);
drop(channel);
assert_matches!(future.await, Err(Aborted));
}

View File

@@ -3,7 +3,11 @@ use futures::{
future::{join_all, ready, Ready},
prelude::*,
};
use std::io;
use std::{
io,
sync::Arc,
time::{Duration, SystemTime},
};
use tarpc::{
client::{self},
context,
@@ -57,6 +61,65 @@ async fn sequential() -> io::Result<()> {
Ok(())
}
#[tokio::test]
async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
#[tarpc_plugins::service]
trait Loop {
async fn r#loop();
}
struct LoopServer(tokio::sync::mpsc::UnboundedSender<AllHandlersComplete>);
#[derive(Debug)]
struct AllHandlersComplete;
impl Drop for LoopServer {
fn drop(&mut self) {
let _ = self.0.send(AllHandlersComplete);
}
}
#[tarpc::server]
impl Loop for Arc<LoopServer> {
async fn r#loop(self, _: context::Context) {
loop {
futures::pending!();
}
}
}
let _ = env_logger::try_init();
let (tx, rx) = channel::unbounded();
let (rpc_finished_tx, mut rpc_finished) = tokio::sync::mpsc::unbounded_channel();
// Set up a client that initiates a long-lived request.
// The request will complete in error when the server drops the connection.
tokio::spawn(async move {
let mut client = LoopClient::new(client::Config::default(), tx)
.spawn()
.unwrap();
let mut ctx = context::current();
ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60);
let _ = client.r#loop(ctx).await;
});
let mut server =
BaseChannel::with_defaults(rx).respond_with(Arc::new(LoopServer(rpc_finished_tx)).serve());
let first_handler = server.next().await.unwrap()?;
drop(server);
first_handler.await;
// At this point, a single RPC has been sent and a single response initiated.
// The request handler will loop for a long time unless aborted.
// Now, we assert that the act of disconnecting a client is sufficient to abort all
// handlers initiated by the connection's RPCs.
assert_matches!(rpc_finished.recv().await, Some(AllHandlersComplete));
Ok(())
}
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
#[tokio::test]
async fn serde() -> io::Result<()> {