diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 7331c9f..f068bf2 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -292,20 +292,20 @@ where fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { enum ReceiverStatus { - NotReady, + Pending, Closed, } let pending_requests_status = match self.as_mut().poll_write_request(cx)? { Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))), Poll::Ready(None) => ReceiverStatus::Closed, - Poll::Pending => ReceiverStatus::NotReady, + Poll::Pending => ReceiverStatus::Pending, }; let canceled_requests_status = match self.as_mut().poll_write_cancel(cx)? { Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))), Poll::Ready(None) => ReceiverStatus::Closed, - Poll::Pending => ReceiverStatus::NotReady, + Poll::Pending => ReceiverStatus::Pending, }; // Receiving Poll::Ready(None) when polling expired requests never indicates "Closed", @@ -323,7 +323,7 @@ where ready!(self.transport_pin_mut().poll_flush(cx)?); Poll::Ready(None) } - (ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => { + (ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => { // No more messages to process, so flush any messages buffered in the transport. ready!(self.transport_pin_mut().poll_flush(cx)?); diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index d497f9f..f395447 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -15,6 +15,7 @@ use futures::{ task::*, }; use humantime::format_rfc3339; +use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use log::{debug, info, trace}; use pin_project::pin_project; use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime}; @@ -138,7 +139,7 @@ pub struct BaseChannel { #[pin] transport: Fuse, /// Holds data necessary to clean up in-flight requests. - in_flight_requests: in_flight_requests::InFlightRequests, + in_flight_requests: InFlightRequests, /// Types the request and response. ghost: PhantomData<(Req, Resp)>, } @@ -152,7 +153,7 @@ where BaseChannel { config, transport: transport.fuse(), - in_flight_requests: in_flight_requests::InFlightRequests::default(), + in_flight_requests: InFlightRequests::default(), ghost: PhantomData, } } @@ -171,6 +172,14 @@ where pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> { self.project().transport.get_pin_mut() } + + fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { + self.as_mut().project().in_flight_requests + } + + fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse> { + self.as_mut().project().transport + } } impl fmt::Debug for BaseChannel { @@ -231,7 +240,7 @@ where self: Pin<&mut Self>, id: u64, deadline: SystemTime, - ) -> Result; + ) -> Result; /// Returns a stream of requests that automatically handle request cancellation and response /// routing. @@ -276,16 +285,27 @@ where type Item = io::Result>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let self_ = self.as_mut().project(); - while let Poll::Ready(Some(request_id)) = self_.in_flight_requests.poll_expired(cx)? { - // No need to send a response, since the client wouldn't be waiting for one anymore. - debug!("Request {} did not complete before deadline", request_id); + enum ReceiverStatus { + Ready, + Pending, + Closed, } + use ReceiverStatus::*; loop { - let self_ = self.as_mut().project(); - match ready!(self_.transport.poll_next(cx)?) { - Some(message) => match message { + let expiration_status = match self.in_flight_requests_mut().poll_expired(cx)? { + Poll::Ready(Some(request_id)) => { + // No need to send a response, since the client wouldn't be waiting for one + // anymore. + debug!("Request {} did not complete before deadline", request_id); + Ready + } + Poll::Ready(None) => Closed, + Poll::Pending => Pending, + }; + + let request_status = match self.transport_pin_mut().poll_next(cx)? { + Poll::Ready(Some(message)) => match message { ClientMessage::Request(request) => { return Poll::Ready(Some(Ok(request))); } @@ -293,8 +313,8 @@ where trace_context, request_id, } => { - if self_.in_flight_requests.cancel_request(request_id) { - let remaining = self_.in_flight_requests.len(); + if self.in_flight_requests_mut().cancel_request(request_id) { + let remaining = self.in_flight_requests.len(); trace!( "[{}] Request canceled. In-flight requests = {}", trace_context.trace_id, @@ -307,9 +327,17 @@ where trace_context.trace_id, ); } + Ready } }, - None => return Poll::Ready(None), + Poll::Ready(None) => Closed, + Poll::Pending => Pending, + }; + + match (expiration_status, request_status) { + (Ready, _) | (_, Ready) => continue, + (Closed, Closed) => return Poll::Ready(None), + (Pending, Closed) | (Closed, Pending) | (Pending, Pending) => return Poll::Pending, } } } @@ -367,7 +395,7 @@ where self: Pin<&mut Self>, id: u64, deadline: SystemTime, - ) -> Result { + ) -> Result { self.project() .in_flight_requests .start_request(id, deadline) @@ -432,7 +460,7 @@ where // Instead of closing the channel if a duplicate request is sent, just // ignore it, since it's already being processed. Note that we cannot // return Poll::Pending here, since nothing has scheduled a wakeup yet. - Err(in_flight_requests::AlreadyExistsError) => { + Err(AlreadyExistsError) => { info!( "[{}] Request ID {} delivered more than once.", request.context.trace_id(), @@ -712,3 +740,310 @@ where Poll::Ready(()) } } + +#[cfg(test)] +use { + crate::{ + trace, + transport::channel::{self, UnboundedChannel}, + }, + assert_matches::assert_matches, + futures::future::{pending, Aborted}, + futures_test::task::noop_context, + std::time::Duration, +}; + +#[cfg(test)] +fn test_channel() -> ( + Pin, Response>>>>, + UnboundedChannel, ClientMessage>, +) { + let (tx, rx) = crate::transport::channel::unbounded(); + (Box::pin(BaseChannel::new(Config::default(), rx)), tx) +} + +#[cfg(test)] +fn test_requests() -> ( + Pin< + Box, Response>>>>, + >, + UnboundedChannel, ClientMessage>, +) { + let (tx, rx) = crate::transport::channel::unbounded(); + ( + Box::pin(BaseChannel::new(Config::default(), rx).requests()), + tx, + ) +} + +#[cfg(test)] +fn test_bounded_requests( + capacity: usize, +) -> ( + Pin< + Box, Response>>>>, + >, + channel::Channel, ClientMessage>, +) { + let (tx, rx) = crate::transport::channel::bounded(capacity); + let mut config = Config::default(); + // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). + config.pending_response_buffer = capacity + 1; + (Box::pin(BaseChannel::new(config, rx).requests()), tx) +} + +#[cfg(test)] +fn fake_request(req: Req) -> ClientMessage { + ClientMessage::Request(Request { + context: context::current(), + id: 0, + message: req, + }) +} + +#[cfg(test)] +fn test_abortable( + abort_registration: AbortRegistration, +) -> impl Future> { + Abortable::new(pending(), abort_registration) +} + +#[tokio::test] +async fn base_channel_start_send_duplicate_request_returns_error() { + let (mut channel, _tx) = test_channel::<(), ()>(); + + channel + .as_mut() + .start_request(0, SystemTime::now()) + .unwrap(); + assert_matches!( + channel.as_mut().start_request(0, SystemTime::now()), + Err(AlreadyExistsError) + ); +} + +#[tokio::test] +async fn base_channel_poll_next_aborts_multiple_requests() { + let (mut channel, _tx) = test_channel::<(), ()>(); + + tokio::time::pause(); + let abort_registration0 = channel + .as_mut() + .start_request(0, SystemTime::now()) + .unwrap(); + let abort_registration1 = channel + .as_mut() + .start_request(1, SystemTime::now()) + .unwrap(); + tokio::time::advance(std::time::Duration::from_secs(1000)).await; + + assert_matches!( + channel.as_mut().poll_next(&mut noop_context()), + Poll::Pending + ); + assert_matches!(test_abortable(abort_registration0).await, Err(Aborted)); + assert_matches!(test_abortable(abort_registration1).await, Err(Aborted)); +} + +#[tokio::test] +async fn base_channel_poll_next_aborts_canceled_request() { + let (mut channel, mut tx) = test_channel::<(), ()>(); + + tokio::time::pause(); + let abort_registration = channel + .as_mut() + .start_request(0, SystemTime::now() + Duration::from_millis(100)) + .unwrap(); + + tx.send(ClientMessage::Cancel { + trace_context: trace::Context::default(), + request_id: 0, + }) + .await + .unwrap(); + + assert_matches!( + channel.as_mut().poll_next(&mut noop_context()), + Poll::Pending + ); + + assert_matches!(test_abortable(abort_registration).await, Err(Aborted)); +} + +#[tokio::test] +async fn base_channel_with_closed_transport_and_in_flight_request_returns_pending() { + let (mut channel, tx) = test_channel::<(), ()>(); + + tokio::time::pause(); + let _abort_registration = channel + .as_mut() + .start_request(0, SystemTime::now() + Duration::from_millis(100)) + .unwrap(); + + drop(tx); + assert_matches!( + channel.as_mut().poll_next(&mut noop_context()), + Poll::Pending + ); +} + +#[tokio::test] +async fn base_channel_with_closed_transport_and_no_in_flight_requests_returns_closed() { + let (mut channel, tx) = test_channel::<(), ()>(); + drop(tx); + assert_matches!( + channel.as_mut().poll_next(&mut noop_context()), + Poll::Ready(None) + ); +} + +#[tokio::test] +async fn base_channel_poll_next_yields_request() { + let (mut channel, mut tx) = test_channel::<(), ()>(); + tx.send(fake_request(())).await.unwrap(); + + assert_matches!( + channel.as_mut().poll_next(&mut noop_context()), + Poll::Ready(Some(Ok(_))) + ); +} + +#[tokio::test] +async fn base_channel_poll_next_aborts_request_and_yields_request() { + let (mut channel, mut tx) = test_channel::<(), ()>(); + + tokio::time::pause(); + let abort_registration = channel + .as_mut() + .start_request(0, SystemTime::now()) + .unwrap(); + tokio::time::advance(std::time::Duration::from_secs(1000)).await; + + tx.send(fake_request(())).await.unwrap(); + + assert_matches!( + channel.as_mut().poll_next(&mut noop_context()), + Poll::Ready(Some(Ok(_))) + ); + assert_matches!(test_abortable(abort_registration).await, Err(Aborted)); +} + +#[tokio::test] +async fn base_channel_start_send_removes_in_flight_request() { + let (mut channel, _tx) = test_channel::<(), ()>(); + + channel + .as_mut() + .start_request(0, SystemTime::now()) + .unwrap(); + assert_eq!(channel.in_flight_requests(), 1); + channel + .as_mut() + .start_send(Response { + request_id: 0, + message: Ok(()), + }) + .unwrap(); + assert_eq!(channel.in_flight_requests(), 0); +} + +#[tokio::test] +async fn requests_poll_next_response_returns_pending_when_buffer_full() { + let (mut requests, _tx) = test_bounded_requests::<(), ()>(0); + + // Response written to the transport. + requests + .as_mut() + .channel_pin_mut() + .start_send(Response { + request_id: 0, + message: Ok(()), + }) + .unwrap(); + + // Response waiting to be written. + requests + .as_mut() + .project() + .responses_tx + .send(( + context::current(), + Response { + request_id: 1, + message: Ok(()), + }, + )) + .await + .unwrap(); + + requests + .as_mut() + .channel_pin_mut() + .start_request(1, SystemTime::now()) + .unwrap(); + + assert_matches!( + requests.as_mut().poll_next_response(&mut noop_context()), + Poll::Pending + ); +} + +#[tokio::test] +async fn requests_pump_write_returns_pending_when_buffer_full() { + let (mut requests, _tx) = test_bounded_requests::<(), ()>(0); + + // Response written to the transport. + requests + .as_mut() + .channel_pin_mut() + .start_send(Response { + request_id: 0, + message: Ok(()), + }) + .unwrap(); + + // Response waiting to be written. + requests + .as_mut() + .project() + .responses_tx + .send(( + context::current(), + Response { + request_id: 1, + message: Ok(()), + }, + )) + .await + .unwrap(); + + requests + .as_mut() + .channel_pin_mut() + .start_request(1, SystemTime::now()) + .unwrap(); + + assert_matches!( + requests.as_mut().pump_write(&mut noop_context(), true), + Poll::Pending + ); + // Assert that the pending response was not polled while the channel was blocked. + assert_matches!( + requests.as_mut().pending_responses_mut().recv().await, + Some(_) + ); +} + +#[tokio::test] +async fn requests_pump_read() { + let (mut requests, mut tx) = test_requests::<(), ()>(); + + // Response written to the transport. + tx.send(fake_request(())).await.unwrap(); + + assert_matches!( + requests.as_mut().pump_read(&mut noop_context()), + Poll::Ready(Some(Ok(_))) + ); + assert_eq!(requests.channel.in_flight_requests(), 1); +}