mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-01-01 00:51:25 +01:00
Move cancellation types into a dedicated module.
Cancellation utilities could be useful for both client and server code.
This commit is contained in:
41
tarpc/src/cancellations.rs
Normal file
41
tarpc/src/cancellations.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use futures::{prelude::*, task::*};
|
||||
use std::pin::Pin;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Sends request cancellation signals.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RequestCancellation(mpsc::UnboundedSender<u64>);
|
||||
|
||||
/// A stream of IDs of requests that have been canceled.
|
||||
#[derive(Debug)]
|
||||
pub struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
|
||||
|
||||
/// Returns a channel to send request cancellation messages.
|
||||
pub fn cancellations() -> (RequestCancellation, CanceledRequests) {
|
||||
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
|
||||
// bounded by the number of in-flight requests.
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
(RequestCancellation(tx), CanceledRequests(rx))
|
||||
}
|
||||
|
||||
impl RequestCancellation {
|
||||
/// Cancels the request with ID `request_id`.
|
||||
pub fn cancel(&self, request_id: u64) {
|
||||
let _ = self.0.send(request_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl CanceledRequests {
|
||||
/// Polls for a cancelled request.
|
||||
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
||||
self.0.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for CanceledRequests {
|
||||
type Item = u64;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
||||
self.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,10 @@
|
||||
|
||||
mod in_flight_requests;
|
||||
|
||||
use crate::{context, trace, ClientMessage, Request, Response, ServerError, Transport};
|
||||
use crate::{
|
||||
cancellations::{cancellations, CanceledRequests, RequestCancellation},
|
||||
context, trace, ClientMessage, Request, Response, ServerError, Transport,
|
||||
};
|
||||
use futures::{prelude::*, ready, stream::Fuse, task::*};
|
||||
use in_flight_requests::{DeadlineExceededError, InFlightRequests};
|
||||
use pin_project::pin_project;
|
||||
@@ -603,49 +606,9 @@ struct DispatchRequest<Req, Resp> {
|
||||
pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
|
||||
}
|
||||
|
||||
/// Sends request cancellation signals.
|
||||
#[derive(Debug, Clone)]
|
||||
struct RequestCancellation(mpsc::UnboundedSender<u64>);
|
||||
|
||||
/// A stream of IDs of requests that have been canceled.
|
||||
#[derive(Debug)]
|
||||
struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
|
||||
|
||||
/// Returns a channel to send request cancellation messages.
|
||||
fn cancellations() -> (RequestCancellation, CanceledRequests) {
|
||||
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
|
||||
// bounded by the number of in-flight requests.
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
(RequestCancellation(tx), CanceledRequests(rx))
|
||||
}
|
||||
|
||||
impl RequestCancellation {
|
||||
/// Cancels the request with ID `request_id`.
|
||||
fn cancel(&self, request_id: u64) {
|
||||
let _ = self.0.send(request_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl CanceledRequests {
|
||||
fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
||||
self.0.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for CanceledRequests {
|
||||
type Item = u64;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
||||
self.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
cancellations, CanceledRequests, Channel, DispatchRequest, RequestCancellation,
|
||||
RequestDispatch, ResponseGuard,
|
||||
};
|
||||
use super::{cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard};
|
||||
use crate::{
|
||||
client::{
|
||||
in_flight_requests::{DeadlineExceededError, InFlightRequests},
|
||||
@@ -698,7 +661,7 @@ mod tests {
|
||||
});
|
||||
// resp's drop() is run, which should send a cancel message.
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(Some(3)));
|
||||
assert_eq!(canceled_requests.poll_recv(cx), Poll::Ready(Some(3)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -721,7 +684,7 @@ mod tests {
|
||||
.unwrap();
|
||||
drop(cancellation);
|
||||
let cx = &mut Context::from_waker(&noop_waker_ref());
|
||||
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(None));
|
||||
assert_eq!(canceled_requests.poll_recv(cx), Poll::Ready(None));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -829,18 +792,17 @@ mod tests {
|
||||
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
|
||||
|
||||
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
||||
let (cancel_tx, canceled_requests) = mpsc::unbounded_channel();
|
||||
let (cancellation, canceled_requests) = cancellations();
|
||||
let (client_channel, server_channel) = transport::channel::unbounded();
|
||||
|
||||
let dispatch = RequestDispatch::<String, String, _> {
|
||||
transport: client_channel.fuse(),
|
||||
pending_requests: pending_requests,
|
||||
canceled_requests: CanceledRequests(canceled_requests),
|
||||
canceled_requests,
|
||||
in_flight_requests: InFlightRequests::default(),
|
||||
config: Config::default(),
|
||||
};
|
||||
|
||||
let cancellation = RequestCancellation(cancel_tx);
|
||||
let channel = Channel {
|
||||
to_dispatch,
|
||||
cancellation,
|
||||
|
||||
@@ -300,6 +300,7 @@ pub use tarpc_plugins::service;
|
||||
/// `async`, meaning that this should not break existing code.
|
||||
pub use tarpc_plugins::server;
|
||||
|
||||
pub(crate) mod cancellations;
|
||||
pub mod client;
|
||||
pub mod context;
|
||||
pub mod server;
|
||||
|
||||
Reference in New Issue
Block a user