mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-02-23 15:49:54 +01:00
This trait fn returns a private type, which means it's useless for anyone using the Channel. Instead, add an inert (now-public) ResponseGuard to TrackedRequest that, when taken out of the ManuallyDrop, ensures a Channel's request state is cleaned up. It's preferable to make ResponseGuard public instead of RequestCancellations because it's a smaller API surface (no public methods, just a Drop fn) and harder to misuse, because it is already associated with the correct request ID to cancel.
350 lines
9.8 KiB
Rust
350 lines
9.8 KiB
Rust
// Copyright 2020 Google LLC
|
|
//
|
|
// Use of this source code is governed by an MIT-style
|
|
// license that can be found in the LICENSE file or at
|
|
// https://opensource.org/licenses/MIT.
|
|
|
|
use crate::{
|
|
server::{Channel, Config},
|
|
Response, ServerError,
|
|
};
|
|
use futures::{prelude::*, ready, task::*};
|
|
use pin_project::pin_project;
|
|
use std::{io, pin::Pin};
|
|
|
|
/// A [`Channel`] that limits the number of concurrent requests by throttling.
|
|
///
|
|
/// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low
|
|
/// for the resources available to the server. For production use cases, a more advanced throttler
|
|
/// is likely needed.
|
|
#[pin_project]
|
|
#[derive(Debug)]
|
|
pub struct MaxRequests<C> {
|
|
max_in_flight_requests: usize,
|
|
#[pin]
|
|
inner: C,
|
|
}
|
|
|
|
impl<C> MaxRequests<C> {
|
|
/// Returns the inner channel.
|
|
pub fn get_ref(&self) -> &C {
|
|
&self.inner
|
|
}
|
|
}
|
|
|
|
impl<C> MaxRequests<C>
|
|
where
|
|
C: Channel,
|
|
{
|
|
/// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to
|
|
/// `max_in_flight_requests`.
|
|
pub fn new(inner: C, max_in_flight_requests: usize) -> Self {
|
|
MaxRequests {
|
|
max_in_flight_requests,
|
|
inner,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<C> Stream for MaxRequests<C>
|
|
where
|
|
C: Channel,
|
|
{
|
|
type Item = <C as Stream>::Item;
|
|
|
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
|
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
|
|
{
|
|
ready!(self.as_mut().project().inner.poll_ready(cx)?);
|
|
|
|
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
|
|
Some(r) => {
|
|
let _entered = r.span.enter();
|
|
tracing::info!(
|
|
in_flight_requests = self.as_mut().in_flight_requests(),
|
|
"ThrottleRequest",
|
|
);
|
|
|
|
self.as_mut().start_send(Response {
|
|
request_id: r.request.id,
|
|
message: Err(ServerError {
|
|
kind: io::ErrorKind::WouldBlock,
|
|
detail: "server throttled the request.".into(),
|
|
}),
|
|
})?;
|
|
}
|
|
None => return Poll::Ready(None),
|
|
}
|
|
}
|
|
self.project().inner.poll_next(cx)
|
|
}
|
|
}
|
|
|
|
impl<C> Sink<Response<<C as Channel>::Resp>> for MaxRequests<C>
|
|
where
|
|
C: Channel,
|
|
{
|
|
type Error = C::Error;
|
|
|
|
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
|
self.project().inner.poll_ready(cx)
|
|
}
|
|
|
|
fn start_send(
|
|
self: Pin<&mut Self>,
|
|
item: Response<<C as Channel>::Resp>,
|
|
) -> Result<(), Self::Error> {
|
|
self.project().inner.start_send(item)
|
|
}
|
|
|
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
|
self.project().inner.poll_flush(cx)
|
|
}
|
|
|
|
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
|
self.project().inner.poll_close(cx)
|
|
}
|
|
}
|
|
|
|
impl<C> AsRef<C> for MaxRequests<C> {
|
|
fn as_ref(&self) -> &C {
|
|
&self.inner
|
|
}
|
|
}
|
|
|
|
impl<C> Channel for MaxRequests<C>
|
|
where
|
|
C: Channel,
|
|
{
|
|
type Req = <C as Channel>::Req;
|
|
type Resp = <C as Channel>::Resp;
|
|
type Transport = <C as Channel>::Transport;
|
|
|
|
fn in_flight_requests(&self) -> usize {
|
|
self.inner.in_flight_requests()
|
|
}
|
|
|
|
fn config(&self) -> &Config {
|
|
self.inner.config()
|
|
}
|
|
|
|
fn transport(&self) -> &Self::Transport {
|
|
self.inner.transport()
|
|
}
|
|
}
|
|
|
|
/// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on
|
|
/// the number of in-flight requests.
|
|
#[pin_project]
|
|
#[derive(Debug)]
|
|
pub struct MaxRequestsPerChannel<S> {
|
|
#[pin]
|
|
inner: S,
|
|
max_in_flight_requests: usize,
|
|
}
|
|
|
|
impl<S> MaxRequestsPerChannel<S>
|
|
where
|
|
S: Stream,
|
|
<S as Stream>::Item: Channel,
|
|
{
|
|
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
|
|
Self {
|
|
inner,
|
|
max_in_flight_requests,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<S> Stream for MaxRequestsPerChannel<S>
|
|
where
|
|
S: Stream,
|
|
<S as Stream>::Item: Channel,
|
|
{
|
|
type Item = MaxRequests<<S as Stream>::Item>;
|
|
|
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
|
match ready!(self.as_mut().project().inner.poll_next(cx)) {
|
|
Some(channel) => Poll::Ready(Some(MaxRequests::new(
|
|
channel,
|
|
*self.project().max_in_flight_requests,
|
|
))),
|
|
None => Poll::Ready(None),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
use crate::server::{
|
|
testing::{self, FakeChannel, PollExt},
|
|
TrackedRequest,
|
|
};
|
|
use pin_utils::pin_mut;
|
|
use std::{
|
|
marker::PhantomData,
|
|
time::{Duration, SystemTime},
|
|
};
|
|
use tracing::Span;
|
|
|
|
#[tokio::test]
|
|
async fn throttler_in_flight_requests() {
|
|
let throttler = MaxRequests {
|
|
max_in_flight_requests: 0,
|
|
inner: FakeChannel::default::<isize, isize>(),
|
|
};
|
|
|
|
pin_mut!(throttler);
|
|
for i in 0..5 {
|
|
throttler
|
|
.inner
|
|
.in_flight_requests
|
|
.start_request(
|
|
i,
|
|
SystemTime::now() + Duration::from_secs(1),
|
|
Span::current(),
|
|
)
|
|
.unwrap();
|
|
}
|
|
assert_eq!(throttler.as_mut().in_flight_requests(), 5);
|
|
}
|
|
|
|
#[test]
|
|
fn throttler_poll_next_done() {
|
|
let throttler = MaxRequests {
|
|
max_in_flight_requests: 0,
|
|
inner: FakeChannel::default::<isize, isize>(),
|
|
};
|
|
|
|
pin_mut!(throttler);
|
|
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
|
}
|
|
|
|
#[test]
|
|
fn throttler_poll_next_some() -> io::Result<()> {
|
|
let throttler = MaxRequests {
|
|
max_in_flight_requests: 1,
|
|
inner: FakeChannel::default::<isize, isize>(),
|
|
};
|
|
|
|
pin_mut!(throttler);
|
|
throttler.inner.push_req(0, 1);
|
|
assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready());
|
|
assert_eq!(
|
|
throttler
|
|
.as_mut()
|
|
.poll_next(&mut testing::cx())?
|
|
.map(|r| r.map(|r| (r.request.id, r.request.message))),
|
|
Poll::Ready(Some((0, 1)))
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn throttler_poll_next_throttled() {
|
|
let throttler = MaxRequests {
|
|
max_in_flight_requests: 0,
|
|
inner: FakeChannel::default::<isize, isize>(),
|
|
};
|
|
|
|
pin_mut!(throttler);
|
|
throttler.inner.push_req(1, 1);
|
|
assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done());
|
|
assert_eq!(throttler.inner.sink.len(), 1);
|
|
let resp = throttler.inner.sink.get(0).unwrap();
|
|
assert_eq!(resp.request_id, 1);
|
|
assert!(resp.message.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn throttler_poll_next_throttled_sink_not_ready() {
|
|
let throttler = MaxRequests {
|
|
max_in_flight_requests: 0,
|
|
inner: PendingSink::default::<isize, isize>(),
|
|
};
|
|
pin_mut!(throttler);
|
|
assert!(throttler.poll_next(&mut testing::cx()).is_pending());
|
|
|
|
struct PendingSink<In, Out> {
|
|
ghost: PhantomData<fn(Out) -> In>,
|
|
}
|
|
impl PendingSink<(), ()> {
|
|
pub fn default<Req, Resp>(
|
|
) -> PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
|
PendingSink { ghost: PhantomData }
|
|
}
|
|
}
|
|
impl<In, Out> Stream for PendingSink<In, Out> {
|
|
type Item = In;
|
|
fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
|
|
unimplemented!()
|
|
}
|
|
}
|
|
impl<In, Out> Sink<Out> for PendingSink<In, Out> {
|
|
type Error = io::Error;
|
|
fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
|
Poll::Pending
|
|
}
|
|
fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> {
|
|
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
|
}
|
|
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
|
Poll::Pending
|
|
}
|
|
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
|
|
Poll::Pending
|
|
}
|
|
}
|
|
impl<Req, Resp> Channel for PendingSink<io::Result<TrackedRequest<Req>>, Response<Resp>> {
|
|
type Req = Req;
|
|
type Resp = Resp;
|
|
type Transport = ();
|
|
fn config(&self) -> &Config {
|
|
unimplemented!()
|
|
}
|
|
fn in_flight_requests(&self) -> usize {
|
|
0
|
|
}
|
|
fn transport(&self) -> &() {
|
|
&()
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn throttler_start_send() {
|
|
let throttler = MaxRequests {
|
|
max_in_flight_requests: 0,
|
|
inner: FakeChannel::default::<isize, isize>(),
|
|
};
|
|
|
|
pin_mut!(throttler);
|
|
throttler
|
|
.inner
|
|
.in_flight_requests
|
|
.start_request(
|
|
0,
|
|
SystemTime::now() + Duration::from_secs(1),
|
|
Span::current(),
|
|
)
|
|
.unwrap();
|
|
throttler
|
|
.as_mut()
|
|
.start_send(Response {
|
|
request_id: 0,
|
|
message: Ok(1),
|
|
})
|
|
.unwrap();
|
|
assert_eq!(throttler.inner.in_flight_requests.len(), 0);
|
|
assert_eq!(
|
|
throttler.inner.sink.get(0),
|
|
Some(&Response {
|
|
request_id: 0,
|
|
message: Ok(1),
|
|
})
|
|
);
|
|
}
|
|
}
|