diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index b2fcb4e..13fd429 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -19,10 +19,7 @@ use futures::{ }; use in_flight_requests::{AlreadyExistsError, InFlightRequests}; use pin_project::pin_project; -use std::{ - convert::TryFrom, error::Error, fmt, hash::Hash, marker::PhantomData, pin::Pin, - time::SystemTime, -}; +use std::{convert::TryFrom, error::Error, fmt, hash::Hash, marker::PhantomData, pin::Pin}; use tokio::sync::mpsc; use tracing::{info_span, instrument::Instrument, Span}; @@ -190,6 +187,47 @@ where fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse> { self.as_mut().project().transport } + + fn start_request( + self: Pin<&mut Self>, + mut request: Request, + ) -> Result, AlreadyExistsError> { + let span = info_span!( + "RPC", + rpc.trace_id = %request.context.trace_id(), + otel.kind = "server", + otel.name = tracing::field::Empty, + ); + span.set_context(&request.context); + request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + tracing::trace!( + "OpenTelemetry subscriber not installed; making unsampled \ + child context." + ); + request.context.trace_context.new_child() + }); + let entered = span.enter(); + tracing::info!("ReceiveRequest"); + let start = self.project().in_flight_requests.start_request( + request.id, + request.context.deadline, + span.clone(), + ); + match start { + Ok(abort_registration) => { + drop(entered); + return Ok(TrackedRequest { + request, + abort_registration, + span, + }); + } + Err(AlreadyExistsError) => { + tracing::trace!("DuplicateRequest"); + return Err(AlreadyExistsError); + } + } + } } impl fmt::Debug for BaseChannel { @@ -198,6 +236,18 @@ impl fmt::Debug for BaseChannel { } } +/// A request tracked by a [`Channel`]. +#[derive(Debug)] +pub struct TrackedRequest { + /// The request sent by the client. + pub request: Request, + /// A registration to abort a future when the [`Channel`] that produced this request stops + /// tracking it. + pub abort_registration: AbortRegistration, + /// A span representing the server processing of this request. + pub span: Span, +} + /// The server end of an open connection with a client, streaming in requests from, and sinking /// responses to, the client. /// @@ -210,18 +260,20 @@ impl fmt::Debug for BaseChannel { /// [`execute`](InFlightRequest::execute) method. If using `execute`, request processing will /// automatically cease when either the request deadline is reached or when a corresponding /// cancellation message is received by the Channel. -/// 3. [`Sink::start_send`] - A user is free to manually send responses to requests produced by a -/// Channel using [`Sink::start_send`] in lieu of the previous methods. If not using one of the -/// previous execute methods, then nothing will automatically cancel requests or set up the -/// request context. However, the Channel will still clean up resources upon deadline expiration -/// or cancellation. In the case that the Channel cleans up resources related to a request -/// before the response is sent, the response can still be sent into the Channel later on. -/// Because there is no guarantee that a cancellation message will ever be received for a -/// request, or that requests come with reasonably short deadlines, services should strive to -/// clean up Channel resources by sending a response for every request. +/// 3. [`Stream::next`](futures::stream::StreamExt::next) / +/// [`Sink::send`](futures::sink::SinkExt::send) - A user is free to manually read requests +/// from, and send responses into, a Channel in lieu of the previous methods. Channels stream +/// [`TrackedRequests`](TrackedRequest), which, in addition to the request itself, contains the +/// server [`Span`] and request lifetime [`AbortRegistration`]. Wrapping response +/// logic in an [`Abortable`] future using the abort registration will ensure that the response +/// does not execute longer than the request deadline. The `Channel` itself will clean up +/// request state once either the deadline expires, or a cancellation message is received, or a +/// response is sent. Because there is no guarantee that a cancellation message will ever be +/// received for a request, or that requests come with reasonably short deadlines, services +/// should strive to clean up Channel resources by sending a response for every request. pub trait Channel where - Self: Transport::Resp>, Request<::Req>>, + Self: Transport::Resp>, TrackedRequest<::Req>>, { /// Type of request item. type Req; @@ -249,16 +301,6 @@ where Throttler::new(self, limit) } - /// Tells the Channel that request with ID `request_id` is being handled. - /// The request will be tracked until a response with the same ID is sent - /// to the Channel or the deadline expires, whichever happens first. - fn start_request( - self: Pin<&mut Self>, - request_id: u64, - deadline: SystemTime, - span: Span, - ) -> Result; - /// Returns a stream of requests that automatically handle request cancellation and response /// routing. /// @@ -313,7 +355,7 @@ impl Stream for BaseChannel where T: Transport, ClientMessage>, { - type Item = Result, ChannelError>; + type Item = Result, ChannelError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { enum ReceiverStatus { @@ -343,7 +385,16 @@ where { Poll::Ready(Some(message)) => match message { ClientMessage::Request(request) => { - return Poll::Ready(Some(Ok(request))); + match self.as_mut().start_request(request) { + Ok(request) => return Poll::Ready(Some(Ok(request))), + Err(AlreadyExistsError) => { + // Instead of closing the channel if a duplicate request is sent, + // just ignore it, since it's already being processed. Note that we + // cannot return Poll::Pending here, since nothing has scheduled a + // wakeup yet. + continue; + } + } } ClientMessage::Cancel { trace_context, @@ -405,6 +456,7 @@ where } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + tracing::trace!("poll_flush"); self.project() .transport .poll_flush(cx) @@ -444,17 +496,6 @@ where fn transport(&self) -> &Self::Transport { self.get_ref() } - - fn start_request( - self: Pin<&mut Self>, - request_id: u64, - deadline: SystemTime, - span: Span, - ) -> Result { - self.project() - .in_flight_requests - .start_request(request_id, deadline, span) - } } /// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so @@ -494,48 +535,12 @@ where ) -> Poll, C::Error>>> { loop { match ready!(self.channel_pin_mut().poll_next(cx)?) { - Some(mut request) => { - let span = info_span!( - "RPC", - rpc.trace_id = %request.context.trace_id(), - otel.kind = "server", - otel.name = tracing::field::Empty, - ); - span.set_context(&request.context); - request.context.trace_context = - trace::Context::try_from(&span).unwrap_or_else(|_| { - tracing::trace!( - "OpenTelemetry subscriber not installed; making unsampled \ - child context." - ); - request.context.trace_context.new_child() - }); - let entered = span.enter(); - tracing::info!("ReceiveRequest"); - let start = self.channel_pin_mut().start_request( - request.id, - request.context.deadline, - span.clone(), - ); - match start { - Ok(abort_registration) => { - let response_tx = self.responses_tx.clone(); - drop(entered); - return Poll::Ready(Some(Ok(InFlightRequest { - request, - response_tx, - abort_registration, - span, - }))); - } - // Instead of closing the channel if a duplicate request is sent, just - // ignore it, since it's already being processed. Note that we cannot - // return Poll::Pending here, since nothing has scheduled a wakeup yet. - Err(AlreadyExistsError) => { - tracing::trace!("DuplicateRequest"); - continue; - } - } + Some(request) => { + let response_tx = self.responses_tx.clone(); + return Poll::Ready(Some(Ok(InFlightRequest { + request, + response_tx, + }))); } None => return Poll::Ready(None), } @@ -619,16 +624,14 @@ where /// A request produced by [Channel::requests]. #[derive(Debug)] pub struct InFlightRequest { - request: Request, + request: TrackedRequest, response_tx: mpsc::Sender>, - abort_registration: AbortRegistration, - span: Span, } impl InFlightRequest { /// Returns a reference to the request. pub fn get(&self) -> &Request { - &self.request + &self.request.request } /// Returns a [future](Future) that executes the request using the given [service @@ -647,15 +650,18 @@ impl InFlightRequest { S: Serve, { let Self { - abort_registration, - request: - Request { - context, - message, - id: request_id, - }, response_tx, - span, + request: + TrackedRequest { + abort_registration, + span, + request: + Request { + context, + message, + id: request_id, + }, + }, } = self; let method = serve.method(&message); span.record("otel.name", &method.unwrap_or("")); @@ -826,7 +832,6 @@ mod tests { use assert_matches::assert_matches; use futures::future::{pending, Aborted}; use futures_test::task::noop_context; - use std::time::Duration; fn test_channel() -> ( Pin, Response>>>>, @@ -892,12 +897,18 @@ mod tests { channel .as_mut() - .start_request(0, SystemTime::now(), Span::current()) + .start_request(Request { + id: 0, + context: context::current(), + message: (), + }) .unwrap(); assert_matches!( - channel - .as_mut() - .start_request(0, SystemTime::now(), Span::current()), + channel.as_mut().start_request(Request { + id: 0, + context: context::current(), + message: () + }), Err(AlreadyExistsError) ); } @@ -907,13 +918,21 @@ mod tests { let (mut channel, _tx) = test_channel::<(), ()>(); tokio::time::pause(); - let abort_registration0 = channel + let req0 = channel .as_mut() - .start_request(0, SystemTime::now(), Span::current()) + .start_request(Request { + id: 0, + context: context::current(), + message: (), + }) .unwrap(); - let abort_registration1 = channel + let req1 = channel .as_mut() - .start_request(1, SystemTime::now(), Span::current()) + .start_request(Request { + id: 1, + context: context::current(), + message: (), + }) .unwrap(); tokio::time::advance(std::time::Duration::from_secs(1000)).await; @@ -921,8 +940,8 @@ mod tests { channel.as_mut().poll_next(&mut noop_context()), Poll::Pending ); - assert_matches!(test_abortable(abort_registration0).await, Err(Aborted)); - assert_matches!(test_abortable(abort_registration1).await, Err(Aborted)); + assert_matches!(test_abortable(req0.abort_registration).await, Err(Aborted)); + assert_matches!(test_abortable(req1.abort_registration).await, Err(Aborted)); } #[tokio::test] @@ -930,13 +949,13 @@ mod tests { let (mut channel, mut tx) = test_channel::<(), ()>(); tokio::time::pause(); - let abort_registration = channel + let req = channel .as_mut() - .start_request( - 0, - SystemTime::now() + Duration::from_millis(100), - Span::current(), - ) + .start_request(Request { + id: 0, + context: context::current(), + message: (), + }) .unwrap(); tx.send(ClientMessage::Cancel { @@ -951,7 +970,7 @@ mod tests { Poll::Pending ); - assert_matches!(test_abortable(abort_registration).await, Err(Aborted)); + assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted)); } #[tokio::test] @@ -961,11 +980,11 @@ mod tests { tokio::time::pause(); let _abort_registration = channel .as_mut() - .start_request( - 0, - SystemTime::now() + Duration::from_millis(100), - Span::current(), - ) + .start_request(Request { + id: 0, + context: context::current(), + message: (), + }) .unwrap(); drop(tx); @@ -1001,9 +1020,13 @@ mod tests { let (mut channel, mut tx) = test_channel::<(), ()>(); tokio::time::pause(); - let abort_registration = channel + let req = channel .as_mut() - .start_request(0, SystemTime::now(), Span::current()) + .start_request(Request { + id: 0, + context: context::current(), + message: (), + }) .unwrap(); tokio::time::advance(std::time::Duration::from_secs(1000)).await; @@ -1013,7 +1036,7 @@ mod tests { channel.as_mut().poll_next(&mut noop_context()), Poll::Ready(Some(Ok(_))) ); - assert_matches!(test_abortable(abort_registration).await, Err(Aborted)); + assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted)); } #[tokio::test] @@ -1022,7 +1045,11 @@ mod tests { channel .as_mut() - .start_request(0, SystemTime::now(), Span::current()) + .start_request(Request { + id: 0, + context: context::current(), + message: (), + }) .unwrap(); assert_eq!(channel.in_flight_requests(), 1); channel @@ -1043,7 +1070,11 @@ mod tests { requests .as_mut() .channel_pin_mut() - .start_request(0, SystemTime::now(), Span::current()) + .start_request(Request { + id: 0, + context: context::current(), + message: (), + }) .unwrap(); requests .as_mut() @@ -1069,7 +1100,11 @@ mod tests { requests .as_mut() .channel_pin_mut() - .start_request(1, SystemTime::now(), Span::current()) + .start_request(Request { + id: 1, + context: context::current(), + message: (), + }) .unwrap(); assert_matches!( @@ -1086,7 +1121,11 @@ mod tests { requests .as_mut() .channel_pin_mut() - .start_request(0, SystemTime::now(), Span::current()) + .start_request(Request { + id: 0, + context: context::current(), + message: (), + }) .unwrap(); requests .as_mut() @@ -1101,7 +1140,11 @@ mod tests { requests .as_mut() .channel_pin_mut() - .start_request(1, SystemTime::now(), Span::current()) + .start_request(Request { + id: 1, + context: context::current(), + message: (), + }) .unwrap(); requests .as_mut() diff --git a/tarpc/src/server/filter.rs b/tarpc/src/server/filter.rs index b5b5a95..295fd95 100644 --- a/tarpc/src/server/filter.rs +++ b/tarpc/src/server/filter.rs @@ -9,12 +9,11 @@ use crate::{ util::Compact, }; use fnv::FnvHashMap; -use futures::{future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*}; +use futures::{prelude::*, ready, stream::Fuse, task::*}; use pin_project::pin_project; use std::sync::{Arc, Weak}; use std::{ collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin, - time::SystemTime, }; use tokio::sync::mpsc; use tracing::{debug, info, trace}; @@ -116,15 +115,6 @@ where fn transport(&self) -> &Self::Transport { self.inner.transport() } - - fn start_request( - mut self: Pin<&mut Self>, - id: u64, - deadline: SystemTime, - span: tracing::Span, - ) -> Result { - self.inner_pin_mut().start_request(id, deadline, span) - } } impl TrackedChannel { diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 94ffce3..709c90c 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -6,10 +6,10 @@ use crate::{ context, - server::{Channel, Config}, + server::{Channel, Config, TrackedRequest}, Request, Response, }; -use futures::{future::AbortRegistration, task::*, Sink, Stream}; +use futures::{task::*, Sink, Stream}; use pin_project::pin_project; use std::{collections::VecDeque, io, pin::Pin, time::SystemTime}; use tracing::Span; @@ -62,7 +62,7 @@ impl Sink> for FakeChannel> { } } -impl Channel for FakeChannel>, Response> +impl Channel for FakeChannel>, Response> where Req: Unpin, { @@ -81,34 +81,28 @@ where fn transport(&self) -> &() { &() } - - fn start_request( - self: Pin<&mut Self>, - id: u64, - deadline: SystemTime, - span: Span, - ) -> Result { - self.project() - .in_flight_requests - .start_request(id, deadline, span) - } } -impl FakeChannel>, Response> { +impl FakeChannel>, Response> { pub fn push_req(&mut self, id: u64, message: Req) { - self.stream.push_back(Ok(Request { - context: context::Context { - deadline: SystemTime::UNIX_EPOCH, - trace_context: Default::default(), + let (_, abort_registration) = futures::future::AbortHandle::new_pair(); + self.stream.push_back(Ok(TrackedRequest { + request: Request { + context: context::Context { + deadline: SystemTime::UNIX_EPOCH, + trace_context: Default::default(), + }, + id, + message, }, - id, - message, + abort_registration, + span: Span::none(), })); } } impl FakeChannel<(), ()> { - pub fn default() -> FakeChannel>, Response> { + pub fn default() -> FakeChannel>, Response> { FakeChannel { stream: Default::default(), sink: Default::default(), diff --git a/tarpc/src/server/throttle.rs b/tarpc/src/server/throttle.rs index 477a920..a02e60f 100644 --- a/tarpc/src/server/throttle.rs +++ b/tarpc/src/server/throttle.rs @@ -6,10 +6,9 @@ use super::{Channel, Config}; use crate::{Response, ServerError}; -use futures::{future::AbortRegistration, prelude::*, ready, task::*}; +use futures::{prelude::*, ready, task::*}; use pin_project::pin_project; -use std::{io, pin::Pin, time::SystemTime}; -use tracing::Span; +use std::{io, pin::Pin}; /// A [`Channel`] that limits the number of concurrent /// requests by throttling. @@ -54,19 +53,18 @@ where ready!(self.as_mut().project().inner.poll_ready(cx)?); match ready!(self.as_mut().project().inner.poll_next(cx)?) { - Some(request) => { - tracing::debug!( - rpc.trace_id = %request.context.trace_id(), + Some(r) => { + let _entered = r.span.enter(); + tracing::info!( in_flight_requests = self.as_mut().in_flight_requests(), - max_in_flight_requests = *self.as_mut().project().max_in_flight_requests, - "At in-flight request limit", + "ThrottleRequest", ); self.as_mut().start_send(Response { - request_id: request.id, + request_id: r.request.id, message: Err(ServerError { kind: io::ErrorKind::WouldBlock, - detail: "Server throttled the request.".into(), + detail: "server throttled the request.".into(), }), })?; } @@ -128,15 +126,6 @@ where fn transport(&self) -> &Self::Transport { self.inner.transport() } - - fn start_request( - self: Pin<&mut Self>, - id: u64, - deadline: SystemTime, - span: Span, - ) -> Result { - self.project().inner.start_request(id, deadline, span) - } } /// A stream of throttling channels. @@ -183,15 +172,16 @@ where mod tests { use super::*; - use crate::{ - server::{ - in_flight_requests::AlreadyExistsError, - testing::{self, FakeChannel, PollExt}, - }, - Request, + use crate::server::{ + testing::{self, FakeChannel, PollExt}, + TrackedRequest, }; use pin_utils::pin_mut; - use std::{marker::PhantomData, time::Duration}; + use std::{ + marker::PhantomData, + time::{Duration, SystemTime}, + }; + use tracing::Span; #[tokio::test] async fn throttler_in_flight_requests() { @@ -215,25 +205,6 @@ mod tests { assert_eq!(throttler.as_mut().in_flight_requests(), 5); } - #[tokio::test] - async fn throttler_start_request() { - let throttler = Throttler { - max_in_flight_requests: 0, - inner: FakeChannel::default::(), - }; - - pin_mut!(throttler); - throttler - .as_mut() - .start_request( - 1, - SystemTime::now() + Duration::from_secs(1), - Span::current(), - ) - .unwrap(); - assert_eq!(throttler.inner.in_flight_requests.len(), 1); - } - #[test] fn throttler_poll_next_done() { let throttler = Throttler { @@ -259,7 +230,7 @@ mod tests { throttler .as_mut() .poll_next(&mut testing::cx())? - .map(|r| r.map(|r| (r.id, r.message))), + .map(|r| r.map(|r| (r.request.id, r.request.message))), Poll::Ready(Some((0, 1))) ); Ok(()) @@ -294,7 +265,8 @@ mod tests { ghost: PhantomData In>, } impl PendingSink<(), ()> { - pub fn default() -> PendingSink>, Response> { + pub fn default( + ) -> PendingSink>, Response> { PendingSink { ghost: PhantomData } } } @@ -319,7 +291,7 @@ mod tests { Poll::Pending } } - impl Channel for PendingSink>, Response> { + impl Channel for PendingSink>, Response> { type Req = Req; type Resp = Resp; type Transport = (); @@ -332,14 +304,6 @@ mod tests { fn transport(&self) -> &() { &() } - fn start_request( - self: Pin<&mut Self>, - _id: u64, - _deadline: SystemTime, - _span: tracing::Span, - ) -> Result { - unimplemented!() - } } }