From f694b7573a4cf3f45001837bcd0c69a5cb2c9878 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Fri, 8 Oct 2021 20:10:16 -0700 Subject: [PATCH] Close TcpStream when client disconnects. An attempt at a clean shutdown helps the server to drop its connections more quickly. Testing this uncovered a latent bug in DelayQueue wherein `poll_expired` yields `Pending` when empty. A workaround was added to `InFlightRequests::poll_expired`: check if there are actually any outstanding requests before calling `DelayQueue::poll_expired`. --- tarpc/src/client.rs | 14 +++++++++++++- tarpc/src/server.rs | 6 ++++++ tarpc/src/server/in_flight_requests.rs | 26 +++++++++++++++++++++++++- 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index f71f351..9f82cac 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -291,6 +291,9 @@ where /// Could not flush the transport. #[error("could not flush the transport")] Flush(#[source] E), + /// Could not close the write end of the transport. + #[error("could not close the write end of the transport")] + Close(#[source] E), /// Could not poll expired requests. #[error("could not poll expired requests")] Timer(#[source] tokio::time::error::Error), @@ -335,6 +338,15 @@ where .map_err(ChannelError::Flush) } + fn poll_close<'a>( + self: &'a mut Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + self.transport_pin_mut() + .poll_close(cx) + .map_err(ChannelError::Close) + } + fn canceled_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut CanceledRequests { self.as_mut().project().canceled_requests } @@ -394,7 +406,7 @@ where match (pending_requests_status, canceled_requests_status) { (ReceiverStatus::Closed, ReceiverStatus::Closed) => { - ready!(self.poll_flush(cx)?); + ready!(self.poll_close(cx)?); Poll::Ready(None) } (ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => { diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index acb0cf1..1181eed 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -333,6 +333,7 @@ where type Item = Result, ChannelError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + #[derive(Debug)] enum ReceiverStatus { Ready, Pending, @@ -388,6 +389,11 @@ where Poll::Pending => Pending, }; + tracing::trace!( + "Expired requests: {:?}, Inbound: {:?}", + expiration_status, + request_status + ); match (expiration_status, request_status) { (Ready, _) | (_, Ready) => continue, (Closed, Closed) => return Poll::Ready(None), diff --git a/tarpc/src/server/in_flight_requests.rs b/tarpc/src/server/in_flight_requests.rs index 67817e2..912243c 100644 --- a/tarpc/src/server/in_flight_requests.rs +++ b/tarpc/src/server/in_flight_requests.rs @@ -98,6 +98,11 @@ impl InFlightRequests { &mut self, cx: &mut Context, ) -> Poll>> { + if self.deadlines.is_empty() { + // TODO(https://github.com/tokio-rs/tokio/issues/4161) + // This is a workaround for DelayQueue not always treating this case correctly. + return Poll::Ready(None); + } self.deadlines.poll_expired(cx).map_ok(|expired| { if let Some(RequestData { abort_handle, span, .. @@ -184,12 +189,31 @@ mod tests { #[tokio::test] async fn remove_request_doesnt_abort() { let mut in_flight_requests = InFlightRequests::default(); + assert!(in_flight_requests.deadlines.is_empty()); + let abort_registration = in_flight_requests - .start_request(0, SystemTime::now(), Span::current()) + .start_request( + 0, + SystemTime::now() + std::time::Duration::from_secs(10), + Span::current(), + ) .unwrap(); let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); + // Precondition: Pending expiration + assert_matches!( + in_flight_requests.poll_expired(&mut noop_context()), + Poll::Pending + ); + assert!(!in_flight_requests.deadlines.is_empty()); + assert_matches!(in_flight_requests.remove_request(0), Some(_)); + // Postcondition: No pending expirations + assert!(in_flight_requests.deadlines.is_empty()); + assert_matches!( + in_flight_requests.poll_expired(&mut noop_context()), + Poll::Ready(None) + ); assert_matches!( abortable_future.poll_unpin(&mut noop_context()), Poll::Pending