diff --git a/tarpc/src/rpc/client/channel.rs b/tarpc/src/rpc/client/channel.rs index 0432456..290602e 100644 --- a/tarpc/src/rpc/client/channel.rs +++ b/tarpc/src/rpc/client/channel.rs @@ -78,14 +78,21 @@ impl<'a, Req, Resp> Future for Send<'a, Req, Resp> { #[must_use = "futures do nothing unless polled"] pub struct Call<'a, Req, Resp> { #[pin] - fut: AndThenIdent, DispatchResponse>, + fut: tokio::time::Timeout, DispatchResponse>>, } impl<'a, Req, Resp> Future for Call<'a, Req, Resp> { type Output = io::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.as_mut().project().fut.poll(cx) + let resp = ready!(self.as_mut().project().fut.poll(cx)); + Poll::Ready(match resp { + Ok(resp) => resp, + Err(tokio::time::Elapsed { .. }) => Err(io::Error::new( + io::ErrorKind::TimedOut, + "Client dropped expired request.".to_string(), + )), + }) } } @@ -97,13 +104,6 @@ impl Channel { ctx.trace_context.parent_id = Some(ctx.trace_context.span_id); ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng()); - let timeout = ctx.deadline.time_until(); - trace!( - "[{}] Queuing request with timeout {:?}.", - ctx.trace_id(), - timeout, - ); - let (response_completion, response) = oneshot::channel(); let cancellation = self.cancellation.clone(); let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); @@ -116,7 +116,7 @@ impl Channel { response_completion, })), DispatchResponse { - response: tokio::time::timeout(timeout, response), + response, complete: false, request_id, cancellation, @@ -128,9 +128,16 @@ impl Channel { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. - pub fn call(&mut self, context: context::Context, request: Req) -> Call { + pub fn call(&mut self, ctx: context::Context, request: Req) -> Call { + let timeout = ctx.deadline.time_until(); + trace!( + "[{}] Queuing request with timeout {:?}.", + ctx.trace_id(), + timeout, + ); + Call { - fut: AndThenIdent::new(self.send(context, request)), + fut: tokio::time::timeout(timeout, AndThenIdent::new(self.send(ctx, request))), } } } @@ -140,7 +147,7 @@ impl Channel { #[pin_project(PinnedDrop)] #[derive(Debug)] struct DispatchResponse { - response: tokio::time::Timeout>>, + response: oneshot::Receiver>, ctx: context::Context, complete: bool, cancellation: RequestCancellation, @@ -152,24 +159,15 @@ impl Future for DispatchResponse { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let resp = ready!(self.response.poll_unpin(cx)); - + self.complete = true; Poll::Ready(match resp { - Ok(resp) => { - self.complete = true; - match resp { - Ok(resp) => Ok(resp.message?), - Err(oneshot::Canceled) => { - // The oneshot is Canceled when the dispatch task ends. In that case, - // there's nothing listening on the other side, so there's no point in - // propagating cancellation. - Err(io::Error::from(io::ErrorKind::ConnectionReset)) - } - } + Ok(resp) => Ok(resp.message?), + Err(oneshot::Canceled) => { + // The oneshot is Canceled when the dispatch task ends. In that case, + // there's nothing listening on the other side, so there's no point in + // propagating cancellation. + Err(io::Error::from(io::ErrorKind::ConnectionReset)) } - Err(tokio::time::Elapsed { .. }) => Err(io::Error::new( - io::ErrorKind::TimedOut, - "Client dropped expired request.".to_string(), - )), }) } } @@ -189,7 +187,7 @@ impl PinnedDrop for DispatchResponse { // closing the receiver before sending the cancel message, it is guaranteed that if the // dispatch task misses an early-arriving cancellation message, then it will see the // receiver as closed. - self.response.get_mut().close(); + self.response.close(); let request_id = self.request_id; self.cancellation.cancel(request_id); } @@ -714,24 +712,21 @@ mod tests { prelude::*, task::*, }; - use std::time::Duration; use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc}; #[tokio::test(threaded_scheduler)] - async fn dispatch_response_cancels_on_timeout() { - let (_response_completion, response) = oneshot::channel(); + async fn dispatch_response_cancels_on_drop() { let (cancellation, mut canceled_requests) = cancellations(); - let resp = DispatchResponse:: { - // Timeout in the past should cause resp to error out when polled. - response: tokio::time::timeout(Duration::from_secs(0), response), + let (_, response) = oneshot::channel(); + drop(DispatchResponse:: { + response, + cancellation, complete: false, request_id: 3, - cancellation, ctx: context::current(), - }; - let _ = futures::poll!(resp); + }); // resp's drop() is run, which should send a cancel message. - assert!(canceled_requests.0.try_next().unwrap() == Some(3)); + assert_eq!(canceled_requests.0.try_next().unwrap(), Some(3)); } #[tokio::test(threaded_scheduler)] @@ -819,7 +814,7 @@ mod tests { // i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request // map. let mut resp = send_request(&mut channel, "hi").await; - resp.response.get_mut().close(); + resp.response.close(); assert!(dispatch.poll_next_request(cx).is_pending()); } diff --git a/tarpc/src/rpc/server/mod.rs b/tarpc/src/rpc/server/mod.rs index 6b350a6..d2aa9eb 100644 --- a/tarpc/src/rpc/server/mod.rs +++ b/tarpc/src/rpc/server/mod.rs @@ -651,11 +651,9 @@ where pub fn execute(self) -> impl Future { use log::info; - self.try_for_each(|request_handler| { - async { - tokio::spawn(request_handler); - Ok(()) - } + self.try_for_each(|request_handler| async { + tokio::spawn(request_handler); + Ok(()) }) .unwrap_or_else(|e| info!("ClientHandler errored out: {}", e)) }