mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-01-22 03:04:07 +01:00
@@ -144,11 +144,11 @@ where
|
||||
#[cfg(feature = "tokio1")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
|
||||
pub fn spawn(self) -> io::Result<C> {
|
||||
use log::error;
|
||||
use log::warn;
|
||||
|
||||
let dispatch = self
|
||||
.dispatch
|
||||
.unwrap_or_else(move |e| error!("Connection broken: {}", e));
|
||||
.unwrap_or_else(move |e| warn!("Connection broken: {}", e));
|
||||
tokio::spawn(dispatch);
|
||||
Ok(self.client)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ use futures::{
|
||||
};
|
||||
use humantime::format_rfc3339;
|
||||
use log::{debug, trace};
|
||||
use pin_project::pin_project;
|
||||
use pin_project::{pin_project, pinned_drop};
|
||||
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
|
||||
use tokio::time::Timeout;
|
||||
|
||||
@@ -171,7 +171,7 @@ where
|
||||
}
|
||||
|
||||
/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
|
||||
#[pin_project]
|
||||
#[pin_project(PinnedDrop)]
|
||||
pub struct BaseChannel<Req, Resp, T> {
|
||||
config: Config,
|
||||
/// Writes responses to the wire and reads requests off the wire.
|
||||
@@ -211,7 +211,9 @@ where
|
||||
pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> {
|
||||
self.project().transport.get_pin_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> BaseChannel<Req, Resp, T> {
|
||||
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
|
||||
// It's possible the request was already completed, so it's fine
|
||||
// if this is None.
|
||||
@@ -360,6 +362,17 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[pinned_drop]
|
||||
impl<Req, Resp, T> PinnedDrop for BaseChannel<Req, Resp, T> {
|
||||
fn drop(mut self: Pin<&mut Self>) {
|
||||
self.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.values()
|
||||
.for_each(AbortHandle::abort);
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
|
||||
fn as_ref(&self) -> &T {
|
||||
self.transport.get_ref()
|
||||
@@ -729,3 +742,19 @@ where
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abort_in_flight_requests_on_channel_drop() {
|
||||
use assert_matches::assert_matches;
|
||||
use futures::future::Aborted;
|
||||
|
||||
let (_, server_transport) =
|
||||
super::transport::channel::unbounded::<Response<()>, ClientMessage<()>>();
|
||||
let channel = BaseChannel::with_defaults(server_transport);
|
||||
let mut channel = Box::pin(channel);
|
||||
|
||||
let abort_registration = channel.as_mut().start_request(1);
|
||||
let future = Abortable::new(async { () }, abort_registration);
|
||||
drop(channel);
|
||||
assert_matches!(future.await, Err(Aborted));
|
||||
}
|
||||
|
||||
@@ -3,7 +3,11 @@ use futures::{
|
||||
future::{join_all, ready, Ready},
|
||||
prelude::*,
|
||||
};
|
||||
use std::io;
|
||||
use std::{
|
||||
io,
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use tarpc::{
|
||||
client::{self},
|
||||
context,
|
||||
@@ -57,6 +61,65 @@ async fn sequential() -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> {
|
||||
#[tarpc_plugins::service]
|
||||
trait Loop {
|
||||
async fn r#loop();
|
||||
}
|
||||
|
||||
struct LoopServer(tokio::sync::mpsc::UnboundedSender<AllHandlersComplete>);
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AllHandlersComplete;
|
||||
|
||||
impl Drop for LoopServer {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.0.send(AllHandlersComplete);
|
||||
}
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl Loop for Arc<LoopServer> {
|
||||
async fn r#loop(self, _: context::Context) {
|
||||
loop {
|
||||
futures::pending!();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let (tx, rx) = channel::unbounded();
|
||||
let (rpc_finished_tx, mut rpc_finished) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
// Set up a client that initiates a long-lived request.
|
||||
// The request will complete in error when the server drops the connection.
|
||||
tokio::spawn(async move {
|
||||
let mut client = LoopClient::new(client::Config::default(), tx)
|
||||
.spawn()
|
||||
.unwrap();
|
||||
|
||||
let mut ctx = context::current();
|
||||
ctx.deadline = SystemTime::now() + Duration::from_secs(60 * 60);
|
||||
let _ = client.r#loop(ctx).await;
|
||||
});
|
||||
|
||||
let mut server =
|
||||
BaseChannel::with_defaults(rx).respond_with(Arc::new(LoopServer(rpc_finished_tx)).serve());
|
||||
let first_handler = server.next().await.unwrap()?;
|
||||
|
||||
drop(server);
|
||||
first_handler.await;
|
||||
|
||||
// At this point, a single RPC has been sent and a single response initiated.
|
||||
// The request handler will loop for a long time unless aborted.
|
||||
// Now, we assert that the act of disconnecting a client is sufficient to abort all
|
||||
// handlers initiated by the connection's RPCs.
|
||||
assert_matches!(rpc_finished.recv().await, Some(AllHandlersComplete));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "serde-transport", feature = "tcp"))]
|
||||
#[tokio::test]
|
||||
async fn serde() -> io::Result<()> {
|
||||
|
||||
Reference in New Issue
Block a user