diff --git a/tarpc/src/cancellations.rs b/tarpc/src/cancellations.rs new file mode 100644 index 0000000..6d6d684 --- /dev/null +++ b/tarpc/src/cancellations.rs @@ -0,0 +1,41 @@ +use futures::{prelude::*, task::*}; +use std::pin::Pin; +use tokio::sync::mpsc; + +/// Sends request cancellation signals. +#[derive(Debug, Clone)] +pub struct RequestCancellation(mpsc::UnboundedSender); + +/// A stream of IDs of requests that have been canceled. +#[derive(Debug)] +pub struct CanceledRequests(mpsc::UnboundedReceiver); + +/// Returns a channel to send request cancellation messages. +pub fn cancellations() -> (RequestCancellation, CanceledRequests) { + // Unbounded because messages are sent in the drop fn. This is fine, because it's still + // bounded by the number of in-flight requests. + let (tx, rx) = mpsc::unbounded_channel(); + (RequestCancellation(tx), CanceledRequests(rx)) +} + +impl RequestCancellation { + /// Cancels the request with ID `request_id`. + pub fn cancel(&self, request_id: u64) { + let _ = self.0.send(request_id); + } +} + +impl CanceledRequests { + /// Polls for a cancelled request. + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_recv(cx) + } +} + +impl Stream for CanceledRequests { + type Item = u64; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_recv(cx) + } +} diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index f2bae32..60b34c2 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -8,7 +8,10 @@ mod in_flight_requests; -use crate::{context, trace, ClientMessage, Request, Response, ServerError, Transport}; +use crate::{ + cancellations::{cancellations, CanceledRequests, RequestCancellation}, + context, trace, ClientMessage, Request, Response, ServerError, Transport, +}; use futures::{prelude::*, ready, stream::Fuse, task::*}; use in_flight_requests::{DeadlineExceededError, InFlightRequests}; use pin_project::pin_project; @@ -603,49 +606,9 @@ struct DispatchRequest { pub response_completion: oneshot::Sender, DeadlineExceededError>>, } -/// Sends request cancellation signals. -#[derive(Debug, Clone)] -struct RequestCancellation(mpsc::UnboundedSender); - -/// A stream of IDs of requests that have been canceled. -#[derive(Debug)] -struct CanceledRequests(mpsc::UnboundedReceiver); - -/// Returns a channel to send request cancellation messages. -fn cancellations() -> (RequestCancellation, CanceledRequests) { - // Unbounded because messages are sent in the drop fn. This is fine, because it's still - // bounded by the number of in-flight requests. - let (tx, rx) = mpsc::unbounded_channel(); - (RequestCancellation(tx), CanceledRequests(rx)) -} - -impl RequestCancellation { - /// Cancels the request with ID `request_id`. - fn cancel(&self, request_id: u64) { - let _ = self.0.send(request_id); - } -} - -impl CanceledRequests { - fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { - self.0.poll_recv(cx) - } -} - -impl Stream for CanceledRequests { - type Item = u64; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_recv(cx) - } -} - #[cfg(test)] mod tests { - use super::{ - cancellations, CanceledRequests, Channel, DispatchRequest, RequestCancellation, - RequestDispatch, ResponseGuard, - }; + use super::{cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard}; use crate::{ client::{ in_flight_requests::{DeadlineExceededError, InFlightRequests}, @@ -698,7 +661,7 @@ mod tests { }); // resp's drop() is run, which should send a cancel message. let cx = &mut Context::from_waker(&noop_waker_ref()); - assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(Some(3))); + assert_eq!(canceled_requests.poll_recv(cx), Poll::Ready(Some(3))); } #[tokio::test] @@ -721,7 +684,7 @@ mod tests { .unwrap(); drop(cancellation); let cx = &mut Context::from_waker(&noop_waker_ref()); - assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(None)); + assert_eq!(canceled_requests.poll_recv(cx), Poll::Ready(None)); } #[tokio::test] @@ -829,18 +792,17 @@ mod tests { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); let (to_dispatch, pending_requests) = mpsc::channel(1); - let (cancel_tx, canceled_requests) = mpsc::unbounded_channel(); + let (cancellation, canceled_requests) = cancellations(); let (client_channel, server_channel) = transport::channel::unbounded(); let dispatch = RequestDispatch:: { transport: client_channel.fuse(), pending_requests: pending_requests, - canceled_requests: CanceledRequests(canceled_requests), + canceled_requests, in_flight_requests: InFlightRequests::default(), config: Config::default(), }; - let cancellation = RequestCancellation(cancel_tx); let channel = Channel { to_dispatch, cancellation, diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index f944ba1..ebcbd8f 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -300,6 +300,7 @@ pub use tarpc_plugins::service; /// `async`, meaning that this should not break existing code. pub use tarpc_plugins::server; +pub(crate) mod cancellations; pub mod client; pub mod context; pub mod server;