mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-01-07 03:56:48 +01:00
Handle deadlines in BaseChannel.
Before this commit, deadlines were handled by a timeout future that wrapped each request handler. However, request handlers can be dropped before sending a response back to the channel, so they can't be relied on for channel state cleanup. Additionally, clients can't be relied on to send cancellation messages. It was therefore theoretically possible for pathological behaviors to cause an unbounded growth in orphan request data in the Channel. With this change, as long as requests sent have reasonable deadlines, then the channel will be able to clean itself up. It is still possible for requests to be sent with very large deadlines, which would prevent the channel from cleaning itself up.
This commit is contained in:
@@ -37,7 +37,7 @@ serde = { optional = true, version = "1.0", features = ["derive"] }
|
||||
static_assertions = "1.1.0"
|
||||
tarpc-plugins = { path = "../plugins", version = "0.9" }
|
||||
tokio = { version = "1", features = ["time"] }
|
||||
tokio-util = { optional = true, version = "0.6" }
|
||||
tokio-util = { version = "0.6.3", features = ["time"] }
|
||||
tokio-serde = { optional = true, version = "0.8" }
|
||||
|
||||
[dev-dependencies]
|
||||
@@ -46,10 +46,11 @@ bincode = "1.3"
|
||||
bytes = { version = "1", features = ["serde"] }
|
||||
env_logger = "0.8"
|
||||
flate2 = "1.0"
|
||||
futures-test = "0.3"
|
||||
log = "0.4"
|
||||
pin-utils = "0.1.0-alpha"
|
||||
serde_bytes = "0.11"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio = { version = "1", features = ["full", "test-util"] }
|
||||
tokio-serde = { version = "0.8", features = ["json", "bincode"] }
|
||||
trybuild = "1.0"
|
||||
|
||||
|
||||
@@ -4,16 +4,12 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use futures::{
|
||||
future::{self, Ready},
|
||||
prelude::*,
|
||||
};
|
||||
use futures::future::{self, Ready};
|
||||
use std::io;
|
||||
use tarpc::{
|
||||
client, context,
|
||||
server::{self, Channel},
|
||||
};
|
||||
use tokio_serde::formats::Json;
|
||||
|
||||
/// This is the service definition. It looks a lot like a trait definition.
|
||||
/// It defines one RPC, hello, which takes one arg, name, and returns a String.
|
||||
|
||||
@@ -6,25 +6,22 @@
|
||||
|
||||
//! Provides a server that concurrently handles many connections sending multiplexed requests.
|
||||
|
||||
use crate::{
|
||||
context, trace, util::Compact, util::TimeUntil, ClientMessage, PollIo, Request, Response,
|
||||
ServerError, Transport,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use crate::{context, ClientMessage, PollIo, Request, Response, ServerError, Transport};
|
||||
use futures::{
|
||||
channel::mpsc,
|
||||
future::{AbortHandle, AbortRegistration, Abortable},
|
||||
future::{AbortRegistration, Abortable},
|
||||
prelude::*,
|
||||
ready,
|
||||
stream::Fuse,
|
||||
task::*,
|
||||
};
|
||||
use humantime::format_rfc3339;
|
||||
use log::{debug, trace};
|
||||
use pin_project::{pin_project, pinned_drop};
|
||||
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin};
|
||||
use log::{debug, info, trace};
|
||||
use pin_project::pin_project;
|
||||
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
|
||||
|
||||
mod filter;
|
||||
mod in_flight_requests;
|
||||
#[cfg(test)]
|
||||
mod testing;
|
||||
mod throttle;
|
||||
@@ -134,14 +131,14 @@ where
|
||||
/// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation
|
||||
/// messages. Instead, it internally handles them by cancelling corresponding requests (removing
|
||||
/// the corresponding in-flight requests and aborting their handlers).
|
||||
#[pin_project(PinnedDrop)]
|
||||
#[pin_project]
|
||||
pub struct BaseChannel<Req, Resp, T> {
|
||||
config: Config,
|
||||
/// Writes responses to the wire and reads requests off the wire.
|
||||
#[pin]
|
||||
transport: Fuse<T>,
|
||||
/// Number of requests currently being responded to.
|
||||
in_flight_requests: FnvHashMap<u64, AbortHandle>,
|
||||
/// Holds data necessary to clean up in-flight requests.
|
||||
in_flight_requests: in_flight_requests::InFlightRequests,
|
||||
/// Types the request and response.
|
||||
ghost: PhantomData<(Req, Resp)>,
|
||||
}
|
||||
@@ -155,7 +152,7 @@ where
|
||||
BaseChannel {
|
||||
config,
|
||||
transport: transport.fuse(),
|
||||
in_flight_requests: FnvHashMap::default(),
|
||||
in_flight_requests: in_flight_requests::InFlightRequests::default(),
|
||||
ghost: PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -176,35 +173,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> BaseChannel<Req, Resp, T> {
|
||||
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
|
||||
// It's possible the request was already completed, so it's fine
|
||||
// if this is None.
|
||||
if let Some(cancel_handle) = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&request_id)
|
||||
{
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
|
||||
cancel_handle.abort();
|
||||
let remaining = self.as_mut().project().in_flight_requests.len();
|
||||
trace!(
|
||||
"[{}] Request canceled. In-flight requests = {}",
|
||||
trace_context.trace_id,
|
||||
remaining,
|
||||
);
|
||||
} else {
|
||||
trace!(
|
||||
"[{}] Received cancellation, but response handler \
|
||||
is already complete.",
|
||||
trace_context.trace_id,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "BaseChannel")
|
||||
@@ -260,7 +228,14 @@ where
|
||||
/// Tells the Channel that request with ID `request_id` is being handled.
|
||||
/// The request will be tracked until a response with the same ID is sent
|
||||
/// to the Channel.
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration;
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, in_flight_requests::AlreadyExistsError>;
|
||||
|
||||
/// Yields a request that has expired, aborting any ongoing processing of that request.
|
||||
fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo<u64>;
|
||||
|
||||
/// Returns a stream of requests that automatically handle request cancellation and response
|
||||
/// routing.
|
||||
@@ -312,7 +287,25 @@ where
|
||||
trace_context,
|
||||
request_id,
|
||||
} => {
|
||||
self.as_mut().cancel_request(&trace_context, request_id);
|
||||
if self
|
||||
.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.cancel_request(request_id)
|
||||
{
|
||||
let remaining = self.in_flight_requests.len();
|
||||
trace!(
|
||||
"[{}] Request canceled. In-flight requests = {}",
|
||||
trace_context.trace_id,
|
||||
remaining,
|
||||
);
|
||||
} else {
|
||||
trace!(
|
||||
"[{}] Received cancellation, but response handler \
|
||||
is already complete.",
|
||||
trace_context.trace_id,
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
None => return Poll::Ready(None),
|
||||
@@ -332,16 +325,10 @@ where
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
|
||||
if self
|
||||
.as_mut()
|
||||
self.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&response.request_id)
|
||||
.is_some()
|
||||
{
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
}
|
||||
|
||||
.remove_request(response.request_id);
|
||||
self.project().transport.start_send(response)
|
||||
}
|
||||
|
||||
@@ -354,17 +341,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[pinned_drop]
|
||||
impl<Req, Resp, T> PinnedDrop for BaseChannel<Req, Resp, T> {
|
||||
fn drop(mut self: Pin<&mut Self>) {
|
||||
self.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.values()
|
||||
.for_each(AbortHandle::abort);
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
|
||||
fn as_ref(&self) -> &T {
|
||||
self.transport.get_ref()
|
||||
@@ -386,14 +362,18 @@ where
|
||||
self.in_flight_requests.len()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
let (abort_handle, abort_registration) = AbortHandle::new_pair();
|
||||
assert!(self
|
||||
.project()
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, in_flight_requests::AlreadyExistsError> {
|
||||
self.project()
|
||||
.in_flight_requests
|
||||
.insert(request_id, abort_handle)
|
||||
.is_none());
|
||||
abort_registration
|
||||
.start_request(id, deadline)
|
||||
}
|
||||
|
||||
fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo<u64> {
|
||||
self.project().in_flight_requests.poll_expired(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -426,16 +406,41 @@ where
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<InFlightRequest<C::Req, C::Resp>> {
|
||||
match ready!(self.as_mut().project().channel.poll_next(cx)?) {
|
||||
Some(request) => {
|
||||
let abort_registration = self.as_mut().project().channel.start_request(request.id);
|
||||
Poll::Ready(Some(Ok(InFlightRequest {
|
||||
request,
|
||||
response_tx: self.responses_tx.clone(),
|
||||
abort_registration,
|
||||
})))
|
||||
loop {
|
||||
match ready!(self.as_mut().project().channel.poll_next(cx)?) {
|
||||
Some(request) => {
|
||||
trace!(
|
||||
"[{}] Handling request with deadline {}.",
|
||||
request.context.trace_id(),
|
||||
format_rfc3339(request.context.deadline),
|
||||
);
|
||||
|
||||
match self
|
||||
.channel_pin_mut()
|
||||
.start_request(request.id, request.context.deadline)
|
||||
{
|
||||
Ok(abort_registration) => {
|
||||
return Poll::Ready(Some(Ok(InFlightRequest {
|
||||
request,
|
||||
response_tx: self.responses_tx.clone(),
|
||||
abort_registration,
|
||||
})))
|
||||
}
|
||||
// Instead of closing the channel if a duplicate request is sent, just
|
||||
// ignore it, since it's already being processed. Note that we cannot
|
||||
// return Poll::Pending here, since nothing has scheduled a wakeup yet.
|
||||
Err(in_flight_requests::AlreadyExistsError) => {
|
||||
info!(
|
||||
"[{}] Request ID {} delivered more than once.",
|
||||
request.context.trace_id(),
|
||||
request.id
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -444,6 +449,17 @@ where
|
||||
cx: &mut Context<'_>,
|
||||
read_half_closed: bool,
|
||||
) -> PollIo<()> {
|
||||
if let Poll::Ready(Some(request_id)) = self.channel_pin_mut().poll_expired(cx)? {
|
||||
debug!("Request {} did not complete before deadline", request_id);
|
||||
self.channel_pin_mut().start_send(Response {
|
||||
request_id,
|
||||
message: Err(ServerError {
|
||||
kind: io::ErrorKind::TimedOut,
|
||||
detail: Some(format!("Request did not complete before deadline.")),
|
||||
}),
|
||||
})?;
|
||||
return Poll::Ready(Some(Ok(())));
|
||||
}
|
||||
match self.as_mut().poll_next_response(cx)? {
|
||||
Poll::Ready(Some((context, response))) => {
|
||||
trace!(
|
||||
@@ -451,6 +467,12 @@ where
|
||||
context.trace_id(),
|
||||
self.channel.in_flight_requests(),
|
||||
);
|
||||
// TODO: it's possible for poll_flush to be starved and start_send to end up full.
|
||||
// Currently that would cause the channel to shut down. serde_transport internally
|
||||
// uses tokio-util Framed, which will allocate as much as needed. But other
|
||||
// transports may work differently.
|
||||
//
|
||||
// There should be a way to know if a flush is needed soon.
|
||||
self.channel_pin_mut().start_send(response)?;
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
@@ -543,39 +565,10 @@ impl<Req, Res> InFlightRequest<Req, Res> {
|
||||
message,
|
||||
id: request_id,
|
||||
} = request;
|
||||
let trace_id = *request.context.trace_id();
|
||||
let deadline = request.context.deadline;
|
||||
let timeout = deadline.time_until();
|
||||
trace!(
|
||||
"[{}] Handling request with deadline {} (timeout {:?}).",
|
||||
trace_id,
|
||||
format_rfc3339(deadline),
|
||||
timeout,
|
||||
);
|
||||
let result =
|
||||
tokio::time::timeout(timeout, async { serve.serve(context, message).await })
|
||||
.await;
|
||||
let response = serve.serve(context, message).await;
|
||||
let response = Response {
|
||||
request_id,
|
||||
message: match result {
|
||||
Ok(message) => Ok(message),
|
||||
Err(tokio::time::error::Elapsed { .. }) => {
|
||||
debug!(
|
||||
"[{}] Response did not complete before deadline of {}s.",
|
||||
trace_id,
|
||||
format_rfc3339(deadline)
|
||||
);
|
||||
// No point in responding, since the client will have dropped the
|
||||
// request.
|
||||
Err(ServerError {
|
||||
kind: io::ErrorKind::TimedOut,
|
||||
detail: Some(format!(
|
||||
"Response did not complete before deadline of {}s.",
|
||||
format_rfc3339(deadline)
|
||||
)),
|
||||
})
|
||||
}
|
||||
},
|
||||
message: Ok(response),
|
||||
};
|
||||
let _ = response_tx.send((context, response)).await;
|
||||
},
|
||||
@@ -687,7 +680,7 @@ where
|
||||
while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) {
|
||||
tokio::spawn(channel.execute(self.serve.clone()));
|
||||
}
|
||||
log::info!("Server shutting down.");
|
||||
info!("Server shutting down.");
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
@@ -713,7 +706,7 @@ where
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
log::info!("Requests stream errored out: {}", e);
|
||||
info!("Requests stream errored out: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
use crate::{
|
||||
server::{self, Channel},
|
||||
util::Compact,
|
||||
PollIo,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{channel::mpsc, future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*};
|
||||
@@ -15,6 +16,7 @@ use pin_project::pin_project;
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::{
|
||||
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
|
||||
time::SystemTime,
|
||||
};
|
||||
|
||||
/// A single-threaded filter that drops channels based on per-key limits.
|
||||
@@ -112,8 +114,16 @@ where
|
||||
self.inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
self.project().inner.start_request(request_id)
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.project().inner.start_request(id, deadline)
|
||||
}
|
||||
|
||||
fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo<u64> {
|
||||
self.project().inner.poll_expired(cx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
193
tarpc/src/server/in_flight_requests.rs
Normal file
193
tarpc/src/server/in_flight_requests.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use crate::{
|
||||
util::{Compact, TimeUntil},
|
||||
PollIo,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{
|
||||
future::{AbortHandle, AbortRegistration},
|
||||
ready,
|
||||
};
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
io,
|
||||
task::{Context, Poll},
|
||||
time::SystemTime,
|
||||
};
|
||||
use tokio_util::time::delay_queue::{self, DelayQueue};
|
||||
|
||||
/// A data structure that tracks in-flight requests. It aborts requests,
|
||||
/// either on demand or when a request deadline expires.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct InFlightRequests {
|
||||
request_data: FnvHashMap<u64, RequestData>,
|
||||
deadlines: DelayQueue<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Data needed to clean up a single in-flight request.
|
||||
struct RequestData {
|
||||
/// Aborts the response handler for the associated request.
|
||||
abort_handle: AbortHandle,
|
||||
/// The key to remove the timer for the request's deadline.
|
||||
deadline_key: delay_queue::Key,
|
||||
}
|
||||
|
||||
/// An error returned when a request attempted to start with the same ID as a request already
|
||||
/// in flight.
|
||||
#[derive(Debug)]
|
||||
pub struct AlreadyExistsError;
|
||||
|
||||
impl InFlightRequests {
|
||||
pub fn len(&self) -> usize {
|
||||
self.request_data.len()
|
||||
}
|
||||
|
||||
/// Starts a request, unless a request with the same ID is already in flight.
|
||||
pub fn start_request(
|
||||
&mut self,
|
||||
request_id: u64,
|
||||
deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, AlreadyExistsError> {
|
||||
let timeout = deadline.time_until();
|
||||
let (abort_handle, abort_registration) = AbortHandle::new_pair();
|
||||
let deadline_key = self.deadlines.insert(request_id, timeout);
|
||||
match self.request_data.entry(request_id) {
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
vacant.insert(RequestData {
|
||||
abort_handle,
|
||||
deadline_key,
|
||||
});
|
||||
Ok(abort_registration)
|
||||
}
|
||||
hash_map::Entry::Occupied(_) => {
|
||||
self.deadlines.remove(&deadline_key);
|
||||
Err(AlreadyExistsError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancels an in-flight request. Returns true iff the request was found.
|
||||
pub fn cancel_request(&mut self, request_id: u64) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
|
||||
request_data.abort_handle.abort();
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a request without aborting. Returns true iff the request was found.
|
||||
pub fn remove_request(&mut self, request_id: u64) -> bool {
|
||||
if let Some(request_data) = self.request_data.remove(&request_id) {
|
||||
self.request_data.compact(0.1);
|
||||
|
||||
self.deadlines.remove(&request_data.deadline_key);
|
||||
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Yields a request that has expired, aborting any ongoing processing of that request.
|
||||
pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo<u64> {
|
||||
Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) {
|
||||
Some(Ok(expired)) => {
|
||||
if let Some(request_data) = self.request_data.remove(expired.get_ref()) {
|
||||
self.request_data.compact(0.1);
|
||||
request_data.abort_handle.abort();
|
||||
}
|
||||
Some(Ok(expired.into_inner()))
|
||||
}
|
||||
Some(Err(e)) => Some(Err(io::Error::new(io::ErrorKind::Other, e))),
|
||||
None => None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// When InFlightRequests is dropped, any requests still in flight are aborted.
|
||||
impl Drop for InFlightRequests {
|
||||
fn drop(self: &mut Self) {
|
||||
self.request_data
|
||||
.values()
|
||||
.for_each(|request_data| request_data.abort_handle.abort())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use {
|
||||
assert_matches::assert_matches,
|
||||
futures::{
|
||||
future::{pending, Abortable},
|
||||
FutureExt,
|
||||
},
|
||||
futures_test::task::noop_context,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_request_increases_len() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.unwrap();
|
||||
assert_eq!(in_flight_requests.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn polling_expired_aborts() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
tokio::time::pause();
|
||||
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
|
||||
|
||||
assert_matches!(
|
||||
in_flight_requests.poll_expired(&mut noop_context()),
|
||||
Poll::Ready(Some(Ok(_)))
|
||||
);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Ready(Err(_))
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancel_request_aborts() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
assert_eq!(in_flight_requests.cancel_request(0), true);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Ready(Err(_))
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_request_doesnt_abort() {
|
||||
let mut in_flight_requests = InFlightRequests::default();
|
||||
let abort_registration = in_flight_requests
|
||||
.start_request(0, SystemTime::now())
|
||||
.unwrap();
|
||||
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
|
||||
|
||||
assert_eq!(in_flight_requests.remove_request(0), true);
|
||||
assert_matches!(
|
||||
abortable_future.poll_unpin(&mut noop_context()),
|
||||
Poll::Pending
|
||||
);
|
||||
assert_eq!(in_flight_requests.len(), 0);
|
||||
}
|
||||
@@ -4,14 +4,12 @@
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use crate::server::{Channel, Config};
|
||||
use crate::{context, Request, Response};
|
||||
use fnv::FnvHashSet;
|
||||
use futures::{
|
||||
future::{AbortHandle, AbortRegistration},
|
||||
task::*,
|
||||
Sink, Stream,
|
||||
use crate::{
|
||||
context,
|
||||
server::{Channel, Config},
|
||||
PollIo, Request, Response,
|
||||
};
|
||||
use futures::{future::AbortRegistration, task::*, Sink, Stream};
|
||||
use pin_project::pin_project;
|
||||
use std::collections::VecDeque;
|
||||
use std::io;
|
||||
@@ -25,7 +23,7 @@ pub(crate) struct FakeChannel<In, Out> {
|
||||
#[pin]
|
||||
pub sink: VecDeque<Out>,
|
||||
pub config: Config,
|
||||
pub in_flight_requests: FnvHashSet<u64>,
|
||||
pub in_flight_requests: super::in_flight_requests::InFlightRequests,
|
||||
}
|
||||
|
||||
impl<In, Out> Stream for FakeChannel<In, Out>
|
||||
@@ -50,7 +48,7 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
|
||||
self.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&response.request_id);
|
||||
.remove_request(response.request_id);
|
||||
self.project()
|
||||
.sink
|
||||
.start_send(response)
|
||||
@@ -81,9 +79,18 @@ where
|
||||
self.in_flight_requests.len()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration {
|
||||
self.project().in_flight_requests.insert(id);
|
||||
AbortHandle::new_pair().1
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.project()
|
||||
.in_flight_requests
|
||||
.start_request(id, deadline)
|
||||
}
|
||||
|
||||
fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo<u64> {
|
||||
self.project().in_flight_requests.poll_expired(cx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,11 +5,11 @@
|
||||
// https://opensource.org/licenses/MIT.
|
||||
|
||||
use super::{Channel, Config};
|
||||
use crate::{Response, ServerError};
|
||||
use crate::{PollIo, Response, ServerError};
|
||||
use futures::{future::AbortRegistration, prelude::*, ready, task::*};
|
||||
use log::debug;
|
||||
use pin_project::pin_project;
|
||||
use std::{io, pin::Pin};
|
||||
use std::{io, pin::Pin, time::SystemTime};
|
||||
|
||||
/// A [`Channel`] that limits the number of concurrent
|
||||
/// requests by throttling.
|
||||
@@ -121,8 +121,16 @@ where
|
||||
self.inner.config()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
self.project().inner.start_request(request_id)
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
id: u64,
|
||||
deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
self.project().inner.start_request(id, deadline)
|
||||
}
|
||||
|
||||
fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo<u64> {
|
||||
self.project().inner.poll_expired(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,10 +181,10 @@ use crate::Request;
|
||||
#[cfg(test)]
|
||||
use pin_utils::pin_mut;
|
||||
#[cfg(test)]
|
||||
use std::marker::PhantomData;
|
||||
use std::{marker::PhantomData, time::Duration};
|
||||
|
||||
#[test]
|
||||
fn throttler_in_flight_requests() {
|
||||
#[tokio::test]
|
||||
async fn throttler_in_flight_requests() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
@@ -184,20 +192,27 @@ fn throttler_in_flight_requests() {
|
||||
|
||||
pin_mut!(throttler);
|
||||
for i in 0..5 {
|
||||
throttler.inner.in_flight_requests.insert(i);
|
||||
throttler
|
||||
.inner
|
||||
.in_flight_requests
|
||||
.start_request(i, SystemTime::now() + Duration::from_secs(1))
|
||||
.unwrap();
|
||||
}
|
||||
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_start_request() {
|
||||
#[tokio::test]
|
||||
async fn throttler_start_request() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.as_mut().start_request(1);
|
||||
throttler
|
||||
.as_mut()
|
||||
.start_request(1, SystemTime::now() + Duration::from_secs(1))
|
||||
.unwrap();
|
||||
assert_eq!(throttler.inner.in_flight_requests.len(), 1);
|
||||
}
|
||||
|
||||
@@ -295,21 +310,32 @@ fn throttler_poll_next_throttled_sink_not_ready() {
|
||||
fn in_flight_requests(&self) -> usize {
|
||||
0
|
||||
}
|
||||
fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration {
|
||||
fn start_request(
|
||||
self: Pin<&mut Self>,
|
||||
_id: u64,
|
||||
_deadline: SystemTime,
|
||||
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
|
||||
unimplemented!()
|
||||
}
|
||||
fn poll_expired(self: Pin<&mut Self>, _cx: &mut Context) -> PollIo<u64> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throttler_start_send() {
|
||||
#[tokio::test]
|
||||
async fn throttler_start_send() {
|
||||
let throttler = Throttler {
|
||||
max_in_flight_requests: 0,
|
||||
inner: FakeChannel::default::<isize, isize>(),
|
||||
};
|
||||
|
||||
pin_mut!(throttler);
|
||||
throttler.inner.in_flight_requests.insert(0);
|
||||
throttler
|
||||
.inner
|
||||
.in_flight_requests
|
||||
.start_request(0, SystemTime::now() + Duration::from_secs(1))
|
||||
.unwrap();
|
||||
throttler
|
||||
.as_mut()
|
||||
.start_send(Response {
|
||||
@@ -317,7 +343,7 @@ fn throttler_start_send() {
|
||||
message: Ok(1),
|
||||
})
|
||||
.unwrap();
|
||||
assert!(throttler.inner.in_flight_requests.is_empty());
|
||||
assert_eq!(throttler.inner.in_flight_requests.len(), 0);
|
||||
assert_eq!(
|
||||
throttler.inner.sink.get(0),
|
||||
Some(&Response {
|
||||
|
||||
Reference in New Issue
Block a user