diff --git a/example-service/Cargo.toml b/example-service/Cargo.toml index bc40e6d..20c1458 100644 --- a/example-service/Cargo.toml +++ b/example-service/Cargo.toml @@ -17,8 +17,8 @@ clap = "2.33" env_logger = "0.8" futures = "0.3" serde = { version = "1.0" } -tarpc = { version = "0.25", path = "../tarpc", features = ["full"] } -tokio = { version = "1", features = ["full"] } +tarpc = { path = "../tarpc", features = ["full"] } +tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] } [lib] name = "service" diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 51fafd7..abf964c 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -747,308 +747,315 @@ where } #[cfg(test)] -use { - crate::{ - trace, - transport::channel::{self, UnboundedChannel}, - }, - assert_matches::assert_matches, - futures::future::{pending, Aborted}, - futures_test::task::noop_context, - std::time::Duration, -}; +mod tests { + use super::*; -#[cfg(test)] -fn test_channel() -> ( - Pin, Response>>>>, - UnboundedChannel, ClientMessage>, -) { - let (tx, rx) = crate::transport::channel::unbounded(); - (Box::pin(BaseChannel::new(Config::default(), rx)), tx) -} + use { + crate::{ + trace, + transport::channel::{self, UnboundedChannel}, + }, + assert_matches::assert_matches, + futures::future::{pending, Aborted}, + futures_test::task::noop_context, + std::time::Duration, + }; -#[cfg(test)] -fn test_requests() -> ( - Pin< - Box, Response>>>>, - >, - UnboundedChannel, ClientMessage>, -) { - let (tx, rx) = crate::transport::channel::unbounded(); - ( - Box::pin(BaseChannel::new(Config::default(), rx).requests()), - tx, - ) -} + fn test_channel() -> ( + Pin, Response>>>>, + UnboundedChannel, ClientMessage>, + ) { + let (tx, rx) = crate::transport::channel::unbounded(); + (Box::pin(BaseChannel::new(Config::default(), rx)), tx) + } -#[cfg(test)] -fn test_bounded_requests( - capacity: usize, -) -> ( - Pin< - Box, Response>>>>, - >, - channel::Channel, ClientMessage>, -) { - let (tx, rx) = crate::transport::channel::bounded(capacity); - let mut config = Config::default(); - // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). - config.pending_response_buffer = capacity + 1; - (Box::pin(BaseChannel::new(config, rx).requests()), tx) -} + fn test_requests() -> ( + Pin< + Box< + Requests< + BaseChannel, Response>>, + >, + >, + >, + UnboundedChannel, ClientMessage>, + ) { + let (tx, rx) = crate::transport::channel::unbounded(); + ( + Box::pin(BaseChannel::new(Config::default(), rx).requests()), + tx, + ) + } -#[cfg(test)] -fn fake_request(req: Req) -> ClientMessage { - ClientMessage::Request(Request { - context: context::current(), - id: 0, - message: req, - }) -} + fn test_bounded_requests( + capacity: usize, + ) -> ( + Pin< + Box< + Requests< + BaseChannel, Response>>, + >, + >, + >, + channel::Channel, ClientMessage>, + ) { + let (tx, rx) = crate::transport::channel::bounded(capacity); + let mut config = Config::default(); + // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). + config.pending_response_buffer = capacity + 1; + (Box::pin(BaseChannel::new(config, rx).requests()), tx) + } -#[cfg(test)] -fn test_abortable( - abort_registration: AbortRegistration, -) -> impl Future> { - Abortable::new(pending(), abort_registration) -} - -#[tokio::test] -async fn base_channel_start_send_duplicate_request_returns_error() { - let (mut channel, _tx) = test_channel::<(), ()>(); - - channel - .as_mut() - .start_request(0, SystemTime::now()) - .unwrap(); - assert_matches!( - channel.as_mut().start_request(0, SystemTime::now()), - Err(AlreadyExistsError) - ); -} - -#[tokio::test] -async fn base_channel_poll_next_aborts_multiple_requests() { - let (mut channel, _tx) = test_channel::<(), ()>(); - - tokio::time::pause(); - let abort_registration0 = channel - .as_mut() - .start_request(0, SystemTime::now()) - .unwrap(); - let abort_registration1 = channel - .as_mut() - .start_request(1, SystemTime::now()) - .unwrap(); - tokio::time::advance(std::time::Duration::from_secs(1000)).await; - - assert_matches!( - channel.as_mut().poll_next(&mut noop_context()), - Poll::Pending - ); - assert_matches!(test_abortable(abort_registration0).await, Err(Aborted)); - assert_matches!(test_abortable(abort_registration1).await, Err(Aborted)); -} - -#[tokio::test] -async fn base_channel_poll_next_aborts_canceled_request() { - let (mut channel, mut tx) = test_channel::<(), ()>(); - - tokio::time::pause(); - let abort_registration = channel - .as_mut() - .start_request(0, SystemTime::now() + Duration::from_millis(100)) - .unwrap(); - - tx.send(ClientMessage::Cancel { - trace_context: trace::Context::default(), - request_id: 0, - }) - .await - .unwrap(); - - assert_matches!( - channel.as_mut().poll_next(&mut noop_context()), - Poll::Pending - ); - - assert_matches!(test_abortable(abort_registration).await, Err(Aborted)); -} - -#[tokio::test] -async fn base_channel_with_closed_transport_and_in_flight_request_returns_pending() { - let (mut channel, tx) = test_channel::<(), ()>(); - - tokio::time::pause(); - let _abort_registration = channel - .as_mut() - .start_request(0, SystemTime::now() + Duration::from_millis(100)) - .unwrap(); - - drop(tx); - assert_matches!( - channel.as_mut().poll_next(&mut noop_context()), - Poll::Pending - ); -} - -#[tokio::test] -async fn base_channel_with_closed_transport_and_no_in_flight_requests_returns_closed() { - let (mut channel, tx) = test_channel::<(), ()>(); - drop(tx); - assert_matches!( - channel.as_mut().poll_next(&mut noop_context()), - Poll::Ready(None) - ); -} - -#[tokio::test] -async fn base_channel_poll_next_yields_request() { - let (mut channel, mut tx) = test_channel::<(), ()>(); - tx.send(fake_request(())).await.unwrap(); - - assert_matches!( - channel.as_mut().poll_next(&mut noop_context()), - Poll::Ready(Some(Ok(_))) - ); -} - -#[tokio::test] -async fn base_channel_poll_next_aborts_request_and_yields_request() { - let (mut channel, mut tx) = test_channel::<(), ()>(); - - tokio::time::pause(); - let abort_registration = channel - .as_mut() - .start_request(0, SystemTime::now()) - .unwrap(); - tokio::time::advance(std::time::Duration::from_secs(1000)).await; - - tx.send(fake_request(())).await.unwrap(); - - assert_matches!( - channel.as_mut().poll_next(&mut noop_context()), - Poll::Ready(Some(Ok(_))) - ); - assert_matches!(test_abortable(abort_registration).await, Err(Aborted)); -} - -#[tokio::test] -async fn base_channel_start_send_removes_in_flight_request() { - let (mut channel, _tx) = test_channel::<(), ()>(); - - channel - .as_mut() - .start_request(0, SystemTime::now()) - .unwrap(); - assert_eq!(channel.in_flight_requests(), 1); - channel - .as_mut() - .start_send(Response { - request_id: 0, - message: Ok(()), + fn fake_request(req: Req) -> ClientMessage { + ClientMessage::Request(Request { + context: context::current(), + id: 0, + message: req, }) - .unwrap(); - assert_eq!(channel.in_flight_requests(), 0); -} + } -#[tokio::test] -async fn requests_poll_next_response_returns_pending_when_buffer_full() { - let (mut requests, _tx) = test_bounded_requests::<(), ()>(0); + fn test_abortable( + abort_registration: AbortRegistration, + ) -> impl Future> { + Abortable::new(pending(), abort_registration) + } - // Response written to the transport. - requests - .as_mut() - .channel_pin_mut() - .start_send(Response { + #[tokio::test] + async fn base_channel_start_send_duplicate_request_returns_error() { + let (mut channel, _tx) = test_channel::<(), ()>(); + + channel + .as_mut() + .start_request(0, SystemTime::now()) + .unwrap(); + assert_matches!( + channel.as_mut().start_request(0, SystemTime::now()), + Err(AlreadyExistsError) + ); + } + + #[tokio::test] + async fn base_channel_poll_next_aborts_multiple_requests() { + let (mut channel, _tx) = test_channel::<(), ()>(); + + tokio::time::pause(); + let abort_registration0 = channel + .as_mut() + .start_request(0, SystemTime::now()) + .unwrap(); + let abort_registration1 = channel + .as_mut() + .start_request(1, SystemTime::now()) + .unwrap(); + tokio::time::advance(std::time::Duration::from_secs(1000)).await; + + assert_matches!( + channel.as_mut().poll_next(&mut noop_context()), + Poll::Pending + ); + assert_matches!(test_abortable(abort_registration0).await, Err(Aborted)); + assert_matches!(test_abortable(abort_registration1).await, Err(Aborted)); + } + + #[tokio::test] + async fn base_channel_poll_next_aborts_canceled_request() { + let (mut channel, mut tx) = test_channel::<(), ()>(); + + tokio::time::pause(); + let abort_registration = channel + .as_mut() + .start_request(0, SystemTime::now() + Duration::from_millis(100)) + .unwrap(); + + tx.send(ClientMessage::Cancel { + trace_context: trace::Context::default(), request_id: 0, - message: Ok(()), }) - .unwrap(); - - // Response waiting to be written. - requests - .as_mut() - .project() - .responses_tx - .send(( - context::current(), - Response { - request_id: 1, - message: Ok(()), - }, - )) .await .unwrap(); - requests - .as_mut() - .channel_pin_mut() - .start_request(1, SystemTime::now()) - .unwrap(); + assert_matches!( + channel.as_mut().poll_next(&mut noop_context()), + Poll::Pending + ); - assert_matches!( - requests.as_mut().poll_next_response(&mut noop_context()), - Poll::Pending - ); -} + assert_matches!(test_abortable(abort_registration).await, Err(Aborted)); + } -#[tokio::test] -async fn requests_pump_write_returns_pending_when_buffer_full() { - let (mut requests, _tx) = test_bounded_requests::<(), ()>(0); + #[tokio::test] + async fn base_channel_with_closed_transport_and_in_flight_request_returns_pending() { + let (mut channel, tx) = test_channel::<(), ()>(); - // Response written to the transport. - requests - .as_mut() - .channel_pin_mut() - .start_send(Response { - request_id: 0, - message: Ok(()), - }) - .unwrap(); + tokio::time::pause(); + let _abort_registration = channel + .as_mut() + .start_request(0, SystemTime::now() + Duration::from_millis(100)) + .unwrap(); - // Response waiting to be written. - requests - .as_mut() - .project() - .responses_tx - .send(( - context::current(), - Response { - request_id: 1, + drop(tx); + assert_matches!( + channel.as_mut().poll_next(&mut noop_context()), + Poll::Pending + ); + } + + #[tokio::test] + async fn base_channel_with_closed_transport_and_no_in_flight_requests_returns_closed() { + let (mut channel, tx) = test_channel::<(), ()>(); + drop(tx); + assert_matches!( + channel.as_mut().poll_next(&mut noop_context()), + Poll::Ready(None) + ); + } + + #[tokio::test] + async fn base_channel_poll_next_yields_request() { + let (mut channel, mut tx) = test_channel::<(), ()>(); + tx.send(fake_request(())).await.unwrap(); + + assert_matches!( + channel.as_mut().poll_next(&mut noop_context()), + Poll::Ready(Some(Ok(_))) + ); + } + + #[tokio::test] + async fn base_channel_poll_next_aborts_request_and_yields_request() { + let (mut channel, mut tx) = test_channel::<(), ()>(); + + tokio::time::pause(); + let abort_registration = channel + .as_mut() + .start_request(0, SystemTime::now()) + .unwrap(); + tokio::time::advance(std::time::Duration::from_secs(1000)).await; + + tx.send(fake_request(())).await.unwrap(); + + assert_matches!( + channel.as_mut().poll_next(&mut noop_context()), + Poll::Ready(Some(Ok(_))) + ); + assert_matches!(test_abortable(abort_registration).await, Err(Aborted)); + } + + #[tokio::test] + async fn base_channel_start_send_removes_in_flight_request() { + let (mut channel, _tx) = test_channel::<(), ()>(); + + channel + .as_mut() + .start_request(0, SystemTime::now()) + .unwrap(); + assert_eq!(channel.in_flight_requests(), 1); + channel + .as_mut() + .start_send(Response { + request_id: 0, message: Ok(()), - }, - )) - .await - .unwrap(); + }) + .unwrap(); + assert_eq!(channel.in_flight_requests(), 0); + } - requests - .as_mut() - .channel_pin_mut() - .start_request(1, SystemTime::now()) - .unwrap(); + #[tokio::test] + async fn requests_poll_next_response_returns_pending_when_buffer_full() { + let (mut requests, _tx) = test_bounded_requests::<(), ()>(0); - assert_matches!( - requests.as_mut().pump_write(&mut noop_context(), true), - Poll::Pending - ); - // Assert that the pending response was not polled while the channel was blocked. - assert_matches!( - requests.as_mut().pending_responses_mut().recv().await, - Some(_) - ); -} - -#[tokio::test] -async fn requests_pump_read() { - let (mut requests, mut tx) = test_requests::<(), ()>(); - - // Response written to the transport. - tx.send(fake_request(())).await.unwrap(); - - assert_matches!( - requests.as_mut().pump_read(&mut noop_context()), - Poll::Ready(Some(Ok(_))) - ); - assert_eq!(requests.channel.in_flight_requests(), 1); + // Response written to the transport. + requests + .as_mut() + .channel_pin_mut() + .start_send(Response { + request_id: 0, + message: Ok(()), + }) + .unwrap(); + + // Response waiting to be written. + requests + .as_mut() + .project() + .responses_tx + .send(( + context::current(), + Response { + request_id: 1, + message: Ok(()), + }, + )) + .await + .unwrap(); + + requests + .as_mut() + .channel_pin_mut() + .start_request(1, SystemTime::now()) + .unwrap(); + + assert_matches!( + requests.as_mut().poll_next_response(&mut noop_context()), + Poll::Pending + ); + } + + #[tokio::test] + async fn requests_pump_write_returns_pending_when_buffer_full() { + let (mut requests, _tx) = test_bounded_requests::<(), ()>(0); + + // Response written to the transport. + requests + .as_mut() + .channel_pin_mut() + .start_send(Response { + request_id: 0, + message: Ok(()), + }) + .unwrap(); + + // Response waiting to be written. + requests + .as_mut() + .project() + .responses_tx + .send(( + context::current(), + Response { + request_id: 1, + message: Ok(()), + }, + )) + .await + .unwrap(); + + requests + .as_mut() + .channel_pin_mut() + .start_request(1, SystemTime::now()) + .unwrap(); + + assert_matches!( + requests.as_mut().pump_write(&mut noop_context(), true), + Poll::Pending + ); + // Assert that the pending response was not polled while the channel was blocked. + assert_matches!( + requests.as_mut().pending_responses_mut().recv().await, + Some(_) + ); + } + + #[tokio::test] + async fn requests_pump_read() { + let (mut requests, mut tx) = test_requests::<(), ()>(); + + // Response written to the transport. + tx.send(fake_request(())).await.unwrap(); + + assert_matches!( + requests.as_mut().pump_read(&mut noop_context()), + Poll::Ready(Some(Ok(_))) + ); + assert_eq!(requests.channel.in_flight_requests(), 1); + } } diff --git a/tarpc/src/server/in_flight_requests.rs b/tarpc/src/server/in_flight_requests.rs index d17ddbd..9a1a16d 100644 --- a/tarpc/src/server/in_flight_requests.rs +++ b/tarpc/src/server/in_flight_requests.rs @@ -118,75 +118,79 @@ impl Drop for InFlightRequests { } #[cfg(test)] -use { - assert_matches::assert_matches, - futures::{ - future::{pending, Abortable}, - FutureExt, - }, - futures_test::task::noop_context, -}; +mod tests { + use super::*; -#[tokio::test] -async fn start_request_increases_len() { - let mut in_flight_requests = InFlightRequests::default(); - assert_eq!(in_flight_requests.len(), 0); - in_flight_requests - .start_request(0, SystemTime::now()) - .unwrap(); - assert_eq!(in_flight_requests.len(), 1); -} - -#[tokio::test] -async fn polling_expired_aborts() { - let mut in_flight_requests = InFlightRequests::default(); - let abort_registration = in_flight_requests - .start_request(0, SystemTime::now()) - .unwrap(); - let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); - - tokio::time::pause(); - tokio::time::advance(std::time::Duration::from_secs(1000)).await; - - assert_matches!( - in_flight_requests.poll_expired(&mut noop_context()), - Poll::Ready(Some(Ok(_))) - ); - assert_matches!( - abortable_future.poll_unpin(&mut noop_context()), - Poll::Ready(Err(_)) - ); - assert_eq!(in_flight_requests.len(), 0); -} - -#[tokio::test] -async fn cancel_request_aborts() { - let mut in_flight_requests = InFlightRequests::default(); - let abort_registration = in_flight_requests - .start_request(0, SystemTime::now()) - .unwrap(); - let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); - - assert_eq!(in_flight_requests.cancel_request(0), true); - assert_matches!( - abortable_future.poll_unpin(&mut noop_context()), - Poll::Ready(Err(_)) - ); - assert_eq!(in_flight_requests.len(), 0); -} - -#[tokio::test] -async fn remove_request_doesnt_abort() { - let mut in_flight_requests = InFlightRequests::default(); - let abort_registration = in_flight_requests - .start_request(0, SystemTime::now()) - .unwrap(); - let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); - - assert_eq!(in_flight_requests.remove_request(0), true); - assert_matches!( - abortable_future.poll_unpin(&mut noop_context()), - Poll::Pending - ); - assert_eq!(in_flight_requests.len(), 0); + use { + assert_matches::assert_matches, + futures::{ + future::{pending, Abortable}, + FutureExt, + }, + futures_test::task::noop_context, + }; + + #[tokio::test] + async fn start_request_increases_len() { + let mut in_flight_requests = InFlightRequests::default(); + assert_eq!(in_flight_requests.len(), 0); + in_flight_requests + .start_request(0, SystemTime::now()) + .unwrap(); + assert_eq!(in_flight_requests.len(), 1); + } + + #[tokio::test] + async fn polling_expired_aborts() { + let mut in_flight_requests = InFlightRequests::default(); + let abort_registration = in_flight_requests + .start_request(0, SystemTime::now()) + .unwrap(); + let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); + + tokio::time::pause(); + tokio::time::advance(std::time::Duration::from_secs(1000)).await; + + assert_matches!( + in_flight_requests.poll_expired(&mut noop_context()), + Poll::Ready(Some(Ok(_))) + ); + assert_matches!( + abortable_future.poll_unpin(&mut noop_context()), + Poll::Ready(Err(_)) + ); + assert_eq!(in_flight_requests.len(), 0); + } + + #[tokio::test] + async fn cancel_request_aborts() { + let mut in_flight_requests = InFlightRequests::default(); + let abort_registration = in_flight_requests + .start_request(0, SystemTime::now()) + .unwrap(); + let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); + + assert_eq!(in_flight_requests.cancel_request(0), true); + assert_matches!( + abortable_future.poll_unpin(&mut noop_context()), + Poll::Ready(Err(_)) + ); + assert_eq!(in_flight_requests.len(), 0); + } + + #[tokio::test] + async fn remove_request_doesnt_abort() { + let mut in_flight_requests = InFlightRequests::default(); + let abort_registration = in_flight_requests + .start_request(0, SystemTime::now()) + .unwrap(); + let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); + + assert_eq!(in_flight_requests.remove_request(0), true); + assert_matches!( + abortable_future.poll_unpin(&mut noop_context()), + Poll::Pending + ); + assert_eq!(in_flight_requests.len(), 0); + } } diff --git a/tarpc/src/server/throttle.rs b/tarpc/src/server/throttle.rs index 05dec4c..bac2029 100644 --- a/tarpc/src/server/throttle.rs +++ b/tarpc/src/server/throttle.rs @@ -171,177 +171,179 @@ where } #[cfg(test)] -use super::testing::{self, FakeChannel, PollExt}; -#[cfg(test)] -use crate::Request; -#[cfg(test)] -use pin_utils::pin_mut; -#[cfg(test)] -use std::{marker::PhantomData, time::Duration}; +mod tests { + use super::*; -#[tokio::test] -async fn throttler_in_flight_requests() { - let throttler = Throttler { - max_in_flight_requests: 0, - inner: FakeChannel::default::(), - }; + use super::super::in_flight_requests::AlreadyExistsError; + use super::super::testing::{self, FakeChannel, PollExt}; + use crate::Request; + use pin_utils::pin_mut; + use std::{marker::PhantomData, time::Duration}; - pin_mut!(throttler); - for i in 0..5 { + #[tokio::test] + async fn throttler_in_flight_requests() { + let throttler = Throttler { + max_in_flight_requests: 0, + inner: FakeChannel::default::(), + }; + + pin_mut!(throttler); + for i in 0..5 { + throttler + .inner + .in_flight_requests + .start_request(i, SystemTime::now() + Duration::from_secs(1)) + .unwrap(); + } + assert_eq!(throttler.as_mut().in_flight_requests(), 5); + } + + #[tokio::test] + async fn throttler_start_request() { + let throttler = Throttler { + max_in_flight_requests: 0, + inner: FakeChannel::default::(), + }; + + pin_mut!(throttler); + throttler + .as_mut() + .start_request(1, SystemTime::now() + Duration::from_secs(1)) + .unwrap(); + assert_eq!(throttler.inner.in_flight_requests.len(), 1); + } + + #[test] + fn throttler_poll_next_done() { + let throttler = Throttler { + max_in_flight_requests: 0, + inner: FakeChannel::default::(), + }; + + pin_mut!(throttler); + assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done()); + } + + #[test] + fn throttler_poll_next_some() -> io::Result<()> { + let throttler = Throttler { + max_in_flight_requests: 1, + inner: FakeChannel::default::(), + }; + + pin_mut!(throttler); + throttler.inner.push_req(0, 1); + assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready()); + assert_eq!( + throttler + .as_mut() + .poll_next(&mut testing::cx())? + .map(|r| r.map(|r| (r.id, r.message))), + Poll::Ready(Some((0, 1))) + ); + Ok(()) + } + + #[test] + fn throttler_poll_next_throttled() { + let throttler = Throttler { + max_in_flight_requests: 0, + inner: FakeChannel::default::(), + }; + + pin_mut!(throttler); + throttler.inner.push_req(1, 1); + assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done()); + assert_eq!(throttler.inner.sink.len(), 1); + let resp = throttler.inner.sink.get(0).unwrap(); + assert_eq!(resp.request_id, 1); + assert!(resp.message.is_err()); + } + + #[test] + fn throttler_poll_next_throttled_sink_not_ready() { + let throttler = Throttler { + max_in_flight_requests: 0, + inner: PendingSink::default::(), + }; + pin_mut!(throttler); + assert!(throttler.poll_next(&mut testing::cx()).is_pending()); + + struct PendingSink { + ghost: PhantomData In>, + } + impl PendingSink<(), ()> { + pub fn default() -> PendingSink>, Response> { + PendingSink { ghost: PhantomData } + } + } + impl Stream for PendingSink { + type Item = In; + fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + unimplemented!() + } + } + impl Sink for PendingSink { + type Error = io::Error; + fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + Poll::Pending + } + fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> { + Err(io::Error::from(io::ErrorKind::WouldBlock)) + } + fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + Poll::Pending + } + fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll> { + Poll::Pending + } + } + impl Channel for PendingSink>, Response> { + type Req = Req; + type Resp = Resp; + fn config(&self) -> &Config { + unimplemented!() + } + fn in_flight_requests(&self) -> usize { + 0 + } + fn start_request( + self: Pin<&mut Self>, + _id: u64, + _deadline: SystemTime, + ) -> Result { + unimplemented!() + } + } + } + + #[tokio::test] + async fn throttler_start_send() { + let throttler = Throttler { + max_in_flight_requests: 0, + inner: FakeChannel::default::(), + }; + + pin_mut!(throttler); throttler .inner .in_flight_requests - .start_request(i, SystemTime::now() + Duration::from_secs(1)) + .start_request(0, SystemTime::now() + Duration::from_secs(1)) .unwrap(); - } - assert_eq!(throttler.as_mut().in_flight_requests(), 5); -} - -#[tokio::test] -async fn throttler_start_request() { - let throttler = Throttler { - max_in_flight_requests: 0, - inner: FakeChannel::default::(), - }; - - pin_mut!(throttler); - throttler - .as_mut() - .start_request(1, SystemTime::now() + Duration::from_secs(1)) - .unwrap(); - assert_eq!(throttler.inner.in_flight_requests.len(), 1); -} - -#[test] -fn throttler_poll_next_done() { - let throttler = Throttler { - max_in_flight_requests: 0, - inner: FakeChannel::default::(), - }; - - pin_mut!(throttler); - assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done()); -} - -#[test] -fn throttler_poll_next_some() -> io::Result<()> { - let throttler = Throttler { - max_in_flight_requests: 1, - inner: FakeChannel::default::(), - }; - - pin_mut!(throttler); - throttler.inner.push_req(0, 1); - assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready()); - assert_eq!( throttler .as_mut() - .poll_next(&mut testing::cx())? - .map(|r| r.map(|r| (r.id, r.message))), - Poll::Ready(Some((0, 1))) - ); - Ok(()) -} - -#[test] -fn throttler_poll_next_throttled() { - let throttler = Throttler { - max_in_flight_requests: 0, - inner: FakeChannel::default::(), - }; - - pin_mut!(throttler); - throttler.inner.push_req(1, 1); - assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done()); - assert_eq!(throttler.inner.sink.len(), 1); - let resp = throttler.inner.sink.get(0).unwrap(); - assert_eq!(resp.request_id, 1); - assert!(resp.message.is_err()); -} - -#[test] -fn throttler_poll_next_throttled_sink_not_ready() { - let throttler = Throttler { - max_in_flight_requests: 0, - inner: PendingSink::default::(), - }; - pin_mut!(throttler); - assert!(throttler.poll_next(&mut testing::cx()).is_pending()); - - struct PendingSink { - ghost: PhantomData In>, - } - impl PendingSink<(), ()> { - pub fn default() -> PendingSink>, Response> { - PendingSink { ghost: PhantomData } - } - } - impl Stream for PendingSink { - type Item = In; - fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll> { - unimplemented!() - } - } - impl Sink for PendingSink { - type Error = io::Error; - fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll> { - Poll::Pending - } - fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> { - Err(io::Error::from(io::ErrorKind::WouldBlock)) - } - fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll> { - Poll::Pending - } - fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll> { - Poll::Pending - } - } - impl Channel for PendingSink>, Response> { - type Req = Req; - type Resp = Resp; - fn config(&self) -> &Config { - unimplemented!() - } - fn in_flight_requests(&self) -> usize { - 0 - } - fn start_request( - self: Pin<&mut Self>, - _id: u64, - _deadline: SystemTime, - ) -> Result { - unimplemented!() - } + .start_send(Response { + request_id: 0, + message: Ok(1), + }) + .unwrap(); + assert_eq!(throttler.inner.in_flight_requests.len(), 0); + assert_eq!( + throttler.inner.sink.get(0), + Some(&Response { + request_id: 0, + message: Ok(1), + }) + ); } } - -#[tokio::test] -async fn throttler_start_send() { - let throttler = Throttler { - max_in_flight_requests: 0, - inner: FakeChannel::default::(), - }; - - pin_mut!(throttler); - throttler - .inner - .in_flight_requests - .start_request(0, SystemTime::now() + Duration::from_secs(1)) - .unwrap(); - throttler - .as_mut() - .start_send(Response { - request_id: 0, - message: Ok(1), - }) - .unwrap(); - assert_eq!(throttler.inner.in_flight_requests.len(), 0); - assert_eq!( - throttler.inner.sink.get(0), - Some(&Response { - request_id: 0, - message: Ok(1), - }) - ); -}