diff --git a/rpc/Cargo.toml b/rpc/Cargo.toml index 6c9e195..c93d940 100644 --- a/rpc/Cargo.toml +++ b/rpc/Cargo.toml @@ -33,3 +33,4 @@ futures-test-preview = { version = "0.3.0-alpha.17" } env_logger = "0.6" tokio = "0.1" tokio-executor = "0.1" +assert_matches = "1.0" diff --git a/rpc/src/server/filter.rs b/rpc/src/server/filter.rs index 299b005..78419bf 100644 --- a/rpc/src/server/filter.rs +++ b/rpc/src/server/filter.rs @@ -7,7 +7,6 @@ use crate::{ server::{self, Channel}, util::Compact, - Response, }; use fnv::FnvHashMap; use futures::{ @@ -22,7 +21,7 @@ use log::{debug, info, trace}; use pin_utils::{unsafe_pinned, unsafe_unpinned}; use std::sync::{Arc, Weak}; use std::{ - collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, io, marker::Unpin, pin::Pin, + collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin, }; /// A single-threaded filter that drops channels based on per-key limits. @@ -63,34 +62,9 @@ impl Drop for Tracker { } } -/// A running handler serving all requests for a single client. -#[derive(Debug)] -pub struct TrackedHandler { - inner: Fut, - tracker: Tracker, -} - -impl TrackedHandler -where - Fut: Future, -{ - unsafe_pinned!(inner: Fut); -} - -impl Future for TrackedHandler -where - Fut: Future, -{ - type Output = Fut::Output; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.inner().poll(cx) - } -} - impl Stream for TrackedChannel where - C: Channel, + C: Stream, { type Item = ::Item; @@ -99,17 +73,17 @@ where } } -impl Sink> for TrackedChannel +impl Sink for TrackedChannel where - C: Channel, + C: Sink, { - type Error = io::Error; + type Error = C::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.channel().poll_ready(cx) } - fn start_send(self: Pin<&mut Self>, item: Response) -> Result<(), Self::Error> { + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { self.channel().start_send(item) } @@ -176,6 +150,7 @@ impl ChannelFilter where K: Eq + Hash, S: Stream, + F: Fn(&S::Item) -> K, { /// Sheds new channels to stay under configured limits. pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self { @@ -191,23 +166,24 @@ where } } -impl ChannelFilter +impl ChannelFilter where - S: Stream, - C: Channel, + S: Stream, K: fmt::Display + Eq + Hash + Clone + Unpin, - F: Fn(&C) -> K, + F: Fn(&S::Item) -> K, { - fn handle_new_channel(self: &mut Pin<&mut Self>, stream: C) -> Result, K> { + fn handle_new_channel( + self: &mut Pin<&mut Self>, + stream: S::Item, + ) -> Result, K> { let key = self.as_mut().keymaker()(&stream); let tracker = self.increment_channels_for_key(key.clone())?; - let max = self.as_mut().channels_per_key(); - debug!( + trace!( "[{}] Opening channel ({}/{}) channels for key.", key, Arc::strong_count(&tracker), - max + self.as_mut().channels_per_key() ); Ok(TrackedChannel { @@ -256,7 +232,7 @@ where fn poll_listener( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, K>>> { + ) -> Poll, K>>> { match ready!(self.as_mut().listener().poll_next_unpin(cx)) { Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))), None => Poll::Ready(None), @@ -276,19 +252,18 @@ where } } -impl Stream for ChannelFilter +impl Stream for ChannelFilter where - S: Stream, - C: Channel, + S: Stream, K: fmt::Display + Eq + Hash + Clone + Unpin, - F: Fn(&C) -> K, + F: Fn(&S::Item) -> K, { - type Item = TrackedChannel; + type Item = TrackedChannel; fn poll_next( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { loop { match ( self.as_mut().poll_listener(cx), @@ -310,3 +285,199 @@ where } } } + +#[cfg(test)] +fn ctx() -> Context<'static> { + use futures_test::task::noop_waker_ref; + + Context::from_waker(&noop_waker_ref()) +} + +#[test] +fn tracker_drop() { + use assert_matches::assert_matches; + + let (tx, mut rx) = mpsc::unbounded(); + Tracker { + key: Some(1), + dropped_keys: tx, + }; + assert_matches!(rx.try_next(), Ok(Some(1))); +} + +#[test] +fn tracked_channel_stream() { + use assert_matches::assert_matches; + use pin_utils::pin_mut; + + let (chan_tx, chan) = mpsc::unbounded(); + let (dropped_keys, _) = mpsc::unbounded(); + let channel = TrackedChannel { + inner: chan, + tracker: Arc::new(Tracker { + key: Some(1), + dropped_keys, + }), + }; + + chan_tx.unbounded_send("test").unwrap(); + pin_mut!(channel); + assert_matches!(channel.poll_next(&mut ctx()), Poll::Ready(Some("test"))); +} + +#[test] +fn tracked_channel_sink() { + use assert_matches::assert_matches; + use pin_utils::pin_mut; + + let (chan, mut chan_rx) = mpsc::unbounded(); + let (dropped_keys, _) = mpsc::unbounded(); + let channel = TrackedChannel { + inner: chan, + tracker: Arc::new(Tracker { + key: Some(1), + dropped_keys, + }), + }; + + pin_mut!(channel); + assert_matches!(channel.as_mut().poll_ready(&mut ctx()), Poll::Ready(Ok(()))); + assert_matches!(channel.as_mut().start_send("test"), Ok(())); + assert_matches!(channel.as_mut().poll_flush(&mut ctx()), Poll::Ready(Ok(()))); + assert_matches!(chan_rx.try_next(), Ok(Some("test"))); +} + +#[test] +fn channel_filter_increment_channels_for_key() { + use assert_matches::assert_matches; + use pin_utils::pin_mut; + + struct TestChannel { + key: &'static str, + } + let (_, listener) = mpsc::unbounded(); + let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + pin_mut!(filter); + let tracker1 = filter.increment_channels_for_key("key").unwrap(); + assert_eq!(Arc::strong_count(&tracker1), 1); + let tracker2 = filter.increment_channels_for_key("key").unwrap(); + assert_eq!(Arc::strong_count(&tracker1), 2); + assert_matches!(filter.increment_channels_for_key("key"), Err("key")); + drop(tracker2); + assert_eq!(Arc::strong_count(&tracker1), 1); +} + +#[test] +fn channel_filter_handle_new_channel() { + use assert_matches::assert_matches; + use pin_utils::pin_mut; + + #[derive(Debug)] + struct TestChannel { + key: &'static str, + } + let (_, listener) = mpsc::unbounded(); + let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + pin_mut!(filter); + let channel1 = filter + .handle_new_channel(TestChannel { key: "key" }) + .unwrap(); + assert_eq!(Arc::strong_count(&channel1.tracker), 1); + + let channel2 = filter + .handle_new_channel(TestChannel { key: "key" }) + .unwrap(); + assert_eq!(Arc::strong_count(&channel1.tracker), 2); + + assert_matches!( + filter.handle_new_channel(TestChannel { key: "key" }), + Err("key") + ); + drop(channel2); + assert_eq!(Arc::strong_count(&channel1.tracker), 1); +} + +#[test] +fn channel_filter_poll_listener() { + use assert_matches::assert_matches; + use pin_utils::pin_mut; + + #[derive(Debug)] + struct TestChannel { + key: &'static str, + } + let (new_channels, listener) = mpsc::unbounded(); + let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + pin_mut!(filter); + + new_channels + .unbounded_send(TestChannel { key: "key" }) + .unwrap(); + let channel1 = + assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c); + assert_eq!(Arc::strong_count(&channel1.tracker), 1); + + new_channels + .unbounded_send(TestChannel { key: "key" }) + .unwrap(); + let _channel2 = + assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c); + assert_eq!(Arc::strong_count(&channel1.tracker), 2); + + new_channels + .unbounded_send(TestChannel { key: "key" }) + .unwrap(); + let key = + assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k); + assert_eq!(key, "key"); + assert_eq!(Arc::strong_count(&channel1.tracker), 2); +} + +#[test] +fn channel_filter_poll_closed_channels() { + use assert_matches::assert_matches; + use pin_utils::pin_mut; + + #[derive(Debug)] + struct TestChannel { + key: &'static str, + } + let (new_channels, listener) = mpsc::unbounded(); + let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + pin_mut!(filter); + + new_channels + .unbounded_send(TestChannel { key: "key" }) + .unwrap(); + let channel = + assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c); + assert_eq!(filter.key_counts.len(), 1); + + drop(channel); + assert_matches!(filter.poll_closed_channels(&mut ctx()), Poll::Ready(())); + assert!(filter.key_counts.is_empty()); +} + +#[test] +fn channel_filter_stream() { + use assert_matches::assert_matches; + use pin_utils::pin_mut; + + #[derive(Debug)] + struct TestChannel { + key: &'static str, + } + let (new_channels, listener) = mpsc::unbounded(); + let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); + pin_mut!(filter); + + new_channels + .unbounded_send(TestChannel { key: "key" }) + .unwrap(); + let channel = assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Ready(Some(c)) => c); + assert_eq!(filter.key_counts.len(), 1); + + drop(channel); + assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Pending); + assert!(filter.key_counts.is_empty()); +}