Add tests for rpc/server/filter.rs

This commit is contained in:
Tim Kuehn
2019-07-16 21:48:11 -07:00
parent 9863433fea
commit 94b5b2c431
2 changed files with 219 additions and 47 deletions

View File

@@ -33,3 +33,4 @@ futures-test-preview = { version = "0.3.0-alpha.17" }
env_logger = "0.6"
tokio = "0.1"
tokio-executor = "0.1"
assert_matches = "1.0"

View File

@@ -7,7 +7,6 @@
use crate::{
server::{self, Channel},
util::Compact,
Response,
};
use fnv::FnvHashMap;
use futures::{
@@ -22,7 +21,7 @@ use log::{debug, info, trace};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::sync::{Arc, Weak};
use std::{
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, io, marker::Unpin, pin::Pin,
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
};
/// A single-threaded filter that drops channels based on per-key limits.
@@ -63,34 +62,9 @@ impl<K> Drop for Tracker<K> {
}
}
/// A running handler serving all requests for a single client.
#[derive(Debug)]
pub struct TrackedHandler<K, Fut> {
inner: Fut,
tracker: Tracker<K>,
}
impl<K, Fut> TrackedHandler<K, Fut>
where
Fut: Future,
{
unsafe_pinned!(inner: Fut);
}
impl<K, Fut> Future for TrackedHandler<K, Fut>
where
Fut: Future,
{
type Output = Fut::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.inner().poll(cx)
}
}
impl<C, K> Stream for TrackedChannel<C, K>
where
C: Channel,
C: Stream,
{
type Item = <C as Stream>::Item;
@@ -99,17 +73,17 @@ where
}
}
impl<C, K> Sink<Response<C::Resp>> for TrackedChannel<C, K>
impl<C, I, K> Sink<I> for TrackedChannel<C, K>
where
C: Channel,
C: Sink<I>,
{
type Error = io::Error;
type Error = C::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.channel().poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Response<C::Resp>) -> Result<(), Self::Error> {
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
self.channel().start_send(item)
}
@@ -176,6 +150,7 @@ impl<S, K, F> ChannelFilter<S, K, F>
where
K: Eq + Hash,
S: Stream,
F: Fn(&S::Item) -> K,
{
/// Sheds new channels to stay under configured limits.
pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self {
@@ -191,23 +166,24 @@ where
}
}
impl<S, C, K, F> ChannelFilter<S, K, F>
impl<S, K, F> ChannelFilter<S, K, F>
where
S: Stream<Item = C>,
C: Channel,
S: Stream,
K: fmt::Display + Eq + Hash + Clone + Unpin,
F: Fn(&C) -> K,
F: Fn(&S::Item) -> K,
{
fn handle_new_channel(self: &mut Pin<&mut Self>, stream: C) -> Result<TrackedChannel<C, K>, K> {
fn handle_new_channel(
self: &mut Pin<&mut Self>,
stream: S::Item,
) -> Result<TrackedChannel<S::Item, K>, K> {
let key = self.as_mut().keymaker()(&stream);
let tracker = self.increment_channels_for_key(key.clone())?;
let max = self.as_mut().channels_per_key();
debug!(
trace!(
"[{}] Opening channel ({}/{}) channels for key.",
key,
Arc::strong_count(&tracker),
max
self.as_mut().channels_per_key()
);
Ok(TrackedChannel {
@@ -256,7 +232,7 @@ where
fn poll_listener(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<TrackedChannel<C, K>, K>>> {
) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
match ready!(self.as_mut().listener().poll_next_unpin(cx)) {
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
None => Poll::Ready(None),
@@ -276,19 +252,18 @@ where
}
}
impl<S, C, K, F> Stream for ChannelFilter<S, K, F>
impl<S, K, F> Stream for ChannelFilter<S, K, F>
where
S: Stream<Item = C>,
C: Channel,
S: Stream,
K: fmt::Display + Eq + Hash + Clone + Unpin,
F: Fn(&C) -> K,
F: Fn(&S::Item) -> K,
{
type Item = TrackedChannel<C, K>;
type Item = TrackedChannel<S::Item, K>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<TrackedChannel<C, K>>> {
) -> Poll<Option<TrackedChannel<S::Item, K>>> {
loop {
match (
self.as_mut().poll_listener(cx),
@@ -310,3 +285,199 @@ where
}
}
}
#[cfg(test)]
fn ctx() -> Context<'static> {
use futures_test::task::noop_waker_ref;
Context::from_waker(&noop_waker_ref())
}
#[test]
fn tracker_drop() {
use assert_matches::assert_matches;
let (tx, mut rx) = mpsc::unbounded();
Tracker {
key: Some(1),
dropped_keys: tx,
};
assert_matches!(rx.try_next(), Ok(Some(1)));
}
#[test]
fn tracked_channel_stream() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
let (chan_tx, chan) = mpsc::unbounded();
let (dropped_keys, _) = mpsc::unbounded();
let channel = TrackedChannel {
inner: chan,
tracker: Arc::new(Tracker {
key: Some(1),
dropped_keys,
}),
};
chan_tx.unbounded_send("test").unwrap();
pin_mut!(channel);
assert_matches!(channel.poll_next(&mut ctx()), Poll::Ready(Some("test")));
}
#[test]
fn tracked_channel_sink() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
let (chan, mut chan_rx) = mpsc::unbounded();
let (dropped_keys, _) = mpsc::unbounded();
let channel = TrackedChannel {
inner: chan,
tracker: Arc::new(Tracker {
key: Some(1),
dropped_keys,
}),
};
pin_mut!(channel);
assert_matches!(channel.as_mut().poll_ready(&mut ctx()), Poll::Ready(Ok(())));
assert_matches!(channel.as_mut().start_send("test"), Ok(()));
assert_matches!(channel.as_mut().poll_flush(&mut ctx()), Poll::Ready(Ok(())));
assert_matches!(chan_rx.try_next(), Ok(Some("test")));
}
#[test]
fn channel_filter_increment_channels_for_key() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
struct TestChannel {
key: &'static str,
}
let (_, listener) = mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
let tracker1 = filter.increment_channels_for_key("key").unwrap();
assert_eq!(Arc::strong_count(&tracker1), 1);
let tracker2 = filter.increment_channels_for_key("key").unwrap();
assert_eq!(Arc::strong_count(&tracker1), 2);
assert_matches!(filter.increment_channels_for_key("key"), Err("key"));
drop(tracker2);
assert_eq!(Arc::strong_count(&tracker1), 1);
}
#[test]
fn channel_filter_handle_new_channel() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
#[derive(Debug)]
struct TestChannel {
key: &'static str,
}
let (_, listener) = mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
let channel1 = filter
.handle_new_channel(TestChannel { key: "key" })
.unwrap();
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
let channel2 = filter
.handle_new_channel(TestChannel { key: "key" })
.unwrap();
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
assert_matches!(
filter.handle_new_channel(TestChannel { key: "key" }),
Err("key")
);
drop(channel2);
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
}
#[test]
fn channel_filter_poll_listener() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
#[derive(Debug)]
struct TestChannel {
key: &'static str,
}
let (new_channels, listener) = mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let channel1 =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(Arc::strong_count(&channel1.tracker), 1);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let _channel2 =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let key =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k);
assert_eq!(key, "key");
assert_eq!(Arc::strong_count(&channel1.tracker), 2);
}
#[test]
fn channel_filter_poll_closed_channels() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
#[derive(Debug)]
struct TestChannel {
key: &'static str,
}
let (new_channels, listener) = mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let channel =
assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c);
assert_eq!(filter.key_counts.len(), 1);
drop(channel);
assert_matches!(filter.poll_closed_channels(&mut ctx()), Poll::Ready(()));
assert!(filter.key_counts.is_empty());
}
#[test]
fn channel_filter_stream() {
use assert_matches::assert_matches;
use pin_utils::pin_mut;
#[derive(Debug)]
struct TestChannel {
key: &'static str,
}
let (new_channels, listener) = mpsc::unbounded();
let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key);
pin_mut!(filter);
new_channels
.unbounded_send(TestChannel { key: "key" })
.unwrap();
let channel = assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Ready(Some(c)) => c);
assert_eq!(filter.key_counts.len(), 1);
drop(channel);
assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Pending);
assert!(filter.key_counts.is_empty());
}