mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-01-21 10:38:26 +01:00
Add tests for rpc/server/filter.rs
This commit is contained in:
@@ -33,3 +33,4 @@ futures-test-preview = { version = "0.3.0-alpha.17" }
|
||||
env_logger = "0.6"
|
||||
tokio = "0.1"
|
||||
tokio-executor = "0.1"
|
||||
assert_matches = "1.0"
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
use crate::{
|
||||
server::{self, Channel},
|
||||
util::Compact,
|
||||
Response,
|
||||
};
|
||||
use fnv::FnvHashMap;
|
||||
use futures::{
|
||||
@@ -22,7 +21,7 @@ use log::{debug, info, trace};
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::{
|
||||
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, io, marker::Unpin, pin::Pin,
|
||||
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
|
||||
};
|
||||
|
||||
/// A single-threaded filter that drops channels based on per-key limits.
|
||||
@@ -63,34 +62,9 @@ impl<K> Drop for Tracker<K> {
|
||||
}
|
||||
}
|
||||
|
||||
/// A running handler serving all requests for a single client.
|
||||
#[derive(Debug)]
|
||||
pub struct TrackedHandler<K, Fut> {
|
||||
inner: Fut,
|
||||
tracker: Tracker<K>,
|
||||
}
|
||||
|
||||
impl<K, Fut> TrackedHandler<K, Fut>
|
||||
where
|
||||
Fut: Future,
|
||||
{
|
||||
unsafe_pinned!(inner: Fut);
|
||||
}
|
||||
|
||||
impl<K, Fut> Future for TrackedHandler<K, Fut>
|
||||
where
|
||||
Fut: Future,
|
||||
{
|
||||
type Output = Fut::Output;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.inner().poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, K> Stream for TrackedChannel<C, K>
|
||||
where
|
||||
C: Channel,
|
||||
C: Stream,
|
||||
{
|
||||
type Item = <C as Stream>::Item;
|
||||
|
||||
@@ -99,17 +73,17 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, K> Sink<Response<C::Resp>> for TrackedChannel<C, K>
|
||||
impl<C, I, K> Sink<I> for TrackedChannel<C, K>
|
||||
where
|
||||
C: Channel,
|
||||
C: Sink<I>,
|
||||
{
|
||||
type Error = io::Error;
|
||||
type Error = C::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.channel().poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Response<C::Resp>) -> Result<(), Self::Error> {
|
||||
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
|
||||
self.channel().start_send(item)
|
||||
}
|
||||
|
||||
@@ -176,6 +150,7 @@ impl<S, K, F> ChannelFilter<S, K, F>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
S: Stream,
|
||||
F: Fn(&S::Item) -> K,
|
||||
{
|
||||
/// Sheds new channels to stay under configured limits.
|
||||
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
|
||||
@@ -191,23 +166,24 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C, K, F> ChannelFilter<S, K, F>
|
||||
impl<S, K, F> ChannelFilter<S, K, F>
|
||||
where
|
||||
S: Stream<Item = C>,
|
||||
C: Channel,
|
||||
S: Stream,
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
F: Fn(&C) -> K,
|
||||
F: Fn(&S::Item) -> K,
|
||||
{
|
||||
fn handle_new_channel(self: &mut Pin<&mut Self>, stream: C) -> Result<TrackedChannel<C, K>, K> {
|
||||
fn handle_new_channel(
|
||||
self: &mut Pin<&mut Self>,
|
||||
stream: S::Item,
|
||||
) -> Result<TrackedChannel<S::Item, K>, K> {
|
||||
let key = self.as_mut().keymaker()(&stream);
|
||||
let tracker = self.increment_channels_for_key(key.clone())?;
|
||||
let max = self.as_mut().channels_per_key();
|
||||
|
||||
debug!(
|
||||
trace!(
|
||||
"[{}] Opening channel ({}/{}) channels for key.",
|
||||
key,
|
||||
Arc::strong_count(&tracker),
|
||||
max
|
||||
self.as_mut().channels_per_key()
|
||||
);
|
||||
|
||||
Ok(TrackedChannel {
|
||||
@@ -256,7 +232,7 @@ where
|
||||
fn poll_listener(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<TrackedChannel<C, K>, K>>> {
|
||||
) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
|
||||
match ready!(self.as_mut().listener().poll_next_unpin(cx)) {
|
||||
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
|
||||
None => Poll::Ready(None),
|
||||
@@ -276,19 +252,18 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, C, K, F> Stream for ChannelFilter<S, K, F>
|
||||
impl<S, K, F> Stream for ChannelFilter<S, K, F>
|
||||
where
|
||||
S: Stream<Item = C>,
|
||||
C: Channel,
|
||||
S: Stream,
|
||||
K: fmt::Display + Eq + Hash + Clone + Unpin,
|
||||
F: Fn(&C) -> K,
|
||||
F: Fn(&S::Item) -> K,
|
||||
{
|
||||
type Item = TrackedChannel<C, K>;
|
||||
type Item = TrackedChannel<S::Item, K>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<TrackedChannel<C, K>>> {
|
||||
) -> Poll<Option<TrackedChannel<S::Item, K>>> {
|
||||
loop {
|
||||
match (
|
||||
self.as_mut().poll_listener(cx),
|
||||
@@ -310,3 +285,199 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn ctx() -> Context<'static> {
|
||||
use futures_test::task::noop_waker_ref;
|
||||
|
||||
Context::from_waker(&noop_waker_ref())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tracker_drop() {
|
||||
use assert_matches::assert_matches;
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded();
|
||||
Tracker {
|
||||
key: Some(1),
|
||||
dropped_keys: tx,
|
||||
};
|
||||
assert_matches!(rx.try_next(), Ok(Some(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tracked_channel_stream() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
|
||||
let (chan_tx, chan) = mpsc::unbounded();
|
||||
let (dropped_keys, _) = mpsc::unbounded();
|
||||
let channel = TrackedChannel {
|
||||
inner: chan,
|
||||
tracker: Arc::new(Tracker {
|
||||
key: Some(1),
|
||||
dropped_keys,
|
||||
}),
|
||||
};
|
||||
|
||||
chan_tx.unbounded_send("test").unwrap();
|
||||
pin_mut!(channel);
|
||||
assert_matches!(channel.poll_next(&mut ctx()), Poll::Ready(Some("test")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tracked_channel_sink() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
|
||||
let (chan, mut chan_rx) = mpsc::unbounded();
|
||||
let (dropped_keys, _) = mpsc::unbounded();
|
||||
let channel = TrackedChannel {
|
||||
inner: chan,
|
||||
tracker: Arc::new(Tracker {
|
||||
key: Some(1),
|
||||
dropped_keys,
|
||||
}),
|
||||
};
|
||||
|
||||
pin_mut!(channel);
|
||||
assert_matches!(channel.as_mut().poll_ready(&mut ctx()), Poll::Ready(Ok(())));
|
||||
assert_matches!(channel.as_mut().start_send("test"), Ok(()));
|
||||
assert_matches!(channel.as_mut().poll_flush(&mut ctx()), Poll::Ready(Ok(())));
|
||||
assert_matches!(chan_rx.try_next(), Ok(Some("test")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_filter_increment_channels_for_key() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (_, listener) = mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
let tracker1 = filter.increment_channels_for_key("key").unwrap();
|
||||
assert_eq!(Arc::strong_count(&tracker1), 1);
|
||||
let tracker2 = filter.increment_channels_for_key("key").unwrap();
|
||||
assert_eq!(Arc::strong_count(&tracker1), 2);
|
||||
assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
|
||||
drop(tracker2);
|
||||
assert_eq!(Arc::strong_count(&tracker1), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_filter_handle_new_channel() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (_, listener) = mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
let channel1 = filter
|
||||
.handle_new_channel(TestChannel { key: "key" })
|
||||
.unwrap();
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
|
||||
|
||||
let channel2 = filter
|
||||
.handle_new_channel(TestChannel { key: "key" })
|
||||
.unwrap();
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
|
||||
|
||||
assert_matches!(
|
||||
filter.handle_new_channel(TestChannel { key: "key" }),
|
||||
Err("key")
|
||||
);
|
||||
drop(channel2);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_filter_poll_listener() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
new_channels
|
||||
.unbounded_send(TestChannel { key: "key" })
|
||||
.unwrap();
|
||||
let channel1 =
|
||||
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
|
||||
|
||||
new_channels
|
||||
.unbounded_send(TestChannel { key: "key" })
|
||||
.unwrap();
|
||||
let _channel2 =
|
||||
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
|
||||
|
||||
new_channels
|
||||
.unbounded_send(TestChannel { key: "key" })
|
||||
.unwrap();
|
||||
let key =
|
||||
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
|
||||
assert_eq!(key, "key");
|
||||
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_filter_poll_closed_channels() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
new_channels
|
||||
.unbounded_send(TestChannel { key: "key" })
|
||||
.unwrap();
|
||||
let channel =
|
||||
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
|
||||
assert_eq!(filter.key_counts.len(), 1);
|
||||
|
||||
drop(channel);
|
||||
assert_matches!(filter.poll_closed_channels(&mut ctx()), Poll::Ready(()));
|
||||
assert!(filter.key_counts.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_filter_stream() {
|
||||
use assert_matches::assert_matches;
|
||||
use pin_utils::pin_mut;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TestChannel {
|
||||
key: &'static str,
|
||||
}
|
||||
let (new_channels, listener) = mpsc::unbounded();
|
||||
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
|
||||
pin_mut!(filter);
|
||||
|
||||
new_channels
|
||||
.unbounded_send(TestChannel { key: "key" })
|
||||
.unwrap();
|
||||
let channel = assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Ready(Some(c)) => c);
|
||||
assert_eq!(filter.key_counts.len(), 1);
|
||||
|
||||
drop(channel);
|
||||
assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Pending);
|
||||
assert!(filter.key_counts.is_empty());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user