diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 2869b4d..6fa0143 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -37,7 +37,7 @@ serde = { optional = true, version = "1.0", features = ["derive"] } static_assertions = "1.1.0" tarpc-plugins = { path = "../plugins", version = "0.9" } tokio = { version = "1", features = ["time"] } -tokio-util = { optional = true, version = "0.6" } +tokio-util = { version = "0.6.3", features = ["time"] } tokio-serde = { optional = true, version = "0.8" } [dev-dependencies] @@ -46,10 +46,11 @@ bincode = "1.3" bytes = { version = "1", features = ["serde"] } env_logger = "0.8" flate2 = "1.0" +futures-test = "0.3" log = "0.4" pin-utils = "0.1.0-alpha" serde_bytes = "0.11" -tokio = { version = "1", features = ["full"] } +tokio = { version = "1", features = ["full", "test-util"] } tokio-serde = { version = "0.8", features = ["json", "bincode"] } trybuild = "1.0" diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 26a63ba..b85a0b8 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -4,16 +4,12 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use futures::{ - future::{self, Ready}, - prelude::*, -}; +use futures::future::{self, Ready}; use std::io; use tarpc::{ client, context, server::{self, Channel}, }; -use tokio_serde::formats::Json; /// This is the service definition. It looks a lot like a trait definition. /// It defines one RPC, hello, which takes one arg, name, and returns a String. diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 295d0c4..f24e2cb 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,25 +6,22 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. -use crate::{ - context, trace, util::Compact, util::TimeUntil, ClientMessage, PollIo, Request, Response, - ServerError, Transport, -}; -use fnv::FnvHashMap; +use crate::{context, ClientMessage, PollIo, Request, Response, ServerError, Transport}; use futures::{ channel::mpsc, - future::{AbortHandle, AbortRegistration, Abortable}, + future::{AbortRegistration, Abortable}, prelude::*, ready, stream::Fuse, task::*, }; use humantime::format_rfc3339; -use log::{debug, trace}; -use pin_project::{pin_project, pinned_drop}; -use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin}; +use log::{debug, info, trace}; +use pin_project::pin_project; +use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime}; mod filter; +mod in_flight_requests; #[cfg(test)] mod testing; mod throttle; @@ -134,14 +131,14 @@ where /// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation /// messages. Instead, it internally handles them by cancelling corresponding requests (removing /// the corresponding in-flight requests and aborting their handlers). -#[pin_project(PinnedDrop)] +#[pin_project] pub struct BaseChannel { config: Config, /// Writes responses to the wire and reads requests off the wire. #[pin] transport: Fuse, - /// Number of requests currently being responded to. - in_flight_requests: FnvHashMap, + /// Holds data necessary to clean up in-flight requests. + in_flight_requests: in_flight_requests::InFlightRequests, /// Types the request and response. ghost: PhantomData<(Req, Resp)>, } @@ -155,7 +152,7 @@ where BaseChannel { config, transport: transport.fuse(), - in_flight_requests: FnvHashMap::default(), + in_flight_requests: in_flight_requests::InFlightRequests::default(), ghost: PhantomData, } } @@ -176,35 +173,6 @@ where } } -impl BaseChannel { - fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) { - // It's possible the request was already completed, so it's fine - // if this is None. - if let Some(cancel_handle) = self - .as_mut() - .project() - .in_flight_requests - .remove(&request_id) - { - self.as_mut().project().in_flight_requests.compact(0.1); - - cancel_handle.abort(); - let remaining = self.as_mut().project().in_flight_requests.len(); - trace!( - "[{}] Request canceled. In-flight requests = {}", - trace_context.trace_id, - remaining, - ); - } else { - trace!( - "[{}] Received cancellation, but response handler \ - is already complete.", - trace_context.trace_id, - ); - } - } -} - impl fmt::Debug for BaseChannel { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "BaseChannel") @@ -260,7 +228,14 @@ where /// Tells the Channel that request with ID `request_id` is being handled. /// The request will be tracked until a response with the same ID is sent /// to the Channel. - fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration; + fn start_request( + self: Pin<&mut Self>, + id: u64, + deadline: SystemTime, + ) -> Result; + + /// Yields a request that has expired, aborting any ongoing processing of that request. + fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo; /// Returns a stream of requests that automatically handle request cancellation and response /// routing. @@ -312,7 +287,25 @@ where trace_context, request_id, } => { - self.as_mut().cancel_request(&trace_context, request_id); + if self + .as_mut() + .project() + .in_flight_requests + .cancel_request(request_id) + { + let remaining = self.in_flight_requests.len(); + trace!( + "[{}] Request canceled. In-flight requests = {}", + trace_context.trace_id, + remaining, + ); + } else { + trace!( + "[{}] Received cancellation, but response handler \ + is already complete.", + trace_context.trace_id, + ); + } } }, None => return Poll::Ready(None), @@ -332,16 +325,10 @@ where } fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { - if self - .as_mut() + self.as_mut() .project() .in_flight_requests - .remove(&response.request_id) - .is_some() - { - self.as_mut().project().in_flight_requests.compact(0.1); - } - + .remove_request(response.request_id); self.project().transport.start_send(response) } @@ -354,17 +341,6 @@ where } } -#[pinned_drop] -impl PinnedDrop for BaseChannel { - fn drop(mut self: Pin<&mut Self>) { - self.as_mut() - .project() - .in_flight_requests - .values() - .for_each(AbortHandle::abort); - } -} - impl AsRef for BaseChannel { fn as_ref(&self) -> &T { self.transport.get_ref() @@ -386,14 +362,18 @@ where self.in_flight_requests.len() } - fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration { - let (abort_handle, abort_registration) = AbortHandle::new_pair(); - assert!(self - .project() + fn start_request( + self: Pin<&mut Self>, + id: u64, + deadline: SystemTime, + ) -> Result { + self.project() .in_flight_requests - .insert(request_id, abort_handle) - .is_none()); - abort_registration + .start_request(id, deadline) + } + + fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo { + self.project().in_flight_requests.poll_expired(cx) } } @@ -426,16 +406,41 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> PollIo> { - match ready!(self.as_mut().project().channel.poll_next(cx)?) { - Some(request) => { - let abort_registration = self.as_mut().project().channel.start_request(request.id); - Poll::Ready(Some(Ok(InFlightRequest { - request, - response_tx: self.responses_tx.clone(), - abort_registration, - }))) + loop { + match ready!(self.as_mut().project().channel.poll_next(cx)?) { + Some(request) => { + trace!( + "[{}] Handling request with deadline {}.", + request.context.trace_id(), + format_rfc3339(request.context.deadline), + ); + + match self + .channel_pin_mut() + .start_request(request.id, request.context.deadline) + { + Ok(abort_registration) => { + return Poll::Ready(Some(Ok(InFlightRequest { + request, + response_tx: self.responses_tx.clone(), + abort_registration, + }))) + } + // Instead of closing the channel if a duplicate request is sent, just + // ignore it, since it's already being processed. Note that we cannot + // return Poll::Pending here, since nothing has scheduled a wakeup yet. + Err(in_flight_requests::AlreadyExistsError) => { + info!( + "[{}] Request ID {} delivered more than once.", + request.context.trace_id(), + request.id + ); + continue; + } + } + } + None => return Poll::Ready(None), } - None => Poll::Ready(None), } } @@ -444,6 +449,17 @@ where cx: &mut Context<'_>, read_half_closed: bool, ) -> PollIo<()> { + if let Poll::Ready(Some(request_id)) = self.channel_pin_mut().poll_expired(cx)? { + debug!("Request {} did not complete before deadline", request_id); + self.channel_pin_mut().start_send(Response { + request_id, + message: Err(ServerError { + kind: io::ErrorKind::TimedOut, + detail: Some(format!("Request did not complete before deadline.")), + }), + })?; + return Poll::Ready(Some(Ok(()))); + } match self.as_mut().poll_next_response(cx)? { Poll::Ready(Some((context, response))) => { trace!( @@ -451,6 +467,12 @@ where context.trace_id(), self.channel.in_flight_requests(), ); + // TODO: it's possible for poll_flush to be starved and start_send to end up full. + // Currently that would cause the channel to shut down. serde_transport internally + // uses tokio-util Framed, which will allocate as much as needed. But other + // transports may work differently. + // + // There should be a way to know if a flush is needed soon. self.channel_pin_mut().start_send(response)?; Poll::Ready(Some(Ok(()))) } @@ -543,39 +565,10 @@ impl InFlightRequest { message, id: request_id, } = request; - let trace_id = *request.context.trace_id(); - let deadline = request.context.deadline; - let timeout = deadline.time_until(); - trace!( - "[{}] Handling request with deadline {} (timeout {:?}).", - trace_id, - format_rfc3339(deadline), - timeout, - ); - let result = - tokio::time::timeout(timeout, async { serve.serve(context, message).await }) - .await; + let response = serve.serve(context, message).await; let response = Response { request_id, - message: match result { - Ok(message) => Ok(message), - Err(tokio::time::error::Elapsed { .. }) => { - debug!( - "[{}] Response did not complete before deadline of {}s.", - trace_id, - format_rfc3339(deadline) - ); - // No point in responding, since the client will have dropped the - // request. - Err(ServerError { - kind: io::ErrorKind::TimedOut, - detail: Some(format!( - "Response did not complete before deadline of {}s.", - format_rfc3339(deadline) - )), - }) - } - }, + message: Ok(response), }; let _ = response_tx.send((context, response)).await; }, @@ -687,7 +680,7 @@ where while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) { tokio::spawn(channel.execute(self.serve.clone())); } - log::info!("Server shutting down."); + info!("Server shutting down."); Poll::Ready(()) } } @@ -713,7 +706,7 @@ where }); } Err(e) => { - log::info!("Requests stream errored out: {}", e); + info!("Requests stream errored out: {}", e); break; } } diff --git a/tarpc/src/server/filter.rs b/tarpc/src/server/filter.rs index b457304..902cede 100644 --- a/tarpc/src/server/filter.rs +++ b/tarpc/src/server/filter.rs @@ -7,6 +7,7 @@ use crate::{ server::{self, Channel}, util::Compact, + PollIo, }; use fnv::FnvHashMap; use futures::{channel::mpsc, future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*}; @@ -15,6 +16,7 @@ use pin_project::pin_project; use std::sync::{Arc, Weak}; use std::{ collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin, + time::SystemTime, }; /// A single-threaded filter that drops channels based on per-key limits. @@ -112,8 +114,16 @@ where self.inner.in_flight_requests() } - fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration { - self.project().inner.start_request(request_id) + fn start_request( + self: Pin<&mut Self>, + id: u64, + deadline: SystemTime, + ) -> Result { + self.project().inner.start_request(id, deadline) + } + + fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo { + self.project().inner.poll_expired(cx) } } diff --git a/tarpc/src/server/in_flight_requests.rs b/tarpc/src/server/in_flight_requests.rs new file mode 100644 index 0000000..b8f0d23 --- /dev/null +++ b/tarpc/src/server/in_flight_requests.rs @@ -0,0 +1,193 @@ +use crate::{ + util::{Compact, TimeUntil}, + PollIo, +}; +use fnv::FnvHashMap; +use futures::{ + future::{AbortHandle, AbortRegistration}, + ready, +}; +use std::{ + collections::hash_map, + io, + task::{Context, Poll}, + time::SystemTime, +}; +use tokio_util::time::delay_queue::{self, DelayQueue}; + +/// A data structure that tracks in-flight requests. It aborts requests, +/// either on demand or when a request deadline expires. +#[derive(Debug, Default)] +pub struct InFlightRequests { + request_data: FnvHashMap, + deadlines: DelayQueue, +} + +#[derive(Debug)] +/// Data needed to clean up a single in-flight request. +struct RequestData { + /// Aborts the response handler for the associated request. + abort_handle: AbortHandle, + /// The key to remove the timer for the request's deadline. + deadline_key: delay_queue::Key, +} + +/// An error returned when a request attempted to start with the same ID as a request already +/// in flight. +#[derive(Debug)] +pub struct AlreadyExistsError; + +impl InFlightRequests { + pub fn len(&self) -> usize { + self.request_data.len() + } + + /// Starts a request, unless a request with the same ID is already in flight. + pub fn start_request( + &mut self, + request_id: u64, + deadline: SystemTime, + ) -> Result { + let timeout = deadline.time_until(); + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + let deadline_key = self.deadlines.insert(request_id, timeout); + match self.request_data.entry(request_id) { + hash_map::Entry::Vacant(vacant) => { + vacant.insert(RequestData { + abort_handle, + deadline_key, + }); + Ok(abort_registration) + } + hash_map::Entry::Occupied(_) => { + self.deadlines.remove(&deadline_key); + Err(AlreadyExistsError) + } + } + } + + /// Cancels an in-flight request. Returns true iff the request was found. + pub fn cancel_request(&mut self, request_id: u64) -> bool { + if let Some(request_data) = self.request_data.remove(&request_id) { + self.request_data.compact(0.1); + + request_data.abort_handle.abort(); + self.deadlines.remove(&request_data.deadline_key); + + true + } else { + false + } + } + + /// Removes a request without aborting. Returns true iff the request was found. + pub fn remove_request(&mut self, request_id: u64) -> bool { + if let Some(request_data) = self.request_data.remove(&request_id) { + self.request_data.compact(0.1); + + self.deadlines.remove(&request_data.deadline_key); + + true + } else { + false + } + } + + /// Yields a request that has expired, aborting any ongoing processing of that request. + pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo { + Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) { + Some(Ok(expired)) => { + if let Some(request_data) = self.request_data.remove(expired.get_ref()) { + self.request_data.compact(0.1); + request_data.abort_handle.abort(); + } + Some(Ok(expired.into_inner())) + } + Some(Err(e)) => Some(Err(io::Error::new(io::ErrorKind::Other, e))), + None => None, + }) + } +} + +/// When InFlightRequests is dropped, any requests still in flight are aborted. +impl Drop for InFlightRequests { + fn drop(self: &mut Self) { + self.request_data + .values() + .for_each(|request_data| request_data.abort_handle.abort()) + } +} + +#[cfg(test)] +use { + assert_matches::assert_matches, + futures::{ + future::{pending, Abortable}, + FutureExt, + }, + futures_test::task::noop_context, +}; + +#[tokio::test] +async fn start_request_increases_len() { + let mut in_flight_requests = InFlightRequests::default(); + assert_eq!(in_flight_requests.len(), 0); + in_flight_requests + .start_request(0, SystemTime::now()) + .unwrap(); + assert_eq!(in_flight_requests.len(), 1); +} + +#[tokio::test] +async fn polling_expired_aborts() { + let mut in_flight_requests = InFlightRequests::default(); + let abort_registration = in_flight_requests + .start_request(0, SystemTime::now()) + .unwrap(); + let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); + + tokio::time::pause(); + tokio::time::advance(std::time::Duration::from_secs(1000)).await; + + assert_matches!( + in_flight_requests.poll_expired(&mut noop_context()), + Poll::Ready(Some(Ok(_))) + ); + assert_matches!( + abortable_future.poll_unpin(&mut noop_context()), + Poll::Ready(Err(_)) + ); + assert_eq!(in_flight_requests.len(), 0); +} + +#[tokio::test] +async fn cancel_request_aborts() { + let mut in_flight_requests = InFlightRequests::default(); + let abort_registration = in_flight_requests + .start_request(0, SystemTime::now()) + .unwrap(); + let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); + + assert_eq!(in_flight_requests.cancel_request(0), true); + assert_matches!( + abortable_future.poll_unpin(&mut noop_context()), + Poll::Ready(Err(_)) + ); + assert_eq!(in_flight_requests.len(), 0); +} + +#[tokio::test] +async fn remove_request_doesnt_abort() { + let mut in_flight_requests = InFlightRequests::default(); + let abort_registration = in_flight_requests + .start_request(0, SystemTime::now()) + .unwrap(); + let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); + + assert_eq!(in_flight_requests.remove_request(0), true); + assert_matches!( + abortable_future.poll_unpin(&mut noop_context()), + Poll::Pending + ); + assert_eq!(in_flight_requests.len(), 0); +} diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index cc274d6..9b0aca4 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -4,14 +4,12 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use crate::server::{Channel, Config}; -use crate::{context, Request, Response}; -use fnv::FnvHashSet; -use futures::{ - future::{AbortHandle, AbortRegistration}, - task::*, - Sink, Stream, +use crate::{ + context, + server::{Channel, Config}, + PollIo, Request, Response, }; +use futures::{future::AbortRegistration, task::*, Sink, Stream}; use pin_project::pin_project; use std::collections::VecDeque; use std::io; @@ -25,7 +23,7 @@ pub(crate) struct FakeChannel { #[pin] pub sink: VecDeque, pub config: Config, - pub in_flight_requests: FnvHashSet, + pub in_flight_requests: super::in_flight_requests::InFlightRequests, } impl Stream for FakeChannel @@ -50,7 +48,7 @@ impl Sink> for FakeChannel> { self.as_mut() .project() .in_flight_requests - .remove(&response.request_id); + .remove_request(response.request_id); self.project() .sink .start_send(response) @@ -81,9 +79,18 @@ where self.in_flight_requests.len() } - fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration { - self.project().in_flight_requests.insert(id); - AbortHandle::new_pair().1 + fn start_request( + self: Pin<&mut Self>, + id: u64, + deadline: SystemTime, + ) -> Result { + self.project() + .in_flight_requests + .start_request(id, deadline) + } + + fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo { + self.project().in_flight_requests.poll_expired(cx) } } diff --git a/tarpc/src/server/throttle.rs b/tarpc/src/server/throttle.rs index 807ae43..f640c0e 100644 --- a/tarpc/src/server/throttle.rs +++ b/tarpc/src/server/throttle.rs @@ -5,11 +5,11 @@ // https://opensource.org/licenses/MIT. use super::{Channel, Config}; -use crate::{Response, ServerError}; +use crate::{PollIo, Response, ServerError}; use futures::{future::AbortRegistration, prelude::*, ready, task::*}; use log::debug; use pin_project::pin_project; -use std::{io, pin::Pin}; +use std::{io, pin::Pin, time::SystemTime}; /// A [`Channel`] that limits the number of concurrent /// requests by throttling. @@ -121,8 +121,16 @@ where self.inner.config() } - fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration { - self.project().inner.start_request(request_id) + fn start_request( + self: Pin<&mut Self>, + id: u64, + deadline: SystemTime, + ) -> Result { + self.project().inner.start_request(id, deadline) + } + + fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo { + self.project().inner.poll_expired(cx) } } @@ -173,10 +181,10 @@ use crate::Request; #[cfg(test)] use pin_utils::pin_mut; #[cfg(test)] -use std::marker::PhantomData; +use std::{marker::PhantomData, time::Duration}; -#[test] -fn throttler_in_flight_requests() { +#[tokio::test] +async fn throttler_in_flight_requests() { let throttler = Throttler { max_in_flight_requests: 0, inner: FakeChannel::default::(), @@ -184,20 +192,27 @@ fn throttler_in_flight_requests() { pin_mut!(throttler); for i in 0..5 { - throttler.inner.in_flight_requests.insert(i); + throttler + .inner + .in_flight_requests + .start_request(i, SystemTime::now() + Duration::from_secs(1)) + .unwrap(); } assert_eq!(throttler.as_mut().in_flight_requests(), 5); } -#[test] -fn throttler_start_request() { +#[tokio::test] +async fn throttler_start_request() { let throttler = Throttler { max_in_flight_requests: 0, inner: FakeChannel::default::(), }; pin_mut!(throttler); - throttler.as_mut().start_request(1); + throttler + .as_mut() + .start_request(1, SystemTime::now() + Duration::from_secs(1)) + .unwrap(); assert_eq!(throttler.inner.in_flight_requests.len(), 1); } @@ -295,21 +310,32 @@ fn throttler_poll_next_throttled_sink_not_ready() { fn in_flight_requests(&self) -> usize { 0 } - fn start_request(self: Pin<&mut Self>, _: u64) -> AbortRegistration { + fn start_request( + self: Pin<&mut Self>, + _id: u64, + _deadline: SystemTime, + ) -> Result { + unimplemented!() + } + fn poll_expired(self: Pin<&mut Self>, _cx: &mut Context) -> PollIo { unimplemented!() } } } -#[test] -fn throttler_start_send() { +#[tokio::test] +async fn throttler_start_send() { let throttler = Throttler { max_in_flight_requests: 0, inner: FakeChannel::default::(), }; pin_mut!(throttler); - throttler.inner.in_flight_requests.insert(0); + throttler + .inner + .in_flight_requests + .start_request(0, SystemTime::now() + Duration::from_secs(1)) + .unwrap(); throttler .as_mut() .start_send(Response { @@ -317,7 +343,7 @@ fn throttler_start_send() { message: Ok(1), }) .unwrap(); - assert!(throttler.inner.in_flight_requests.is_empty()); + assert_eq!(throttler.inner.in_flight_requests.len(), 0); assert_eq!( throttler.inner.sink.get(0), Some(&Response {