diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 95fc91b..7331c9f 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -14,7 +14,7 @@ use crate::{ use futures::{prelude::*, ready, stream::Fuse, task::*}; use in_flight_requests::InFlightRequests; use log::{info, trace}; -use pin_project::{pin_project, pinned_drop}; +use pin_project::pin_project; use std::{ convert::TryFrom, fmt, io, @@ -165,7 +165,6 @@ impl Channel { /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. -#[pin_project(PinnedDrop)] #[derive(Debug)] struct DispatchResponse { response: oneshot::Receiver>, @@ -193,11 +192,9 @@ impl Future for DispatchResponse { } // Cancels the request when dropped, if not already complete. -#[pinned_drop] -impl PinnedDrop for DispatchResponse { - fn drop(mut self: Pin<&mut Self>) { - let self_ = self.project(); - if let Some(cancellation) = self_.cancellation { +impl Drop for DispatchResponse { + fn drop(&mut self) { + if let Some(cancellation) = &mut self.cancellation { // The receiver needs to be closed to handle the edge case that the request has not // yet been received by the dispatch task. It is possible for the cancel message to // arrive before the request itself, in which case the request could get stuck in the @@ -208,8 +205,8 @@ impl PinnedDrop for DispatchResponse { // closing the receiver before sending the cancel message, it is guaranteed that if the // dispatch task misses an early-arriving cancellation message, then it will see the // receiver as closed. - self_.response.close(); - cancellation.cancel(*self_.request_id); + self.response.close(); + cancellation.cancel(self.request_id); } } } @@ -252,10 +249,8 @@ pub struct RequestDispatch { #[pin] transport: Fuse, /// Requests waiting to be written to the wire. - #[pin] pending_requests: mpsc::Receiver>, /// Requests that were dropped. - #[pin] canceled_requests: CanceledRequests, /// Requests already written to the wire that haven't yet received responses. in_flight_requests: InFlightRequests, @@ -271,16 +266,28 @@ where self.as_mut().project().in_flight_requests } + fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse> { + self.as_mut().project().transport + } + + fn canceled_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut CanceledRequests { + self.as_mut().project().canceled_requests + } + + fn pending_requests_mut<'a>( + self: &'a mut Pin<&mut Self>, + ) -> &'a mut mpsc::Receiver> { + self.as_mut().project().pending_requests + } + fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { - Poll::Ready( - match ready!(self.as_mut().project().transport.poll_next(cx)?) { - Some(response) => { - self.complete(response); - Some(Ok(())) - } - None => None, - }, - ) + Poll::Ready(match ready!(self.transport_pin_mut().poll_next(cx)?) { + Some(response) => { + self.complete(response); + Some(Ok(())) + } + None => None, + }) } fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { @@ -289,20 +296,14 @@ where Closed, } - let pending_requests_status = match self.as_mut().poll_next_request(cx)? { - Poll::Ready(Some(dispatch_request)) => { - self.as_mut().write_request(dispatch_request)?; - return Poll::Ready(Some(Ok(()))); - } + let pending_requests_status = match self.as_mut().poll_write_request(cx)? { + Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))), Poll::Ready(None) => ReceiverStatus::Closed, Poll::Pending => ReceiverStatus::NotReady, }; - let canceled_requests_status = match self.as_mut().poll_next_cancellation(cx)? { - Poll::Ready(Some((context, request_id))) => { - self.as_mut().write_cancel(context, request_id)?; - return Poll::Ready(Some(Ok(()))); - } + let canceled_requests_status = match self.as_mut().poll_write_cancel(cx)? { + Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))), Poll::Ready(None) => ReceiverStatus::Closed, Poll::Pending => ReceiverStatus::NotReady, }; @@ -319,12 +320,12 @@ where match (pending_requests_status, canceled_requests_status) { (ReceiverStatus::Closed, ReceiverStatus::Closed) => { - ready!(self.as_mut().project().transport.poll_flush(cx)?); + ready!(self.transport_pin_mut().poll_flush(cx)?); Poll::Ready(None) } (ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => { // No more messages to process, so flush any messages buffered in the transport. - ready!(self.as_mut().project().transport.poll_flush(cx)?); + ready!(self.transport_pin_mut().poll_flush(cx)?); // Even if we fully-flush, we return Pending, because we have no more requests // or cancellations right now. @@ -334,6 +335,9 @@ where } /// Yields the next pending request, if one is ready to be sent. + /// + /// Note that a request will only be yielded if the transport is *ready* to be written to (i.e. + /// start_send would succeed). fn poll_next_request( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -350,19 +354,10 @@ where return Poll::Pending; } - while self - .as_mut() - .project() - .transport - .poll_ready(cx)? - .is_pending() - { - // We can't yield a request-to-be-sent before the transport is capable of buffering it. - ready!(self.as_mut().project().transport.poll_flush(cx)?); - } + ready!(self.ensure_writeable(cx)?); loop { - match ready!(self.as_mut().project().pending_requests.poll_recv(cx)) { + match ready!(self.pending_requests_mut().poll_recv(cx)) { Some(request) => { if request.response_completion.is_closed() { trace!( @@ -380,27 +375,17 @@ where } /// Yields the next pending cancellation, and, if one is ready, cancels the associated request. + /// + /// Note that a request to cancel will only be yielded if the transport is *ready* to be + /// written to (i.e. start_send would succeed). fn poll_next_cancellation( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> PollIo<(context::Context, u64)> { - while self - .as_mut() - .project() - .transport - .poll_ready(cx)? - .is_pending() - { - ready!(self.as_mut().project().transport.poll_flush(cx)?); - } + ready!(self.ensure_writeable(cx)?); loop { - let cancellation = self - .as_mut() - .project() - .canceled_requests - .poll_next_unpin(cx); - match ready!(cancellation) { + match ready!(self.canceled_requests_mut().poll_next_unpin(cx)) { Some(request_id) => { if let Some(ctx) = self.in_flight_requests().cancel_request(request_id) { return Poll::Ready(Some(Ok((ctx, request_id)))); @@ -411,10 +396,24 @@ where } } - fn write_request( - mut self: Pin<&mut Self>, - dispatch_request: DispatchRequest, - ) -> io::Result<()> { + /// Returns Ready if writing a message to the transport (i.e. via write_request or + /// write_cancel) would not fail due to a full buffer. If the transport is not ready to be + /// written to, flushes it until it is ready. + fn ensure_writeable<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { + while self.transport_pin_mut().poll_ready(cx)?.is_pending() { + ready!(self.transport_pin_mut().poll_flush(cx)?); + } + Poll::Ready(Some(Ok(()))) + } + + fn poll_write_request<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { + let dispatch_request = match ready!(self.as_mut().poll_next_request(cx)?) { + Some(dispatch_request) => dispatch_request, + None => return Poll::Ready(None), + }; + // poll_next_request only returns Ready if there is room to buffer another request. + // Therefore, we can call write_request without fear of erroring due to a full + // buffer. let request_id = dispatch_request.request_id; let request = ClientMessage::Request(Request { id: request_id, @@ -424,7 +423,7 @@ where trace_context: dispatch_request.ctx.trace_context, }, }); - self.as_mut().project().transport.start_send(request)?; + self.transport_pin_mut().start_send(request)?; self.in_flight_requests() .insert_request( request_id, @@ -432,22 +431,23 @@ where dispatch_request.response_completion, ) .expect("Request IDs should be unique"); - Ok(()) + Poll::Ready(Some(Ok(()))) } - fn write_cancel( - mut self: Pin<&mut Self>, - context: context::Context, - request_id: u64, - ) -> io::Result<()> { + fn poll_write_cancel<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { + let (context, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { + Some((context, request_id)) => (context, request_id), + None => return Poll::Ready(None), + }; + let trace_id = *context.trace_id(); let cancel = ClientMessage::Cancel { trace_context: context.trace_context, request_id, }; - self.as_mut().project().transport.start_send(cancel)?; + self.transport_pin_mut().start_send(cancel)?; trace!("[{}] Cancel message sent.", trace_id); - Ok(()) + Poll::Ready(Some(Ok(()))) } /// Sends a server response to the client task that initiated the associated request. @@ -532,11 +532,17 @@ impl RequestCancellation { } } +impl CanceledRequests { + fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_recv(cx) + } +} + impl Stream for CanceledRequests { type Item = u64; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.0.poll_recv(cx) + self.poll_recv(cx) } } diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 9c5a75f..366eb3b 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -381,10 +381,8 @@ where #[pin] channel: C, /// Responses waiting to be written to the wire. - #[pin] pending_responses: mpsc::Receiver<(context::Context, Response)>, /// Handed out to request handlers to fan in responses. - #[pin] responses_tx: mpsc::Sender<(context::Context, Response)>, } @@ -397,6 +395,13 @@ where self.as_mut().project().channel } + /// Returns the inner channel over which messages are sent and received. + pub fn pending_responses_mut<'a>( + self: &'a mut Pin<&mut Self>, + ) -> &'a mut mpsc::Receiver<(context::Context, Response)> { + self.as_mut().project().pending_responses + } + fn pump_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -451,12 +456,8 @@ where context.trace_id(), self.channel.in_flight_requests(), ); - // TODO: it's possible for poll_flush to be starved and start_send to end up full. - // Currently that would cause the channel to shut down. serde_transport internally - // uses tokio-util Framed, which will allocate as much as needed. But other - // transports may work differently. - // - // There should be a way to know if a flush is needed soon. + // A Ready result from poll_next_response means the Channel is ready to be written + // to. Therefore, we can call start_send without worry of a full buffer. self.channel_pin_mut().start_send(response)?; Poll::Ready(Some(Ok(()))) } @@ -481,16 +482,17 @@ where } } + /// Yields a response ready to be written to the Channel sink. + /// + /// Note that a response will only be yielded if the Channel is *ready* to be written to (i.e. + /// start_send would succeed). fn poll_next_response( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> PollIo<(context::Context, Response)> { - // Ensure there's room to write a response. - while self.channel_pin_mut().poll_ready(cx)?.is_pending() { - ready!(self.as_mut().project().channel.poll_flush(cx)?); - } + ready!(self.ensure_writeable(cx)?); - match ready!(self.as_mut().project().pending_responses.poll_recv(cx)) { + match ready!(self.pending_responses_mut().poll_recv(cx)) { Some(response) => Poll::Ready(Some(Ok(response))), None => { // This branch likely won't happen, since the Requests stream is holding a Sender. @@ -498,6 +500,15 @@ where } } } + + /// Returns Ready if writing a message to the Channel would not fail due to a full buffer. If + /// the Channel is not ready to be written to, flushes it until it is ready. + fn ensure_writeable<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { + while self.channel_pin_mut().poll_ready(cx)?.is_pending() { + ready!(self.channel_pin_mut().poll_flush(cx)?); + } + Poll::Ready(Some(Ok(()))) + } } impl fmt::Debug for Requests