From 72d5dbba89d4de7092d75c13812260b977ad069a Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Sun, 7 Mar 2021 17:42:50 -0800 Subject: [PATCH] Cleanup wrap-up. - Remove unnecessary Sync and Clone bounds. - Merge client and client::channel modules. - Run cargo clippy in the pre-push hook. - Put DispatchResponse.cancellation in an Option. Previously, the cancellation logic looked to see if `complete == true`, but it's a bit less error prone to put the Cancellation in an Option, so that the request can't accidentally be cancelled. - Remove some unnecessary pins/projections. - Clean up docs a bit. rustdoc had some warnings that are now gone. --- hooks/pre-push | 11 +- plugins/src/lib.rs | 21 +- tarpc/src/client.rs | 705 +++++++++++++++++++++++++++++- tarpc/src/client/channel.rs | 684 ----------------------------- tarpc/src/lib.rs | 2 +- tarpc/src/server.rs | 25 +- tarpc/src/server/filter.rs | 70 +-- tarpc/tests/service_functional.rs | 35 ++ 8 files changed, 803 insertions(+), 750 deletions(-) delete mode 100644 tarpc/src/client/channel.rs diff --git a/hooks/pre-push b/hooks/pre-push index 9b62c71..7b527e0 100755 --- a/hooks/pre-push +++ b/hooks/pre-push @@ -84,11 +84,6 @@ command -v rustup &>/dev/null if [ "$?" == 0 ]; then printf "${SUCCESS}\n" - check_toolchain nightly - if [ ${TOOLCHAIN_RESULT} == 1 ]; then - exit 1 - fi - try_run "Building ... " cargo +stable build --color=always try_run "Testing ... " cargo +stable test --color=always try_run "Testing with all features enabled ... " cargo +stable test --all-features --color=always @@ -97,6 +92,12 @@ if [ "$?" == 0 ]; then try_run "Running example \"$EXAMPLE\" ... " cargo +stable run --example $EXAMPLE done + check_toolchain nightly + if [ ${TOOLCHAIN_RESULT} != 1 ]; then + try_run "Running clippy ... " cargo +nightly clippy --color=always -Z unstable-options -- --deny warnings + fi + + fi exit $PREPUSH_RESULT diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 0964b29..b6f5b9e 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -215,9 +215,15 @@ impl Parse for DeriveSerde { } } -/// Generates: -/// - derive of Debug, serde Serialize & Deserialize -/// - serde crate annotation +/// A helper attribute to avoid a direct dependency on Serde. +/// +/// Adds the following annotations to the annotated item: +/// +/// ```rust +/// #[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)] +/// #[serde(crate = "tarpc::serde")] +/// # struct Foo; +/// ``` #[proc_macro_attribute] pub fn derive_serde(_attr: TokenStream, item: TokenStream) -> TokenStream { let mut gen: proc_macro2::TokenStream = quote! { @@ -482,10 +488,11 @@ impl<'a> ServiceGenerator<'a> { quote! { #( #attrs )* - #vis trait #service_ident: Clone { + #vis trait #service_ident: Sized { #( #types_and_fns )* - /// Returns a serving function to use with [tarpc::server::InFlightRequest::execute]. + /// Returns a serving function to use with + /// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute). fn serve(self) -> #server_ident { #server_ident { service: self } } @@ -662,7 +669,7 @@ impl<'a> ServiceGenerator<'a> { quote! { #[allow(unused)] #[derive(Clone, Debug)] - /// The client stub that makes RPC calls to the server. ALl request methods return + /// The client stub that makes RPC calls to the server. All request methods return /// [Futures](std::future::Future). #vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>); } @@ -683,7 +690,7 @@ impl<'a> ServiceGenerator<'a> { #vis fn new(config: tarpc::client::Config, transport: T) -> tarpc::client::NewClient< Self, - tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T> + tarpc::client::RequestDispatch<#request_ident, #response_ident, T> > where T: tarpc::Transport, tarpc::Response<#response_ident>> diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 8cc24ae..95fc91b 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -6,14 +6,25 @@ //! Provides a client that connects to a server and sends multiplexed requests. -use futures::prelude::*; -use std::fmt; -use std::io; - -/// Provides a [`Client`] backed by a transport. -pub mod channel; mod in_flight_requests; -pub use channel::{new, Channel}; + +use crate::{ + context, trace::SpanId, ClientMessage, PollContext, PollIo, Request, Response, Transport, +}; +use futures::{prelude::*, ready, stream::Fuse, task::*}; +use in_flight_requests::InFlightRequests; +use log::{info, trace}; +use pin_project::{pin_project, pinned_drop}; +use std::{ + convert::TryFrom, + fmt, io, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; +use tokio::sync::{mpsc, oneshot}; /// Settings that control the behavior of the client. #[derive(Clone, Debug)] @@ -71,3 +82,683 @@ impl fmt::Debug for NewClient { write!(fmt, "NewClient") } } + +#[allow(dead_code)] +#[allow(clippy::no_effect)] +const CHECK_USIZE: () = { + if std::mem::size_of::() > std::mem::size_of::() { + // TODO: replace this with panic!() as soon as RFC 2345 gets stabilized + ["usize is too big to fit in u64"][42]; + } +}; + +/// Handles communication from the client to request dispatch. +#[derive(Debug)] +pub struct Channel { + to_dispatch: mpsc::Sender>, + /// Channel to send a cancel message to the dispatcher. + cancellation: RequestCancellation, + /// The ID to use for the next request to stage. + next_request_id: Arc, +} + +impl Clone for Channel { + fn clone(&self) -> Self { + Self { + to_dispatch: self.to_dispatch.clone(), + cancellation: self.cancellation.clone(), + next_request_id: self.next_request_id.clone(), + } + } +} + +impl Channel { + /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that + /// resolves when the request is sent (not when the response is received). + fn send( + &self, + mut ctx: context::Context, + request: Req, + ) -> impl Future>> + '_ { + // Convert the context to the call context. + ctx.trace_context.parent_id = Some(ctx.trace_context.span_id); + ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng()); + + let (response_completion, response) = oneshot::channel(); + let cancellation = self.cancellation.clone(); + let request_id = + u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); + + // DispatchResponse impls Drop to cancel in-flight requests. It should be created before + // sending out the request; otherwise, the response future could be dropped after the + // request is sent out but before DispatchResponse is created, rendering the cancellation + // logic inactive. + let response = DispatchResponse { + response, + request_id, + cancellation: Some(cancellation), + ctx, + }; + async move { + self.to_dispatch + .send(DispatchRequest { + ctx, + request_id, + request, + response_completion, + }) + .await + .map_err(|mpsc::error::SendError(_)| { + io::Error::from(io::ErrorKind::ConnectionReset) + })?; + Ok(response) + } + } + + /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that + /// resolves to the response. + pub async fn call(&self, ctx: context::Context, request: Req) -> io::Result { + let dispatch_response = self.send(ctx, request).await?; + dispatch_response.await + } +} + +/// A server response that is completed by request dispatch when the corresponding response +/// arrives off the wire. +#[pin_project(PinnedDrop)] +#[derive(Debug)] +struct DispatchResponse { + response: oneshot::Receiver>, + ctx: context::Context, + cancellation: Option, + request_id: u64, +} + +impl Future for DispatchResponse { + type Output = io::Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let resp = ready!(self.response.poll_unpin(cx)); + self.cancellation.take(); + Poll::Ready(match resp { + Ok(resp) => Ok(resp.message?), + Err(oneshot::error::RecvError { .. }) => { + // The oneshot is Canceled when the dispatch task ends. In that case, + // there's nothing listening on the other side, so there's no point in + // propagating cancellation. + Err(io::Error::from(io::ErrorKind::ConnectionReset)) + } + }) + } +} + +// Cancels the request when dropped, if not already complete. +#[pinned_drop] +impl PinnedDrop for DispatchResponse { + fn drop(mut self: Pin<&mut Self>) { + let self_ = self.project(); + if let Some(cancellation) = self_.cancellation { + // The receiver needs to be closed to handle the edge case that the request has not + // yet been received by the dispatch task. It is possible for the cancel message to + // arrive before the request itself, in which case the request could get stuck in the + // dispatch map forever if the server never responds (e.g. if the server dies while + // responding). Even if the server does respond, it will have unnecessarily done work + // for a client no longer waiting for a response. To avoid this, the dispatch task + // checks if the receiver is closed before inserting the request in the map. By + // closing the receiver before sending the cancel message, it is guaranteed that if the + // dispatch task misses an early-arriving cancellation message, then it will see the + // receiver as closed. + self_.response.close(); + cancellation.cancel(*self_.request_id); + } + } +} + +/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the +/// channel. +pub fn new( + config: Config, + transport: C, +) -> NewClient, RequestDispatch> +where + C: Transport, Response>, +{ + let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); + let (cancellation, canceled_requests) = cancellations(); + let canceled_requests = canceled_requests; + + NewClient { + client: Channel { + to_dispatch, + cancellation, + next_request_id: Arc::new(AtomicUsize::new(0)), + }, + dispatch: RequestDispatch { + config, + canceled_requests, + transport: transport.fuse(), + in_flight_requests: InFlightRequests::default(), + pending_requests, + }, + } +} + +/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations, +/// and dispatching responses to the appropriate channel. +#[pin_project] +#[derive(Debug)] +pub struct RequestDispatch { + /// Writes requests to the wire and reads responses off the wire. + #[pin] + transport: Fuse, + /// Requests waiting to be written to the wire. + #[pin] + pending_requests: mpsc::Receiver>, + /// Requests that were dropped. + #[pin] + canceled_requests: CanceledRequests, + /// Requests already written to the wire that haven't yet received responses. + in_flight_requests: InFlightRequests, + /// Configures limits to prevent unlimited resource usage. + config: Config, +} + +impl RequestDispatch +where + C: Transport, Response>, +{ + fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { + self.as_mut().project().in_flight_requests + } + + fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { + Poll::Ready( + match ready!(self.as_mut().project().transport.poll_next(cx)?) { + Some(response) => { + self.complete(response); + Some(Ok(())) + } + None => None, + }, + ) + } + + fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { + enum ReceiverStatus { + NotReady, + Closed, + } + + let pending_requests_status = match self.as_mut().poll_next_request(cx)? { + Poll::Ready(Some(dispatch_request)) => { + self.as_mut().write_request(dispatch_request)?; + return Poll::Ready(Some(Ok(()))); + } + Poll::Ready(None) => ReceiverStatus::Closed, + Poll::Pending => ReceiverStatus::NotReady, + }; + + let canceled_requests_status = match self.as_mut().poll_next_cancellation(cx)? { + Poll::Ready(Some((context, request_id))) => { + self.as_mut().write_cancel(context, request_id)?; + return Poll::Ready(Some(Ok(()))); + } + Poll::Ready(None) => ReceiverStatus::Closed, + Poll::Pending => ReceiverStatus::NotReady, + }; + + // Receiving Poll::Ready(None) when polling expired requests never indicates "Closed", + // because there can temporarily be zero in-flight rquests. Therefore, there is no need to + // track the status like is done with pending and cancelled requests. + if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx)? { + // Expired requests are considered complete; there is no compelling reason to send a + // cancellation message to the server, since it will have already exhausted its + // allotted processing time. + return Poll::Ready(Some(Ok(()))); + } + + match (pending_requests_status, canceled_requests_status) { + (ReceiverStatus::Closed, ReceiverStatus::Closed) => { + ready!(self.as_mut().project().transport.poll_flush(cx)?); + Poll::Ready(None) + } + (ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => { + // No more messages to process, so flush any messages buffered in the transport. + ready!(self.as_mut().project().transport.poll_flush(cx)?); + + // Even if we fully-flush, we return Pending, because we have no more requests + // or cancellations right now. + Poll::Pending + } + } + } + + /// Yields the next pending request, if one is ready to be sent. + fn poll_next_request( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> PollIo> { + if self.in_flight_requests().len() >= self.config.max_in_flight_requests { + info!( + "At in-flight request capacity ({}/{}).", + self.in_flight_requests().len(), + self.config.max_in_flight_requests + ); + + // No need to schedule a wakeup, because timers and responses are responsible + // for clearing out in-flight requests. + return Poll::Pending; + } + + while self + .as_mut() + .project() + .transport + .poll_ready(cx)? + .is_pending() + { + // We can't yield a request-to-be-sent before the transport is capable of buffering it. + ready!(self.as_mut().project().transport.poll_flush(cx)?); + } + + loop { + match ready!(self.as_mut().project().pending_requests.poll_recv(cx)) { + Some(request) => { + if request.response_completion.is_closed() { + trace!( + "[{}] Request canceled before being sent.", + request.ctx.trace_id() + ); + continue; + } + + return Poll::Ready(Some(Ok(request))); + } + None => return Poll::Ready(None), + } + } + } + + /// Yields the next pending cancellation, and, if one is ready, cancels the associated request. + fn poll_next_cancellation( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> PollIo<(context::Context, u64)> { + while self + .as_mut() + .project() + .transport + .poll_ready(cx)? + .is_pending() + { + ready!(self.as_mut().project().transport.poll_flush(cx)?); + } + + loop { + let cancellation = self + .as_mut() + .project() + .canceled_requests + .poll_next_unpin(cx); + match ready!(cancellation) { + Some(request_id) => { + if let Some(ctx) = self.in_flight_requests().cancel_request(request_id) { + return Poll::Ready(Some(Ok((ctx, request_id)))); + } + } + None => return Poll::Ready(None), + } + } + } + + fn write_request( + mut self: Pin<&mut Self>, + dispatch_request: DispatchRequest, + ) -> io::Result<()> { + let request_id = dispatch_request.request_id; + let request = ClientMessage::Request(Request { + id: request_id, + message: dispatch_request.request, + context: context::Context { + deadline: dispatch_request.ctx.deadline, + trace_context: dispatch_request.ctx.trace_context, + }, + }); + self.as_mut().project().transport.start_send(request)?; + self.in_flight_requests() + .insert_request( + request_id, + dispatch_request.ctx, + dispatch_request.response_completion, + ) + .expect("Request IDs should be unique"); + Ok(()) + } + + fn write_cancel( + mut self: Pin<&mut Self>, + context: context::Context, + request_id: u64, + ) -> io::Result<()> { + let trace_id = *context.trace_id(); + let cancel = ClientMessage::Cancel { + trace_context: context.trace_context, + request_id, + }; + self.as_mut().project().transport.start_send(cancel)?; + trace!("[{}] Cancel message sent.", trace_id); + Ok(()) + } + + /// Sends a server response to the client task that initiated the associated request. + fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { + self.in_flight_requests().complete_request(response) + } +} + +impl Future for RequestDispatch +where + C: Transport, Response>, +{ + type Output = anyhow::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match ( + self.as_mut() + .pump_read(cx) + .context("failed to read from transport")?, + self.as_mut() + .pump_write(cx) + .context("failed to write to transport")?, + ) { + (Poll::Ready(None), _) => { + info!("Shutdown: read half closed, so shutting down."); + return Poll::Ready(Ok(())); + } + (read, Poll::Ready(None)) => { + if self.in_flight_requests().is_empty() { + info!("Shutdown: write half closed, and no requests in flight."); + return Poll::Ready(Ok(())); + } + info!( + "Shutdown: write half closed, and {} requests in flight.", + self.in_flight_requests().len() + ); + match read { + Poll::Ready(Some(())) => continue, + _ => return Poll::Pending, + } + } + (Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {} + _ => return Poll::Pending, + } + } + } +} + +/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage +/// the lifecycle of the request. +#[derive(Debug)] +struct DispatchRequest { + pub ctx: context::Context, + pub request_id: u64, + pub request: Req, + pub response_completion: oneshot::Sender>, +} + +/// Sends request cancellation signals. +#[derive(Debug, Clone)] +struct RequestCancellation(mpsc::UnboundedSender); + +/// A stream of IDs of requests that have been canceled. +#[derive(Debug)] +struct CanceledRequests(mpsc::UnboundedReceiver); + +/// Returns a channel to send request cancellation messages. +fn cancellations() -> (RequestCancellation, CanceledRequests) { + // Unbounded because messages are sent in the drop fn. This is fine, because it's still + // bounded by the number of in-flight requests. Additionally, each request has a clone + // of the sender, so the bounded channel would have the same behavior, + // since it guarantees a slot. + let (tx, rx) = mpsc::unbounded_channel(); + (RequestCancellation(tx), CanceledRequests(rx)) +} + +impl RequestCancellation { + /// Cancels the request with ID `request_id`. + fn cancel(&mut self, request_id: u64) { + let _ = self.0.send(request_id); + } +} + +impl Stream for CanceledRequests { + type Item = u64; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.0.poll_recv(cx) + } +} + +#[cfg(test)] +mod tests { + use super::{ + cancellations, CanceledRequests, Channel, DispatchResponse, RequestCancellation, + RequestDispatch, + }; + use crate::{ + client::{in_flight_requests::InFlightRequests, Config}, + context, + transport::{self, channel::UnboundedChannel}, + ClientMessage, Response, + }; + use futures::{prelude::*, task::*}; + use std::{pin::Pin, sync::atomic::AtomicUsize, sync::Arc}; + use tokio::sync::{mpsc, oneshot}; + + #[tokio::test] + async fn dispatch_response_cancels_on_drop() { + let (cancellation, mut canceled_requests) = cancellations(); + let (_, response) = oneshot::channel(); + drop(DispatchResponse:: { + response, + cancellation: Some(cancellation), + request_id: 3, + ctx: context::current(), + }); + // resp's drop() is run, which should send a cancel message. + let cx = &mut Context::from_waker(&noop_waker_ref()); + assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(Some(3))); + } + + #[tokio::test] + async fn dispatch_response_doesnt_cancel_after_complete() { + let (cancellation, mut canceled_requests) = cancellations(); + let (tx, response) = oneshot::channel(); + tx.send(Response { + request_id: 0, + message: Ok("well done"), + }) + .unwrap(); + { + DispatchResponse { + response, + cancellation: Some(cancellation), + request_id: 3, + ctx: context::current(), + } + .await + .unwrap(); + // resp's drop() is run, but should not send a cancel message. + } + let cx = &mut Context::from_waker(&noop_waker_ref()); + assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(None)); + } + + #[tokio::test] + async fn stage_request() { + let (mut dispatch, mut channel, _server_channel) = set_up(); + let dispatch = Pin::new(&mut dispatch); + let cx = &mut Context::from_waker(&noop_waker_ref()); + + let _resp = send_request(&mut channel, "hi").await; + + let req = dispatch.poll_next_request(cx).ready(); + assert!(req.is_some()); + + let req = req.unwrap(); + assert_eq!(req.request_id, 0); + assert_eq!(req.request, "hi".to_string()); + } + + // Regression test for https://github.com/google/tarpc/issues/220 + #[tokio::test] + async fn stage_request_channel_dropped_doesnt_panic() { + let (mut dispatch, mut channel, mut server_channel) = set_up(); + let mut dispatch = Pin::new(&mut dispatch); + let cx = &mut Context::from_waker(&noop_waker_ref()); + + let _ = send_request(&mut channel, "hi").await; + drop(channel); + + assert!(dispatch.as_mut().poll(cx).is_ready()); + send_response( + &mut server_channel, + Response { + request_id: 0, + message: Ok("hello".into()), + }, + ) + .await; + dispatch.await.unwrap(); + } + + #[tokio::test] + async fn stage_request_response_future_dropped_is_canceled_before_sending() { + let (mut dispatch, mut channel, _server_channel) = set_up(); + let dispatch = Pin::new(&mut dispatch); + let cx = &mut Context::from_waker(&noop_waker_ref()); + + let _ = send_request(&mut channel, "hi").await; + + // Drop the channel so polling returns none if no requests are currently ready. + drop(channel); + // Test that a request future dropped before it's processed by dispatch will cause the request + // to not be added to the in-flight request map. + assert!(dispatch.poll_next_request(cx).ready().is_none()); + } + + #[tokio::test] + async fn stage_request_response_future_dropped_is_canceled_after_sending() { + let (mut dispatch, mut channel, _server_channel) = set_up(); + let cx = &mut Context::from_waker(&noop_waker_ref()); + let mut dispatch = Pin::new(&mut dispatch); + + let req = send_request(&mut channel, "hi").await; + + assert!(dispatch.as_mut().pump_write(cx).ready().is_some()); + assert!(!dispatch.in_flight_requests().is_empty()); + + // Test that a request future dropped after it's processed by dispatch will cause the request + // to be removed from the in-flight request map. + drop(req); + if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(cx).unwrap() { + // ok + } else { + panic!("Expected request to be cancelled") + }; + assert!(dispatch.in_flight_requests().is_empty()); + } + + #[tokio::test] + async fn stage_request_response_closed_skipped() { + let (mut dispatch, mut channel, _server_channel) = set_up(); + let dispatch = Pin::new(&mut dispatch); + let cx = &mut Context::from_waker(&noop_waker_ref()); + + // Test that a request future that's closed its receiver but not yet canceled its request -- + // i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request + // map. + let mut resp = send_request(&mut channel, "hi").await; + resp.response.close(); + + assert!(dispatch.poll_next_request(cx).is_pending()); + } + + fn set_up() -> ( + RequestDispatch, ClientMessage>>, + Channel, + UnboundedChannel, Response>, + ) { + let _ = env_logger::try_init(); + + let (to_dispatch, pending_requests) = mpsc::channel(1); + let (cancel_tx, canceled_requests) = mpsc::unbounded_channel(); + let (client_channel, server_channel) = transport::channel::unbounded(); + + let dispatch = RequestDispatch:: { + transport: client_channel.fuse(), + pending_requests: pending_requests, + canceled_requests: CanceledRequests(canceled_requests), + in_flight_requests: InFlightRequests::default(), + config: Config::default(), + }; + + let cancellation = RequestCancellation(cancel_tx); + let channel = Channel { + to_dispatch, + cancellation, + next_request_id: Arc::new(AtomicUsize::new(0)), + }; + + (dispatch, channel, server_channel) + } + + async fn send_request( + channel: &mut Channel, + request: &str, + ) -> DispatchResponse { + channel + .send(context::current(), request.to_string()) + .await + .unwrap() + } + + async fn send_response( + channel: &mut UnboundedChannel, Response>, + response: Response, + ) { + channel.send(response).await.unwrap(); + } + + trait PollTest { + type T; + fn unwrap(self) -> Poll; + fn ready(self) -> Self::T; + } + + impl PollTest for Poll>> + where + E: ::std::fmt::Display, + { + type T = Option; + + fn unwrap(self) -> Poll> { + match self { + Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)), + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()), + Poll::Pending => Poll::Pending, + } + } + + fn ready(self) -> Option { + match self { + Poll::Ready(Some(Ok(t))) => Some(t), + Poll::Ready(None) => None, + Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()), + Poll::Pending => panic!("Pending"), + } + } + } +} diff --git a/tarpc/src/client/channel.rs b/tarpc/src/client/channel.rs deleted file mode 100644 index bd12135..0000000 --- a/tarpc/src/client/channel.rs +++ /dev/null @@ -1,684 +0,0 @@ -// Copyright 2018 Google LLC -// -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file or at -// https://opensource.org/licenses/MIT. - -use crate::{ - client::in_flight_requests::InFlightRequests, context, trace::SpanId, ClientMessage, - PollContext, PollIo, Request, Response, Transport, -}; -use futures::{prelude::*, ready, stream::Fuse, task::*}; -use log::{info, trace}; -use pin_project::{pin_project, pinned_drop}; -use std::{ - convert::TryFrom, - io, - pin::Pin, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, -}; -use tokio::sync::{mpsc, oneshot}; - -#[allow(dead_code)] -#[allow(clippy::no_effect)] -const CHECK_USIZE: () = { - if std::mem::size_of::() > std::mem::size_of::() { - // TODO: replace this with panic!() as soon as RFC 2345 gets stabilized - ["usize is too big to fit in u64"][42]; - } -}; - -use super::{Config, NewClient}; - -/// Handles communication from the client to request dispatch. -#[derive(Debug)] -pub struct Channel { - to_dispatch: mpsc::Sender>, - /// Channel to send a cancel message to the dispatcher. - cancellation: RequestCancellation, - /// The ID to use for the next request to stage. - next_request_id: Arc, -} - -impl Clone for Channel { - fn clone(&self) -> Self { - Self { - to_dispatch: self.to_dispatch.clone(), - cancellation: self.cancellation.clone(), - next_request_id: self.next_request_id.clone(), - } - } -} - -impl Channel { - /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that - /// resolves when the request is sent (not when the response is received). - fn send( - &self, - mut ctx: context::Context, - request: Req, - ) -> impl Future>> + '_ { - // Convert the context to the call context. - ctx.trace_context.parent_id = Some(ctx.trace_context.span_id); - ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng()); - - let (response_completion, response) = oneshot::channel(); - let cancellation = self.cancellation.clone(); - let request_id = - u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); - - // DispatchResponse impls Drop to cancel in-flight requests. It should be created before - // sending out the request; otherwise, the response future could be dropped after the - // request is sent out but before DispatchResponse is created, rendering the cancellation - // logic inactive. - let response = DispatchResponse { - response, - complete: false, - request_id, - cancellation, - ctx, - }; - async move { - self.to_dispatch - .send(DispatchRequest { - ctx, - request_id, - request, - response_completion, - }) - .await - .map_err(|mpsc::error::SendError(_)| { - io::Error::from(io::ErrorKind::ConnectionReset) - })?; - Ok(response) - } - } - - /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that - /// resolves to the response. - pub async fn call(&self, ctx: context::Context, request: Req) -> io::Result { - let dispatch_response = self.send(ctx, request).await?; - dispatch_response.await - } -} - -/// A server response that is completed by request dispatch when the corresponding response -/// arrives off the wire. -#[pin_project(PinnedDrop)] -#[derive(Debug)] -struct DispatchResponse { - response: oneshot::Receiver>, - ctx: context::Context, - complete: bool, - cancellation: RequestCancellation, - request_id: u64, -} - -impl Future for DispatchResponse { - type Output = io::Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let resp = ready!(self.response.poll_unpin(cx)); - self.complete = true; - Poll::Ready(match resp { - Ok(resp) => Ok(resp.message?), - Err(oneshot::error::RecvError { .. }) => { - // The oneshot is Canceled when the dispatch task ends. In that case, - // there's nothing listening on the other side, so there's no point in - // propagating cancellation. - Err(io::Error::from(io::ErrorKind::ConnectionReset)) - } - }) - } -} - -// Cancels the request when dropped, if not already complete. -#[pinned_drop] -impl PinnedDrop for DispatchResponse { - fn drop(mut self: Pin<&mut Self>) { - if !self.complete { - // The receiver needs to be closed to handle the edge case that the request has not - // yet been received by the dispatch task. It is possible for the cancel message to - // arrive before the request itself, in which case the request could get stuck in the - // dispatch map forever if the server never responds (e.g. if the server dies while - // responding). Even if the server does respond, it will have unnecessarily done work - // for a client no longer waiting for a response. To avoid this, the dispatch task - // checks if the receiver is closed before inserting the request in the map. By - // closing the receiver before sending the cancel message, it is guaranteed that if the - // dispatch task misses an early-arriving cancellation message, then it will see the - // receiver as closed. - self.response.close(); - let request_id = self.request_id; - self.cancellation.cancel(request_id); - } - } -} - -/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the -/// channel. -pub fn new( - config: Config, - transport: C, -) -> NewClient, RequestDispatch> -where - C: Transport, Response>, -{ - let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); - let (cancellation, canceled_requests) = cancellations(); - let canceled_requests = canceled_requests; - - NewClient { - client: Channel { - to_dispatch, - cancellation, - next_request_id: Arc::new(AtomicUsize::new(0)), - }, - dispatch: RequestDispatch { - config, - canceled_requests, - transport: transport.fuse(), - in_flight_requests: InFlightRequests::default(), - pending_requests, - }, - } -} - -/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations, -/// and dispatching responses to the appropriate channel. -#[pin_project] -#[derive(Debug)] -pub struct RequestDispatch { - /// Writes requests to the wire and reads responses off the wire. - #[pin] - transport: Fuse, - /// Requests waiting to be written to the wire. - #[pin] - pending_requests: mpsc::Receiver>, - /// Requests that were dropped. - #[pin] - canceled_requests: CanceledRequests, - /// Requests already written to the wire that haven't yet received responses. - in_flight_requests: InFlightRequests, - /// Configures limits to prevent unlimited resource usage. - config: Config, -} - -impl RequestDispatch -where - C: Transport, Response>, -{ - fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { - self.as_mut().project().in_flight_requests - } - - fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { - Poll::Ready( - match ready!(self.as_mut().project().transport.poll_next(cx)?) { - Some(response) => { - self.complete(response); - Some(Ok(())) - } - None => None, - }, - ) - } - - fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { - enum ReceiverStatus { - NotReady, - Closed, - } - - let pending_requests_status = match self.as_mut().poll_next_request(cx)? { - Poll::Ready(Some(dispatch_request)) => { - self.as_mut().write_request(dispatch_request)?; - return Poll::Ready(Some(Ok(()))); - } - Poll::Ready(None) => ReceiverStatus::Closed, - Poll::Pending => ReceiverStatus::NotReady, - }; - - let canceled_requests_status = match self.as_mut().poll_next_cancellation(cx)? { - Poll::Ready(Some((context, request_id))) => { - self.as_mut().write_cancel(context, request_id)?; - return Poll::Ready(Some(Ok(()))); - } - Poll::Ready(None) => ReceiverStatus::Closed, - Poll::Pending => ReceiverStatus::NotReady, - }; - - // Receiving Poll::Ready(None) when polling expired requests never indicates "Closed", - // because there can temporarily be zero in-flight rquests. Therefore, there is no need to - // track the status like is done with pending and cancelled requests. - if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx)? { - // Expired requests are considered complete; there is no compelling reason to send a - // cancellation message to the server, since it will have already exhausted its - // allotted processing time. - return Poll::Ready(Some(Ok(()))); - } - - match (pending_requests_status, canceled_requests_status) { - (ReceiverStatus::Closed, ReceiverStatus::Closed) => { - ready!(self.as_mut().project().transport.poll_flush(cx)?); - Poll::Ready(None) - } - (ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => { - // No more messages to process, so flush any messages buffered in the transport. - ready!(self.as_mut().project().transport.poll_flush(cx)?); - - // Even if we fully-flush, we return Pending, because we have no more requests - // or cancellations right now. - Poll::Pending - } - } - } - - /// Yields the next pending request, if one is ready to be sent. - fn poll_next_request( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> PollIo> { - if self.in_flight_requests().len() >= self.config.max_in_flight_requests { - info!( - "At in-flight request capacity ({}/{}).", - self.in_flight_requests().len(), - self.config.max_in_flight_requests - ); - - // No need to schedule a wakeup, because timers and responses are responsible - // for clearing out in-flight requests. - return Poll::Pending; - } - - while self - .as_mut() - .project() - .transport - .poll_ready(cx)? - .is_pending() - { - // We can't yield a request-to-be-sent before the transport is capable of buffering it. - ready!(self.as_mut().project().transport.poll_flush(cx)?); - } - - loop { - match ready!(self.as_mut().project().pending_requests.poll_recv(cx)) { - Some(request) => { - if request.response_completion.is_closed() { - trace!( - "[{}] Request canceled before being sent.", - request.ctx.trace_id() - ); - continue; - } - - return Poll::Ready(Some(Ok(request))); - } - None => return Poll::Ready(None), - } - } - } - - /// Yields the next pending cancellation, and, if one is ready, cancels the associated request. - fn poll_next_cancellation( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> PollIo<(context::Context, u64)> { - while self - .as_mut() - .project() - .transport - .poll_ready(cx)? - .is_pending() - { - ready!(self.as_mut().project().transport.poll_flush(cx)?); - } - - loop { - let cancellation = self - .as_mut() - .project() - .canceled_requests - .poll_next_unpin(cx); - match ready!(cancellation) { - Some(request_id) => { - if let Some(ctx) = self.in_flight_requests().cancel_request(request_id) { - return Poll::Ready(Some(Ok((ctx, request_id)))); - } - } - None => return Poll::Ready(None), - } - } - } - - fn write_request( - mut self: Pin<&mut Self>, - dispatch_request: DispatchRequest, - ) -> io::Result<()> { - let request_id = dispatch_request.request_id; - let request = ClientMessage::Request(Request { - id: request_id, - message: dispatch_request.request, - context: context::Context { - deadline: dispatch_request.ctx.deadline, - trace_context: dispatch_request.ctx.trace_context, - }, - }); - self.as_mut().project().transport.start_send(request)?; - self.in_flight_requests() - .insert_request( - request_id, - dispatch_request.ctx, - dispatch_request.response_completion, - ) - .expect("Request IDs should be unique"); - Ok(()) - } - - fn write_cancel( - mut self: Pin<&mut Self>, - context: context::Context, - request_id: u64, - ) -> io::Result<()> { - let trace_id = *context.trace_id(); - let cancel = ClientMessage::Cancel { - trace_context: context.trace_context, - request_id, - }; - self.as_mut().project().transport.start_send(cancel)?; - trace!("[{}] Cancel message sent.", trace_id); - Ok(()) - } - - /// Sends a server response to the client task that initiated the associated request. - fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { - self.in_flight_requests().complete_request(response) - } -} - -impl Future for RequestDispatch -where - C: Transport, Response>, -{ - type Output = anyhow::Result<()>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - match ( - self.as_mut() - .pump_read(cx) - .context("failed to read from transport")?, - self.as_mut() - .pump_write(cx) - .context("failed to write to transport")?, - ) { - (Poll::Ready(None), _) => { - info!("Shutdown: read half closed, so shutting down."); - return Poll::Ready(Ok(())); - } - (read, Poll::Ready(None)) => { - if self.in_flight_requests().is_empty() { - info!("Shutdown: write half closed, and no requests in flight."); - return Poll::Ready(Ok(())); - } - info!( - "Shutdown: write half closed, and {} requests in flight.", - self.in_flight_requests().len() - ); - match read { - Poll::Ready(Some(())) => continue, - _ => return Poll::Pending, - } - } - (Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {} - _ => return Poll::Pending, - } - } - } -} - -/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage -/// the lifecycle of the request. -#[derive(Debug)] -struct DispatchRequest { - pub ctx: context::Context, - pub request_id: u64, - pub request: Req, - pub response_completion: oneshot::Sender>, -} - -/// Sends request cancellation signals. -#[derive(Debug, Clone)] -struct RequestCancellation(mpsc::UnboundedSender); - -/// A stream of IDs of requests that have been canceled. -#[derive(Debug)] -struct CanceledRequests(mpsc::UnboundedReceiver); - -/// Returns a channel to send request cancellation messages. -fn cancellations() -> (RequestCancellation, CanceledRequests) { - // Unbounded because messages are sent in the drop fn. This is fine, because it's still - // bounded by the number of in-flight requests. Additionally, each request has a clone - // of the sender, so the bounded channel would have the same behavior, - // since it guarantees a slot. - let (tx, rx) = mpsc::unbounded_channel(); - (RequestCancellation(tx), CanceledRequests(rx)) -} - -impl RequestCancellation { - /// Cancels the request with ID `request_id`. - fn cancel(&mut self, request_id: u64) { - let _ = self.0.send(request_id); - } -} - -impl Stream for CanceledRequests { - type Item = u64; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.0.poll_recv(cx) - } -} - -#[cfg(test)] -mod tests { - use super::{ - cancellations, CanceledRequests, Channel, DispatchResponse, RequestCancellation, - RequestDispatch, - }; - use crate::{ - client::{in_flight_requests::InFlightRequests, Config}, - context, - transport::{self, channel::UnboundedChannel}, - ClientMessage, Response, - }; - use futures::{prelude::*, task::*}; - use std::{pin::Pin, sync::atomic::AtomicUsize, sync::Arc}; - use tokio::sync::{mpsc, oneshot}; - - #[tokio::test] - async fn dispatch_response_cancels_on_drop() { - let (cancellation, mut canceled_requests) = cancellations(); - let (_, response) = oneshot::channel(); - drop(DispatchResponse:: { - response, - cancellation, - complete: false, - request_id: 3, - ctx: context::current(), - }); - // resp's drop() is run, which should send a cancel message. - let cx = &mut Context::from_waker(&noop_waker_ref()); - assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(Some(3))); - } - - #[tokio::test] - async fn stage_request() { - let (mut dispatch, mut channel, _server_channel) = set_up(); - let dispatch = Pin::new(&mut dispatch); - let cx = &mut Context::from_waker(&noop_waker_ref()); - - let _resp = send_request(&mut channel, "hi").await; - - let req = dispatch.poll_next_request(cx).ready(); - assert!(req.is_some()); - - let req = req.unwrap(); - assert_eq!(req.request_id, 0); - assert_eq!(req.request, "hi".to_string()); - } - - // Regression test for https://github.com/google/tarpc/issues/220 - #[tokio::test] - async fn stage_request_channel_dropped_doesnt_panic() { - let (mut dispatch, mut channel, mut server_channel) = set_up(); - let mut dispatch = Pin::new(&mut dispatch); - let cx = &mut Context::from_waker(&noop_waker_ref()); - - let _ = send_request(&mut channel, "hi").await; - drop(channel); - - assert!(dispatch.as_mut().poll(cx).is_ready()); - send_response( - &mut server_channel, - Response { - request_id: 0, - message: Ok("hello".into()), - }, - ) - .await; - dispatch.await.unwrap(); - } - - #[tokio::test] - async fn stage_request_response_future_dropped_is_canceled_before_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up(); - let dispatch = Pin::new(&mut dispatch); - let cx = &mut Context::from_waker(&noop_waker_ref()); - - let _ = send_request(&mut channel, "hi").await; - - // Drop the channel so polling returns none if no requests are currently ready. - drop(channel); - // Test that a request future dropped before it's processed by dispatch will cause the request - // to not be added to the in-flight request map. - assert!(dispatch.poll_next_request(cx).ready().is_none()); - } - - #[tokio::test] - async fn stage_request_response_future_dropped_is_canceled_after_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up(); - let cx = &mut Context::from_waker(&noop_waker_ref()); - let mut dispatch = Pin::new(&mut dispatch); - - let req = send_request(&mut channel, "hi").await; - - assert!(dispatch.as_mut().pump_write(cx).ready().is_some()); - assert!(!dispatch.in_flight_requests().is_empty()); - - // Test that a request future dropped after it's processed by dispatch will cause the request - // to be removed from the in-flight request map. - drop(req); - if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(cx).unwrap() { - // ok - } else { - panic!("Expected request to be cancelled") - }; - assert!(dispatch.in_flight_requests().is_empty()); - } - - #[tokio::test] - async fn stage_request_response_closed_skipped() { - let (mut dispatch, mut channel, _server_channel) = set_up(); - let dispatch = Pin::new(&mut dispatch); - let cx = &mut Context::from_waker(&noop_waker_ref()); - - // Test that a request future that's closed its receiver but not yet canceled its request -- - // i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request - // map. - let mut resp = send_request(&mut channel, "hi").await; - resp.response.close(); - - assert!(dispatch.poll_next_request(cx).is_pending()); - } - - fn set_up() -> ( - RequestDispatch, ClientMessage>>, - Channel, - UnboundedChannel, Response>, - ) { - let _ = env_logger::try_init(); - - let (to_dispatch, pending_requests) = mpsc::channel(1); - let (cancel_tx, canceled_requests) = mpsc::unbounded_channel(); - let (client_channel, server_channel) = transport::channel::unbounded(); - - let dispatch = RequestDispatch:: { - transport: client_channel.fuse(), - pending_requests: pending_requests, - canceled_requests: CanceledRequests(canceled_requests), - in_flight_requests: InFlightRequests::default(), - config: Config::default(), - }; - - let cancellation = RequestCancellation(cancel_tx); - let channel = Channel { - to_dispatch, - cancellation, - next_request_id: Arc::new(AtomicUsize::new(0)), - }; - - (dispatch, channel, server_channel) - } - - async fn send_request( - channel: &mut Channel, - request: &str, - ) -> DispatchResponse { - channel - .send(context::current(), request.to_string()) - .await - .unwrap() - } - - async fn send_response( - channel: &mut UnboundedChannel, Response>, - response: Response, - ) { - channel.send(response).await.unwrap(); - } - - trait PollTest { - type T; - fn unwrap(self) -> Poll; - fn ready(self) -> Self::T; - } - - impl PollTest for Poll>> - where - E: ::std::fmt::Display, - { - type T = Option; - - fn unwrap(self) -> Poll> { - match self { - Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)), - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()), - Poll::Pending => Poll::Pending, - } - } - - fn ready(self) -> Option { - match self { - Poll::Ready(Some(Ok(t))) => Some(t), - Poll::Ready(None) => None, - Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()), - Poll::Pending => panic!("Pending"), - } - } - } -} diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index f08e471..21972a3 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -131,7 +131,7 @@ //! ``` //! //! Lastly let's write our `main` that will start the server. While this example uses an -//! [in-process channel](rpc::transport::channel), tarpc also ships a generic [`serde_transport`] +//! [in-process channel](transport::channel), tarpc also ships a generic [`serde_transport`] //! behind the `serde-transport` feature, with additional [TCP](serde_transport::tcp) functionality //! available behind the `tcp` feature. //! diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 70485d7..abfc39e 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -72,7 +72,7 @@ pub trait Serve { impl Serve for F where - F: FnOnce(context::Context, Req) -> Fut + Clone, + F: FnOnce(context::Context, Req) -> Fut, Fut: Future, { type Resp = Resp; @@ -182,13 +182,12 @@ impl fmt::Debug for BaseChannel { /// The server end of an open connection with a client, streaming in requests from, and sinking /// responses to, the client. /// -/// /// The ways to use a Channel, in order of simplest to most complex, is: /// 1. [Channel::execute] - Requires the `tokio1` feature. This method is best for those who /// do not have specific scheduling needs and whose services are `Send + 'static`. /// 2. [Channel::requests] - This method is best for those who need direct access to individual /// requests, or are not using `tokio`, or want control over [futures](Future) scheduling. -/// 3. [Raw stream]() - A user is free to manually handle requests produced by +/// 3. [Raw stream](Stream) - A user is free to manually handle requests produced by /// Channel. If they do so, they should uphold the service contract: /// 1. All work being done as part of processing request `request_id` is aborted when /// either of the following occurs: @@ -199,8 +198,10 @@ impl fmt::Debug for BaseChannel { /// [sent](Sink::start_send) into the Channel. Because there is no guarantee that a /// cancellation message will ever be received for a request, services should strive to clean /// up Channel resources by sending a response for every request. For example, [`BaseChannel`] -/// has a map of requests to [abort handles][AbortHandle] whose entries are only removed -/// upon either request cancellation or response completion. +/// has a map of requests to [abort handles][futures::future::AbortHandle] whose entries are +/// only removed upon either request cancellation, response completion, or deadline +/// expiration. For requests with long deadlines that have been abandoned without a response, +/// some cleanup may never happen. pub trait Channel where Self: Transport::Resp>, Request<::Req>>, @@ -260,7 +261,7 @@ where fn execute(self, serve: S) -> TokioChannelExecutor, S> where Self: Sized, - S: Serve + Send + Sync + 'static, + S: Serve + Send + 'static, S::Fut: Send, Self::Req: Send + 'static, Self::Resp: Send + 'static, @@ -406,7 +407,7 @@ where cx: &mut Context<'_>, ) -> PollIo> { loop { - match ready!(self.as_mut().project().channel.poll_next(cx)?) { + match ready!(self.channel_pin_mut().poll_next(cx)?) { Some(request) => { trace!( "[{}] Handling request with deadline {}.", @@ -617,7 +618,7 @@ where /// by [spawning](tokio::spawn) each handler on tokio's default executor. pub fn execute(self, serve: S) -> TokioChannelExecutor where - S: Serve + Send + Sync + 'static, + S: Serve + Send + 'static, { TokioChannelExecutor { inner: self, serve } } @@ -635,8 +636,8 @@ pub struct TokioServerExecutor { serve: S, } -/// A future that drives the server by [spawning](tokio::spawn) each [response handler](ResponseHandler) -/// on tokio's default executor. +/// A future that drives the server by [spawning](tokio::spawn) each [response +/// handler](InFlightRequest::execute) on tokio's default executor. #[pin_project] #[derive(Debug)] #[cfg(feature = "tokio1")] @@ -670,7 +671,7 @@ where C: Channel + Send + 'static, C::Req: Send + 'static, C::Resp: Send + 'static, - Se: Serve + Send + Sync + 'static + Clone, + Se: Serve + Send + 'static + Clone, Se::Fut: Send, { type Output = (); @@ -690,7 +691,7 @@ where C: Channel + 'static, C::Req: Send + 'static, C::Resp: Send + 'static, - S: Serve + Send + Sync + 'static + Clone, + S: Serve + Send + 'static + Clone, S::Fut: Send, { type Output = (); diff --git a/tarpc/src/server/filter.rs b/tarpc/src/server/filter.rs index 2fc8d60..0bee5ec 100644 --- a/tarpc/src/server/filter.rs +++ b/tarpc/src/server/filter.rs @@ -15,7 +15,7 @@ use log::{debug, info, trace}; use pin_project::pin_project; use std::sync::{Arc, Weak}; use std::{ - collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin, + collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin, time::SystemTime, }; use tokio::sync::mpsc; @@ -30,9 +30,7 @@ where #[pin] listener: Fuse, channels_per_key: u32, - #[pin] dropped_keys: mpsc::UnboundedReceiver, - #[pin] dropped_keys_tx: mpsc::UnboundedSender, key_counts: FnvHashMap>>, keymaker: F, @@ -66,8 +64,8 @@ where { type Item = ::Item; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.channel().poll_next(cx) + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.inner_pin_mut().poll_next(cx) } } @@ -77,20 +75,20 @@ where { type Error = C::Error; - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.channel().poll_ready(cx) + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.inner_pin_mut().poll_ready(cx) } - fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { - self.channel().start_send(item) + fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.inner_pin_mut().start_send(item) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.channel().poll_flush(cx) + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.inner_pin_mut().poll_flush(cx) } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.channel().poll_close(cx) + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.inner_pin_mut().poll_close(cx) } } @@ -116,15 +114,15 @@ where } fn start_request( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, id: u64, deadline: SystemTime, ) -> Result { - self.project().inner.start_request(id, deadline) + self.inner_pin_mut().start_request(id, deadline) } - fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo { - self.project().inner.poll_expired(cx) + fn poll_expired(mut self: Pin<&mut Self>, cx: &mut Context) -> PollIo { + self.inner_pin_mut().poll_expired(cx) } } @@ -135,8 +133,8 @@ impl TrackedChannel { } /// Returns the pinned inner channel. - fn channel(self: Pin<&mut Self>) -> Pin<&mut C> { - self.project().inner + fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> { + self.as_mut().project().inner } } @@ -166,6 +164,10 @@ where K: fmt::Display + Eq + Hash + Clone + Unpin, F: Fn(&S::Item) -> K, { + fn listener_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse> { + self.as_mut().project().listener + } + fn handle_new_channel( mut self: Pin<&mut Self>, stream: S::Item, @@ -177,7 +179,7 @@ where "[{}] Opening channel ({}/{}) channels for key.", key, Arc::strong_count(&tracker), - self.as_mut().project().channels_per_key + self.channels_per_key ); Ok(TrackedChannel { @@ -186,15 +188,14 @@ where }) } - fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result>, K> { - let channels_per_key = self.channels_per_key; - let dropped_keys = self.dropped_keys_tx.clone(); - let key_counts = &mut self.as_mut().project().key_counts; - match key_counts.entry(key.clone()) { + fn increment_channels_for_key(self: Pin<&mut Self>, key: K) -> Result>, K> { + let self_ = self.project(); + let dropped_keys = self_.dropped_keys_tx; + match self_.key_counts.entry(key.clone()) { Entry::Vacant(vacant) => { let tracker = Arc::new(Tracker { key: Some(key), - dropped_keys, + dropped_keys: dropped_keys.clone(), }); vacant.insert(Arc::downgrade(&tracker)); @@ -202,17 +203,17 @@ where } Entry::Occupied(mut o) => { let count = o.get().strong_count(); - if count >= channels_per_key.try_into().unwrap() { + if count >= TryFrom::try_from(*self_.channels_per_key).unwrap() { info!( "[{}] Opened max channels from key ({}/{}).", - key, count, channels_per_key + key, count, self_.channels_per_key ); Err(key) } else { Ok(o.get().upgrade().unwrap_or_else(|| { let tracker = Arc::new(Tracker { key: Some(key), - dropped_keys, + dropped_keys: dropped_keys.clone(), }); *o.get_mut() = Arc::downgrade(&tracker); @@ -227,18 +228,19 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, K>>> { - match ready!(self.as_mut().project().listener.poll_next_unpin(cx)) { + match ready!(self.listener_pin_mut().poll_next_unpin(cx)) { Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))), None => Poll::Ready(None), } } - fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - match ready!(self.as_mut().project().dropped_keys.poll_recv(cx)) { + fn poll_closed_channels(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let self_ = self.project(); + match ready!(self_.dropped_keys.poll_recv(cx)) { Some(key) => { debug!("All channels dropped for key [{}]", key); - self.as_mut().project().key_counts.remove(&key); - self.as_mut().project().key_counts.compact(0.1); + self_.key_counts.remove(&key); + self_.key_counts.compact(0.1); Poll::Ready(()) } None => unreachable!("Holding a copy of closed_channels and didn't close it."), diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 488e37c..0fc16eb 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -212,3 +212,38 @@ async fn concurrent_join_all() -> io::Result<()> { Ok(()) } + +#[tokio::test] +async fn counter() -> io::Result<()> { + #[tarpc::service] + trait Counter { + async fn count() -> u32; + } + + struct CountService(u32); + + impl Counter for &mut CountService { + type CountFut = futures::future::Ready; + + fn count(self, _: context::Context) -> Self::CountFut { + self.0 += 1; + futures::future::ready(self.0) + } + } + + let (tx, rx) = channel::unbounded(); + tokio::spawn(async { + let mut requests = BaseChannel::with_defaults(rx).requests(); + let mut counter = CountService(0); + + while let Some(Ok(request)) = requests.next().await { + request.execute(counter.serve()).await; + } + }); + + let client = CounterClient::new(client::Config::default(), tx).spawn()?; + assert_matches!(client.count(context::current()).await, Ok(1)); + assert_matches!(client.count(context::current()).await, Ok(2)); + + Ok(()) +}