Cleanup wrap-up.

- Remove unnecessary Sync and Clone bounds.
- Merge client and client::channel modules.
- Run cargo clippy in the pre-push hook.
- Put DispatchResponse.cancellation in an Option.  Previously, the
  cancellation logic looked to see if `complete == true`, but it's a bit
  less error prone to put the Cancellation in an Option, so that the
  request can't accidentally be cancelled.
- Remove some unnecessary pins/projections.
- Clean up docs a bit. rustdoc had some warnings that are now gone.
This commit is contained in:
Tim Kuehn
2021-03-07 17:42:50 -08:00
parent e75193c191
commit 72d5dbba89
8 changed files with 803 additions and 750 deletions

View File

@@ -84,11 +84,6 @@ command -v rustup &>/dev/null
if [ "$?" == 0 ]; then
printf "${SUCCESS}\n"
check_toolchain nightly
if [ ${TOOLCHAIN_RESULT} == 1 ]; then
exit 1
fi
try_run "Building ... " cargo +stable build --color=always
try_run "Testing ... " cargo +stable test --color=always
try_run "Testing with all features enabled ... " cargo +stable test --all-features --color=always
@@ -97,6 +92,12 @@ if [ "$?" == 0 ]; then
try_run "Running example \"$EXAMPLE\" ... " cargo +stable run --example $EXAMPLE
done
check_toolchain nightly
if [ ${TOOLCHAIN_RESULT} != 1 ]; then
try_run "Running clippy ... " cargo +nightly clippy --color=always -Z unstable-options -- --deny warnings
fi
fi
exit $PREPUSH_RESULT

View File

@@ -215,9 +215,15 @@ impl Parse for DeriveSerde {
}
}
/// Generates:
/// - derive of Debug, serde Serialize & Deserialize
/// - serde crate annotation
/// A helper attribute to avoid a direct dependency on Serde.
///
/// Adds the following annotations to the annotated item:
///
/// ```rust
/// #[derive(tarpc::serde::Serialize, tarpc::serde::Deserialize)]
/// #[serde(crate = "tarpc::serde")]
/// # struct Foo;
/// ```
#[proc_macro_attribute]
pub fn derive_serde(_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut gen: proc_macro2::TokenStream = quote! {
@@ -482,10 +488,11 @@ impl<'a> ServiceGenerator<'a> {
quote! {
#( #attrs )*
#vis trait #service_ident: Clone {
#vis trait #service_ident: Sized {
#( #types_and_fns )*
/// Returns a serving function to use with [tarpc::server::InFlightRequest::execute].
/// Returns a serving function to use with
/// [InFlightRequest::execute](tarpc::server::InFlightRequest::execute).
fn serve(self) -> #server_ident<Self> {
#server_ident { service: self }
}
@@ -662,7 +669,7 @@ impl<'a> ServiceGenerator<'a> {
quote! {
#[allow(unused)]
#[derive(Clone, Debug)]
/// The client stub that makes RPC calls to the server. ALl request methods return
/// The client stub that makes RPC calls to the server. All request methods return
/// [Futures](std::future::Future).
#vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>);
}
@@ -683,7 +690,7 @@ impl<'a> ServiceGenerator<'a> {
#vis fn new<T>(config: tarpc::client::Config, transport: T)
-> tarpc::client::NewClient<
Self,
tarpc::client::channel::RequestDispatch<#request_ident, #response_ident, T>
tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
>
where
T: tarpc::Transport<tarpc::ClientMessage<#request_ident>, tarpc::Response<#response_ident>>

View File

@@ -6,14 +6,25 @@
//! Provides a client that connects to a server and sends multiplexed requests.
use futures::prelude::*;
use std::fmt;
use std::io;
/// Provides a [`Client`] backed by a transport.
pub mod channel;
mod in_flight_requests;
pub use channel::{new, Channel};
use crate::{
context, trace::SpanId, ClientMessage, PollContext, PollIo, Request, Response, Transport,
};
use futures::{prelude::*, ready, stream::Fuse, task::*};
use in_flight_requests::InFlightRequests;
use log::{info, trace};
use pin_project::{pin_project, pinned_drop};
use std::{
convert::TryFrom,
fmt, io,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use tokio::sync::{mpsc, oneshot};
/// Settings that control the behavior of the client.
#[derive(Clone, Debug)]
@@ -71,3 +82,683 @@ impl<C, D> fmt::Debug for NewClient<C, D> {
write!(fmt, "NewClient")
}
}
#[allow(dead_code)]
#[allow(clippy::no_effect)]
const CHECK_USIZE: () = {
if std::mem::size_of::<usize>() > std::mem::size_of::<u64>() {
// TODO: replace this with panic!() as soon as RFC 2345 gets stabilized
["usize is too big to fit in u64"][42];
}
};
/// Handles communication from the client to request dispatch.
#[derive(Debug)]
pub struct Channel<Req, Resp> {
to_dispatch: mpsc::Sender<DispatchRequest<Req, Resp>>,
/// Channel to send a cancel message to the dispatcher.
cancellation: RequestCancellation,
/// The ID to use for the next request to stage.
next_request_id: Arc<AtomicUsize>,
}
impl<Req, Resp> Clone for Channel<Req, Resp> {
fn clone(&self) -> Self {
Self {
to_dispatch: self.to_dispatch.clone(),
cancellation: self.cancellation.clone(),
next_request_id: self.next_request_id.clone(),
}
}
}
impl<Req, Resp> Channel<Req, Resp> {
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves when the request is sent (not when the response is received).
fn send(
&self,
mut ctx: context::Context,
request: Req,
) -> impl Future<Output = io::Result<DispatchResponse<Resp>>> + '_ {
// Convert the context to the call context.
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
let (response_completion, response) = oneshot::channel();
let cancellation = self.cancellation.clone();
let request_id =
u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
// DispatchResponse impls Drop to cancel in-flight requests. It should be created before
// sending out the request; otherwise, the response future could be dropped after the
// request is sent out but before DispatchResponse is created, rendering the cancellation
// logic inactive.
let response = DispatchResponse {
response,
request_id,
cancellation: Some(cancellation),
ctx,
};
async move {
self.to_dispatch
.send(DispatchRequest {
ctx,
request_id,
request,
response_completion,
})
.await
.map_err(|mpsc::error::SendError(_)| {
io::Error::from(io::ErrorKind::ConnectionReset)
})?;
Ok(response)
}
}
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves to the response.
pub async fn call(&self, ctx: context::Context, request: Req) -> io::Result<Resp> {
let dispatch_response = self.send(ctx, request).await?;
dispatch_response.await
}
}
/// A server response that is completed by request dispatch when the corresponding response
/// arrives off the wire.
#[pin_project(PinnedDrop)]
#[derive(Debug)]
struct DispatchResponse<Resp> {
response: oneshot::Receiver<Response<Resp>>,
ctx: context::Context,
cancellation: Option<RequestCancellation>,
request_id: u64,
}
impl<Resp> Future for DispatchResponse<Resp> {
type Output = io::Result<Resp>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
let resp = ready!(self.response.poll_unpin(cx));
self.cancellation.take();
Poll::Ready(match resp {
Ok(resp) => Ok(resp.message?),
Err(oneshot::error::RecvError { .. }) => {
// The oneshot is Canceled when the dispatch task ends. In that case,
// there's nothing listening on the other side, so there's no point in
// propagating cancellation.
Err(io::Error::from(io::ErrorKind::ConnectionReset))
}
})
}
}
// Cancels the request when dropped, if not already complete.
#[pinned_drop]
impl<Resp> PinnedDrop for DispatchResponse<Resp> {
fn drop(mut self: Pin<&mut Self>) {
let self_ = self.project();
if let Some(cancellation) = self_.cancellation {
// The receiver needs to be closed to handle the edge case that the request has not
// yet been received by the dispatch task. It is possible for the cancel message to
// arrive before the request itself, in which case the request could get stuck in the
// dispatch map forever if the server never responds (e.g. if the server dies while
// responding). Even if the server does respond, it will have unnecessarily done work
// for a client no longer waiting for a response. To avoid this, the dispatch task
// checks if the receiver is closed before inserting the request in the map. By
// closing the receiver before sending the cancel message, it is guaranteed that if the
// dispatch task misses an early-arriving cancellation message, then it will see the
// receiver as closed.
self_.response.close();
cancellation.cancel(*self_.request_id);
}
}
}
/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the
/// channel.
pub fn new<Req, Resp, C>(
config: Config,
transport: C,
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C>>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
let (cancellation, canceled_requests) = cancellations();
let canceled_requests = canceled_requests;
NewClient {
client: Channel {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicUsize::new(0)),
},
dispatch: RequestDispatch {
config,
canceled_requests,
transport: transport.fuse(),
in_flight_requests: InFlightRequests::default(),
pending_requests,
},
}
}
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
/// and dispatching responses to the appropriate channel.
#[pin_project]
#[derive(Debug)]
pub struct RequestDispatch<Req, Resp, C> {
/// Writes requests to the wire and reads responses off the wire.
#[pin]
transport: Fuse<C>,
/// Requests waiting to be written to the wire.
#[pin]
pending_requests: mpsc::Receiver<DispatchRequest<Req, Resp>>,
/// Requests that were dropped.
#[pin]
canceled_requests: CanceledRequests,
/// Requests already written to the wire that haven't yet received responses.
in_flight_requests: InFlightRequests<Resp>,
/// Configures limits to prevent unlimited resource usage.
config: Config,
}
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<Resp> {
self.as_mut().project().in_flight_requests
}
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
Poll::Ready(
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
Some(response) => {
self.complete(response);
Some(Ok(()))
}
None => None,
},
)
}
fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
enum ReceiverStatus {
NotReady,
Closed,
}
let pending_requests_status = match self.as_mut().poll_next_request(cx)? {
Poll::Ready(Some(dispatch_request)) => {
self.as_mut().write_request(dispatch_request)?;
return Poll::Ready(Some(Ok(())));
}
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
};
let canceled_requests_status = match self.as_mut().poll_next_cancellation(cx)? {
Poll::Ready(Some((context, request_id))) => {
self.as_mut().write_cancel(context, request_id)?;
return Poll::Ready(Some(Ok(())));
}
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
};
// Receiving Poll::Ready(None) when polling expired requests never indicates "Closed",
// because there can temporarily be zero in-flight rquests. Therefore, there is no need to
// track the status like is done with pending and cancelled requests.
if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx)? {
// Expired requests are considered complete; there is no compelling reason to send a
// cancellation message to the server, since it will have already exhausted its
// allotted processing time.
return Poll::Ready(Some(Ok(())));
}
match (pending_requests_status, canceled_requests_status) {
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
ready!(self.as_mut().project().transport.poll_flush(cx)?);
Poll::Ready(None)
}
(ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
// No more messages to process, so flush any messages buffered in the transport.
ready!(self.as_mut().project().transport.poll_flush(cx)?);
// Even if we fully-flush, we return Pending, because we have no more requests
// or cancellations right now.
Poll::Pending
}
}
}
/// Yields the next pending request, if one is ready to be sent.
fn poll_next_request(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<DispatchRequest<Req, Resp>> {
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
info!(
"At in-flight request capacity ({}/{}).",
self.in_flight_requests().len(),
self.config.max_in_flight_requests
);
// No need to schedule a wakeup, because timers and responses are responsible
// for clearing out in-flight requests.
return Poll::Pending;
}
while self
.as_mut()
.project()
.transport
.poll_ready(cx)?
.is_pending()
{
// We can't yield a request-to-be-sent before the transport is capable of buffering it.
ready!(self.as_mut().project().transport.poll_flush(cx)?);
}
loop {
match ready!(self.as_mut().project().pending_requests.poll_recv(cx)) {
Some(request) => {
if request.response_completion.is_closed() {
trace!(
"[{}] Request canceled before being sent.",
request.ctx.trace_id()
);
continue;
}
return Poll::Ready(Some(Ok(request)));
}
None => return Poll::Ready(None),
}
}
}
/// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
fn poll_next_cancellation(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<(context::Context, u64)> {
while self
.as_mut()
.project()
.transport
.poll_ready(cx)?
.is_pending()
{
ready!(self.as_mut().project().transport.poll_flush(cx)?);
}
loop {
let cancellation = self
.as_mut()
.project()
.canceled_requests
.poll_next_unpin(cx);
match ready!(cancellation) {
Some(request_id) => {
if let Some(ctx) = self.in_flight_requests().cancel_request(request_id) {
return Poll::Ready(Some(Ok((ctx, request_id))));
}
}
None => return Poll::Ready(None),
}
}
}
fn write_request(
mut self: Pin<&mut Self>,
dispatch_request: DispatchRequest<Req, Resp>,
) -> io::Result<()> {
let request_id = dispatch_request.request_id;
let request = ClientMessage::Request(Request {
id: request_id,
message: dispatch_request.request,
context: context::Context {
deadline: dispatch_request.ctx.deadline,
trace_context: dispatch_request.ctx.trace_context,
},
});
self.as_mut().project().transport.start_send(request)?;
self.in_flight_requests()
.insert_request(
request_id,
dispatch_request.ctx,
dispatch_request.response_completion,
)
.expect("Request IDs should be unique");
Ok(())
}
fn write_cancel(
mut self: Pin<&mut Self>,
context: context::Context,
request_id: u64,
) -> io::Result<()> {
let trace_id = *context.trace_id();
let cancel = ClientMessage::Cancel {
trace_context: context.trace_context,
request_id,
};
self.as_mut().project().transport.start_send(cancel)?;
trace!("[{}] Cancel message sent.", trace_id);
Ok(())
}
/// Sends a server response to the client task that initiated the associated request.
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
self.in_flight_requests().complete_request(response)
}
}
impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
type Output = anyhow::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<anyhow::Result<()>> {
loop {
match (
self.as_mut()
.pump_read(cx)
.context("failed to read from transport")?,
self.as_mut()
.pump_write(cx)
.context("failed to write to transport")?,
) {
(Poll::Ready(None), _) => {
info!("Shutdown: read half closed, so shutting down.");
return Poll::Ready(Ok(()));
}
(read, Poll::Ready(None)) => {
if self.in_flight_requests().is_empty() {
info!("Shutdown: write half closed, and no requests in flight.");
return Poll::Ready(Ok(()));
}
info!(
"Shutdown: write half closed, and {} requests in flight.",
self.in_flight_requests().len()
);
match read {
Poll::Ready(Some(())) => continue,
_ => return Poll::Pending,
}
}
(Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
_ => return Poll::Pending,
}
}
}
}
/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage
/// the lifecycle of the request.
#[derive(Debug)]
struct DispatchRequest<Req, Resp> {
pub ctx: context::Context,
pub request_id: u64,
pub request: Req,
pub response_completion: oneshot::Sender<Response<Resp>>,
}
/// Sends request cancellation signals.
#[derive(Debug, Clone)]
struct RequestCancellation(mpsc::UnboundedSender<u64>);
/// A stream of IDs of requests that have been canceled.
#[derive(Debug)]
struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
/// Returns a channel to send request cancellation messages.
fn cancellations() -> (RequestCancellation, CanceledRequests) {
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
// bounded by the number of in-flight requests. Additionally, each request has a clone
// of the sender, so the bounded channel would have the same behavior,
// since it guarantees a slot.
let (tx, rx) = mpsc::unbounded_channel();
(RequestCancellation(tx), CanceledRequests(rx))
}
impl RequestCancellation {
/// Cancels the request with ID `request_id`.
fn cancel(&mut self, request_id: u64) {
let _ = self.0.send(request_id);
}
}
impl Stream for CanceledRequests {
type Item = u64;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
self.0.poll_recv(cx)
}
}
#[cfg(test)]
mod tests {
use super::{
cancellations, CanceledRequests, Channel, DispatchResponse, RequestCancellation,
RequestDispatch,
};
use crate::{
client::{in_flight_requests::InFlightRequests, Config},
context,
transport::{self, channel::UnboundedChannel},
ClientMessage, Response,
};
use futures::{prelude::*, task::*};
use std::{pin::Pin, sync::atomic::AtomicUsize, sync::Arc};
use tokio::sync::{mpsc, oneshot};
#[tokio::test]
async fn dispatch_response_cancels_on_drop() {
let (cancellation, mut canceled_requests) = cancellations();
let (_, response) = oneshot::channel();
drop(DispatchResponse::<u32> {
response,
cancellation: Some(cancellation),
request_id: 3,
ctx: context::current(),
});
// resp's drop() is run, which should send a cancel message.
let cx = &mut Context::from_waker(&noop_waker_ref());
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(Some(3)));
}
#[tokio::test]
async fn dispatch_response_doesnt_cancel_after_complete() {
let (cancellation, mut canceled_requests) = cancellations();
let (tx, response) = oneshot::channel();
tx.send(Response {
request_id: 0,
message: Ok("well done"),
})
.unwrap();
{
DispatchResponse {
response,
cancellation: Some(cancellation),
request_id: 3,
ctx: context::current(),
}
.await
.unwrap();
// resp's drop() is run, but should not send a cancel message.
}
let cx = &mut Context::from_waker(&noop_waker_ref());
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(None));
}
#[tokio::test]
async fn stage_request() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref());
let _resp = send_request(&mut channel, "hi").await;
let req = dispatch.poll_next_request(cx).ready();
assert!(req.is_some());
let req = req.unwrap();
assert_eq!(req.request_id, 0);
assert_eq!(req.request, "hi".to_string());
}
// Regression test for https://github.com/google/tarpc/issues/220
#[tokio::test]
async fn stage_request_channel_dropped_doesnt_panic() {
let (mut dispatch, mut channel, mut server_channel) = set_up();
let mut dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref());
let _ = send_request(&mut channel, "hi").await;
drop(channel);
assert!(dispatch.as_mut().poll(cx).is_ready());
send_response(
&mut server_channel,
Response {
request_id: 0,
message: Ok("hello".into()),
},
)
.await;
dispatch.await.unwrap();
}
#[tokio::test]
async fn stage_request_response_future_dropped_is_canceled_before_sending() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref());
let _ = send_request(&mut channel, "hi").await;
// Drop the channel so polling returns none if no requests are currently ready.
drop(channel);
// Test that a request future dropped before it's processed by dispatch will cause the request
// to not be added to the in-flight request map.
assert!(dispatch.poll_next_request(cx).ready().is_none());
}
#[tokio::test]
async fn stage_request_response_future_dropped_is_canceled_after_sending() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let cx = &mut Context::from_waker(&noop_waker_ref());
let mut dispatch = Pin::new(&mut dispatch);
let req = send_request(&mut channel, "hi").await;
assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
assert!(!dispatch.in_flight_requests().is_empty());
// Test that a request future dropped after it's processed by dispatch will cause the request
// to be removed from the in-flight request map.
drop(req);
if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(cx).unwrap() {
// ok
} else {
panic!("Expected request to be cancelled")
};
assert!(dispatch.in_flight_requests().is_empty());
}
#[tokio::test]
async fn stage_request_response_closed_skipped() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref());
// Test that a request future that's closed its receiver but not yet canceled its request --
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
// map.
let mut resp = send_request(&mut channel, "hi").await;
resp.response.close();
assert!(dispatch.poll_next_request(cx).is_pending());
}
fn set_up() -> (
RequestDispatch<String, String, UnboundedChannel<Response<String>, ClientMessage<String>>>,
Channel<String, String>,
UnboundedChannel<ClientMessage<String>, Response<String>>,
) {
let _ = env_logger::try_init();
let (to_dispatch, pending_requests) = mpsc::channel(1);
let (cancel_tx, canceled_requests) = mpsc::unbounded_channel();
let (client_channel, server_channel) = transport::channel::unbounded();
let dispatch = RequestDispatch::<String, String, _> {
transport: client_channel.fuse(),
pending_requests: pending_requests,
canceled_requests: CanceledRequests(canceled_requests),
in_flight_requests: InFlightRequests::default(),
config: Config::default(),
};
let cancellation = RequestCancellation(cancel_tx);
let channel = Channel {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicUsize::new(0)),
};
(dispatch, channel, server_channel)
}
async fn send_request(
channel: &mut Channel<String, String>,
request: &str,
) -> DispatchResponse<String> {
channel
.send(context::current(), request.to_string())
.await
.unwrap()
}
async fn send_response(
channel: &mut UnboundedChannel<ClientMessage<String>, Response<String>>,
response: Response<String>,
) {
channel.send(response).await.unwrap();
}
trait PollTest {
type T;
fn unwrap(self) -> Poll<Self::T>;
fn ready(self) -> Self::T;
}
impl<T, E> PollTest for Poll<Option<Result<T, E>>>
where
E: ::std::fmt::Display,
{
type T = Option<T>;
fn unwrap(self) -> Poll<Option<T>> {
match self {
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()),
Poll::Pending => Poll::Pending,
}
}
fn ready(self) -> Option<T> {
match self {
Poll::Ready(Some(Ok(t))) => Some(t),
Poll::Ready(None) => None,
Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()),
Poll::Pending => panic!("Pending"),
}
}
}
}

View File

@@ -1,684 +0,0 @@
// Copyright 2018 Google LLC
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
use crate::{
client::in_flight_requests::InFlightRequests, context, trace::SpanId, ClientMessage,
PollContext, PollIo, Request, Response, Transport,
};
use futures::{prelude::*, ready, stream::Fuse, task::*};
use log::{info, trace};
use pin_project::{pin_project, pinned_drop};
use std::{
convert::TryFrom,
io,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use tokio::sync::{mpsc, oneshot};
#[allow(dead_code)]
#[allow(clippy::no_effect)]
const CHECK_USIZE: () = {
if std::mem::size_of::<usize>() > std::mem::size_of::<u64>() {
// TODO: replace this with panic!() as soon as RFC 2345 gets stabilized
["usize is too big to fit in u64"][42];
}
};
use super::{Config, NewClient};
/// Handles communication from the client to request dispatch.
#[derive(Debug)]
pub struct Channel<Req, Resp> {
to_dispatch: mpsc::Sender<DispatchRequest<Req, Resp>>,
/// Channel to send a cancel message to the dispatcher.
cancellation: RequestCancellation,
/// The ID to use for the next request to stage.
next_request_id: Arc<AtomicUsize>,
}
impl<Req, Resp> Clone for Channel<Req, Resp> {
fn clone(&self) -> Self {
Self {
to_dispatch: self.to_dispatch.clone(),
cancellation: self.cancellation.clone(),
next_request_id: self.next_request_id.clone(),
}
}
}
impl<Req, Resp> Channel<Req, Resp> {
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves when the request is sent (not when the response is received).
fn send(
&self,
mut ctx: context::Context,
request: Req,
) -> impl Future<Output = io::Result<DispatchResponse<Resp>>> + '_ {
// Convert the context to the call context.
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
let (response_completion, response) = oneshot::channel();
let cancellation = self.cancellation.clone();
let request_id =
u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
// DispatchResponse impls Drop to cancel in-flight requests. It should be created before
// sending out the request; otherwise, the response future could be dropped after the
// request is sent out but before DispatchResponse is created, rendering the cancellation
// logic inactive.
let response = DispatchResponse {
response,
complete: false,
request_id,
cancellation,
ctx,
};
async move {
self.to_dispatch
.send(DispatchRequest {
ctx,
request_id,
request,
response_completion,
})
.await
.map_err(|mpsc::error::SendError(_)| {
io::Error::from(io::ErrorKind::ConnectionReset)
})?;
Ok(response)
}
}
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
/// resolves to the response.
pub async fn call(&self, ctx: context::Context, request: Req) -> io::Result<Resp> {
let dispatch_response = self.send(ctx, request).await?;
dispatch_response.await
}
}
/// A server response that is completed by request dispatch when the corresponding response
/// arrives off the wire.
#[pin_project(PinnedDrop)]
#[derive(Debug)]
struct DispatchResponse<Resp> {
response: oneshot::Receiver<Response<Resp>>,
ctx: context::Context,
complete: bool,
cancellation: RequestCancellation,
request_id: u64,
}
impl<Resp> Future for DispatchResponse<Resp> {
type Output = io::Result<Resp>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
let resp = ready!(self.response.poll_unpin(cx));
self.complete = true;
Poll::Ready(match resp {
Ok(resp) => Ok(resp.message?),
Err(oneshot::error::RecvError { .. }) => {
// The oneshot is Canceled when the dispatch task ends. In that case,
// there's nothing listening on the other side, so there's no point in
// propagating cancellation.
Err(io::Error::from(io::ErrorKind::ConnectionReset))
}
})
}
}
// Cancels the request when dropped, if not already complete.
#[pinned_drop]
impl<Resp> PinnedDrop for DispatchResponse<Resp> {
fn drop(mut self: Pin<&mut Self>) {
if !self.complete {
// The receiver needs to be closed to handle the edge case that the request has not
// yet been received by the dispatch task. It is possible for the cancel message to
// arrive before the request itself, in which case the request could get stuck in the
// dispatch map forever if the server never responds (e.g. if the server dies while
// responding). Even if the server does respond, it will have unnecessarily done work
// for a client no longer waiting for a response. To avoid this, the dispatch task
// checks if the receiver is closed before inserting the request in the map. By
// closing the receiver before sending the cancel message, it is guaranteed that if the
// dispatch task misses an early-arriving cancellation message, then it will see the
// receiver as closed.
self.response.close();
let request_id = self.request_id;
self.cancellation.cancel(request_id);
}
}
}
/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the
/// channel.
pub fn new<Req, Resp, C>(
config: Config,
transport: C,
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C>>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
let (cancellation, canceled_requests) = cancellations();
let canceled_requests = canceled_requests;
NewClient {
client: Channel {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicUsize::new(0)),
},
dispatch: RequestDispatch {
config,
canceled_requests,
transport: transport.fuse(),
in_flight_requests: InFlightRequests::default(),
pending_requests,
},
}
}
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
/// and dispatching responses to the appropriate channel.
#[pin_project]
#[derive(Debug)]
pub struct RequestDispatch<Req, Resp, C> {
/// Writes requests to the wire and reads responses off the wire.
#[pin]
transport: Fuse<C>,
/// Requests waiting to be written to the wire.
#[pin]
pending_requests: mpsc::Receiver<DispatchRequest<Req, Resp>>,
/// Requests that were dropped.
#[pin]
canceled_requests: CanceledRequests,
/// Requests already written to the wire that haven't yet received responses.
in_flight_requests: InFlightRequests<Resp>,
/// Configures limits to prevent unlimited resource usage.
config: Config,
}
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<Resp> {
self.as_mut().project().in_flight_requests
}
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
Poll::Ready(
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
Some(response) => {
self.complete(response);
Some(Ok(()))
}
None => None,
},
)
}
fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
enum ReceiverStatus {
NotReady,
Closed,
}
let pending_requests_status = match self.as_mut().poll_next_request(cx)? {
Poll::Ready(Some(dispatch_request)) => {
self.as_mut().write_request(dispatch_request)?;
return Poll::Ready(Some(Ok(())));
}
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
};
let canceled_requests_status = match self.as_mut().poll_next_cancellation(cx)? {
Poll::Ready(Some((context, request_id))) => {
self.as_mut().write_cancel(context, request_id)?;
return Poll::Ready(Some(Ok(())));
}
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
};
// Receiving Poll::Ready(None) when polling expired requests never indicates "Closed",
// because there can temporarily be zero in-flight rquests. Therefore, there is no need to
// track the status like is done with pending and cancelled requests.
if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx)? {
// Expired requests are considered complete; there is no compelling reason to send a
// cancellation message to the server, since it will have already exhausted its
// allotted processing time.
return Poll::Ready(Some(Ok(())));
}
match (pending_requests_status, canceled_requests_status) {
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
ready!(self.as_mut().project().transport.poll_flush(cx)?);
Poll::Ready(None)
}
(ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
// No more messages to process, so flush any messages buffered in the transport.
ready!(self.as_mut().project().transport.poll_flush(cx)?);
// Even if we fully-flush, we return Pending, because we have no more requests
// or cancellations right now.
Poll::Pending
}
}
}
/// Yields the next pending request, if one is ready to be sent.
fn poll_next_request(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<DispatchRequest<Req, Resp>> {
if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
info!(
"At in-flight request capacity ({}/{}).",
self.in_flight_requests().len(),
self.config.max_in_flight_requests
);
// No need to schedule a wakeup, because timers and responses are responsible
// for clearing out in-flight requests.
return Poll::Pending;
}
while self
.as_mut()
.project()
.transport
.poll_ready(cx)?
.is_pending()
{
// We can't yield a request-to-be-sent before the transport is capable of buffering it.
ready!(self.as_mut().project().transport.poll_flush(cx)?);
}
loop {
match ready!(self.as_mut().project().pending_requests.poll_recv(cx)) {
Some(request) => {
if request.response_completion.is_closed() {
trace!(
"[{}] Request canceled before being sent.",
request.ctx.trace_id()
);
continue;
}
return Poll::Ready(Some(Ok(request)));
}
None => return Poll::Ready(None),
}
}
}
/// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
fn poll_next_cancellation(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<(context::Context, u64)> {
while self
.as_mut()
.project()
.transport
.poll_ready(cx)?
.is_pending()
{
ready!(self.as_mut().project().transport.poll_flush(cx)?);
}
loop {
let cancellation = self
.as_mut()
.project()
.canceled_requests
.poll_next_unpin(cx);
match ready!(cancellation) {
Some(request_id) => {
if let Some(ctx) = self.in_flight_requests().cancel_request(request_id) {
return Poll::Ready(Some(Ok((ctx, request_id))));
}
}
None => return Poll::Ready(None),
}
}
}
fn write_request(
mut self: Pin<&mut Self>,
dispatch_request: DispatchRequest<Req, Resp>,
) -> io::Result<()> {
let request_id = dispatch_request.request_id;
let request = ClientMessage::Request(Request {
id: request_id,
message: dispatch_request.request,
context: context::Context {
deadline: dispatch_request.ctx.deadline,
trace_context: dispatch_request.ctx.trace_context,
},
});
self.as_mut().project().transport.start_send(request)?;
self.in_flight_requests()
.insert_request(
request_id,
dispatch_request.ctx,
dispatch_request.response_completion,
)
.expect("Request IDs should be unique");
Ok(())
}
fn write_cancel(
mut self: Pin<&mut Self>,
context: context::Context,
request_id: u64,
) -> io::Result<()> {
let trace_id = *context.trace_id();
let cancel = ClientMessage::Cancel {
trace_context: context.trace_context,
request_id,
};
self.as_mut().project().transport.start_send(cancel)?;
trace!("[{}] Cancel message sent.", trace_id);
Ok(())
}
/// Sends a server response to the client task that initiated the associated request.
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
self.in_flight_requests().complete_request(response)
}
}
impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
type Output = anyhow::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<anyhow::Result<()>> {
loop {
match (
self.as_mut()
.pump_read(cx)
.context("failed to read from transport")?,
self.as_mut()
.pump_write(cx)
.context("failed to write to transport")?,
) {
(Poll::Ready(None), _) => {
info!("Shutdown: read half closed, so shutting down.");
return Poll::Ready(Ok(()));
}
(read, Poll::Ready(None)) => {
if self.in_flight_requests().is_empty() {
info!("Shutdown: write half closed, and no requests in flight.");
return Poll::Ready(Ok(()));
}
info!(
"Shutdown: write half closed, and {} requests in flight.",
self.in_flight_requests().len()
);
match read {
Poll::Ready(Some(())) => continue,
_ => return Poll::Pending,
}
}
(Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
_ => return Poll::Pending,
}
}
}
}
/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage
/// the lifecycle of the request.
#[derive(Debug)]
struct DispatchRequest<Req, Resp> {
pub ctx: context::Context,
pub request_id: u64,
pub request: Req,
pub response_completion: oneshot::Sender<Response<Resp>>,
}
/// Sends request cancellation signals.
#[derive(Debug, Clone)]
struct RequestCancellation(mpsc::UnboundedSender<u64>);
/// A stream of IDs of requests that have been canceled.
#[derive(Debug)]
struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
/// Returns a channel to send request cancellation messages.
fn cancellations() -> (RequestCancellation, CanceledRequests) {
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
// bounded by the number of in-flight requests. Additionally, each request has a clone
// of the sender, so the bounded channel would have the same behavior,
// since it guarantees a slot.
let (tx, rx) = mpsc::unbounded_channel();
(RequestCancellation(tx), CanceledRequests(rx))
}
impl RequestCancellation {
/// Cancels the request with ID `request_id`.
fn cancel(&mut self, request_id: u64) {
let _ = self.0.send(request_id);
}
}
impl Stream for CanceledRequests {
type Item = u64;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
self.0.poll_recv(cx)
}
}
#[cfg(test)]
mod tests {
use super::{
cancellations, CanceledRequests, Channel, DispatchResponse, RequestCancellation,
RequestDispatch,
};
use crate::{
client::{in_flight_requests::InFlightRequests, Config},
context,
transport::{self, channel::UnboundedChannel},
ClientMessage, Response,
};
use futures::{prelude::*, task::*};
use std::{pin::Pin, sync::atomic::AtomicUsize, sync::Arc};
use tokio::sync::{mpsc, oneshot};
#[tokio::test]
async fn dispatch_response_cancels_on_drop() {
let (cancellation, mut canceled_requests) = cancellations();
let (_, response) = oneshot::channel();
drop(DispatchResponse::<u32> {
response,
cancellation,
complete: false,
request_id: 3,
ctx: context::current(),
});
// resp's drop() is run, which should send a cancel message.
let cx = &mut Context::from_waker(&noop_waker_ref());
assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(Some(3)));
}
#[tokio::test]
async fn stage_request() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref());
let _resp = send_request(&mut channel, "hi").await;
let req = dispatch.poll_next_request(cx).ready();
assert!(req.is_some());
let req = req.unwrap();
assert_eq!(req.request_id, 0);
assert_eq!(req.request, "hi".to_string());
}
// Regression test for https://github.com/google/tarpc/issues/220
#[tokio::test]
async fn stage_request_channel_dropped_doesnt_panic() {
let (mut dispatch, mut channel, mut server_channel) = set_up();
let mut dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref());
let _ = send_request(&mut channel, "hi").await;
drop(channel);
assert!(dispatch.as_mut().poll(cx).is_ready());
send_response(
&mut server_channel,
Response {
request_id: 0,
message: Ok("hello".into()),
},
)
.await;
dispatch.await.unwrap();
}
#[tokio::test]
async fn stage_request_response_future_dropped_is_canceled_before_sending() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref());
let _ = send_request(&mut channel, "hi").await;
// Drop the channel so polling returns none if no requests are currently ready.
drop(channel);
// Test that a request future dropped before it's processed by dispatch will cause the request
// to not be added to the in-flight request map.
assert!(dispatch.poll_next_request(cx).ready().is_none());
}
#[tokio::test]
async fn stage_request_response_future_dropped_is_canceled_after_sending() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let cx = &mut Context::from_waker(&noop_waker_ref());
let mut dispatch = Pin::new(&mut dispatch);
let req = send_request(&mut channel, "hi").await;
assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
assert!(!dispatch.in_flight_requests().is_empty());
// Test that a request future dropped after it's processed by dispatch will cause the request
// to be removed from the in-flight request map.
drop(req);
if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(cx).unwrap() {
// ok
} else {
panic!("Expected request to be cancelled")
};
assert!(dispatch.in_flight_requests().is_empty());
}
#[tokio::test]
async fn stage_request_response_closed_skipped() {
let (mut dispatch, mut channel, _server_channel) = set_up();
let dispatch = Pin::new(&mut dispatch);
let cx = &mut Context::from_waker(&noop_waker_ref());
// Test that a request future that's closed its receiver but not yet canceled its request --
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
// map.
let mut resp = send_request(&mut channel, "hi").await;
resp.response.close();
assert!(dispatch.poll_next_request(cx).is_pending());
}
fn set_up() -> (
RequestDispatch<String, String, UnboundedChannel<Response<String>, ClientMessage<String>>>,
Channel<String, String>,
UnboundedChannel<ClientMessage<String>, Response<String>>,
) {
let _ = env_logger::try_init();
let (to_dispatch, pending_requests) = mpsc::channel(1);
let (cancel_tx, canceled_requests) = mpsc::unbounded_channel();
let (client_channel, server_channel) = transport::channel::unbounded();
let dispatch = RequestDispatch::<String, String, _> {
transport: client_channel.fuse(),
pending_requests: pending_requests,
canceled_requests: CanceledRequests(canceled_requests),
in_flight_requests: InFlightRequests::default(),
config: Config::default(),
};
let cancellation = RequestCancellation(cancel_tx);
let channel = Channel {
to_dispatch,
cancellation,
next_request_id: Arc::new(AtomicUsize::new(0)),
};
(dispatch, channel, server_channel)
}
async fn send_request(
channel: &mut Channel<String, String>,
request: &str,
) -> DispatchResponse<String> {
channel
.send(context::current(), request.to_string())
.await
.unwrap()
}
async fn send_response(
channel: &mut UnboundedChannel<ClientMessage<String>, Response<String>>,
response: Response<String>,
) {
channel.send(response).await.unwrap();
}
trait PollTest {
type T;
fn unwrap(self) -> Poll<Self::T>;
fn ready(self) -> Self::T;
}
impl<T, E> PollTest for Poll<Option<Result<T, E>>>
where
E: ::std::fmt::Display,
{
type T = Option<T>;
fn unwrap(self) -> Poll<Option<T>> {
match self {
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()),
Poll::Pending => Poll::Pending,
}
}
fn ready(self) -> Option<T> {
match self {
Poll::Ready(Some(Ok(t))) => Some(t),
Poll::Ready(None) => None,
Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()),
Poll::Pending => panic!("Pending"),
}
}
}
}

View File

@@ -131,7 +131,7 @@
//! ```
//!
//! Lastly let's write our `main` that will start the server. While this example uses an
//! [in-process channel](rpc::transport::channel), tarpc also ships a generic [`serde_transport`]
//! [in-process channel](transport::channel), tarpc also ships a generic [`serde_transport`]
//! behind the `serde-transport` feature, with additional [TCP](serde_transport::tcp) functionality
//! available behind the `tcp` feature.
//!

View File

@@ -72,7 +72,7 @@ pub trait Serve<Req> {
impl<Req, Resp, Fut, F> Serve<Req> for F
where
F: FnOnce(context::Context, Req) -> Fut + Clone,
F: FnOnce(context::Context, Req) -> Fut,
Fut: Future<Output = Resp>,
{
type Resp = Resp;
@@ -182,13 +182,12 @@ impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
/// The server end of an open connection with a client, streaming in requests from, and sinking
/// responses to, the client.
///
///
/// The ways to use a Channel, in order of simplest to most complex, is:
/// 1. [Channel::execute] - Requires the `tokio1` feature. This method is best for those who
/// do not have specific scheduling needs and whose services are `Send + 'static`.
/// 2. [Channel::requests] - This method is best for those who need direct access to individual
/// requests, or are not using `tokio`, or want control over [futures](Future) scheduling.
/// 3. [Raw stream](<Channel as Stream>) - A user is free to manually handle requests produced by
/// 3. [Raw stream](Stream) - A user is free to manually handle requests produced by
/// Channel. If they do so, they should uphold the service contract:
/// 1. All work being done as part of processing request `request_id` is aborted when
/// either of the following occurs:
@@ -199,8 +198,10 @@ impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
/// [sent](Sink::start_send) into the Channel. Because there is no guarantee that a
/// cancellation message will ever be received for a request, services should strive to clean
/// up Channel resources by sending a response for every request. For example, [`BaseChannel`]
/// has a map of requests to [abort handles][AbortHandle] whose entries are only removed
/// upon either request cancellation or response completion.
/// has a map of requests to [abort handles][futures::future::AbortHandle] whose entries are
/// only removed upon either request cancellation, response completion, or deadline
/// expiration. For requests with long deadlines that have been abandoned without a response,
/// some cleanup may never happen.
pub trait Channel
where
Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
@@ -260,7 +261,7 @@ where
fn execute<S>(self, serve: S) -> TokioChannelExecutor<Requests<Self>, S>
where
Self: Sized,
S: Serve<Self::Req, Resp = Self::Resp> + Send + Sync + 'static,
S: Serve<Self::Req, Resp = Self::Resp> + Send + 'static,
S::Fut: Send,
Self::Req: Send + 'static,
Self::Resp: Send + 'static,
@@ -406,7 +407,7 @@ where
cx: &mut Context<'_>,
) -> PollIo<InFlightRequest<C::Req, C::Resp>> {
loop {
match ready!(self.as_mut().project().channel.poll_next(cx)?) {
match ready!(self.channel_pin_mut().poll_next(cx)?) {
Some(request) => {
trace!(
"[{}] Handling request with deadline {}.",
@@ -617,7 +618,7 @@ where
/// by [spawning](tokio::spawn) each handler on tokio's default executor.
pub fn execute<S>(self, serve: S) -> TokioChannelExecutor<Self, S>
where
S: Serve<C::Req, Resp = C::Resp> + Send + Sync + 'static,
S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
{
TokioChannelExecutor { inner: self, serve }
}
@@ -635,8 +636,8 @@ pub struct TokioServerExecutor<T, S> {
serve: S,
}
/// A future that drives the server by [spawning](tokio::spawn) each [response handler](ResponseHandler)
/// on tokio's default executor.
/// A future that drives the server by [spawning](tokio::spawn) each [response
/// handler](InFlightRequest::execute) on tokio's default executor.
#[pin_project]
#[derive(Debug)]
#[cfg(feature = "tokio1")]
@@ -670,7 +671,7 @@ where
C: Channel + Send + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
Se: Serve<C::Req, Resp = C::Resp> + Send + Sync + 'static + Clone,
Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
Se::Fut: Send,
{
type Output = ();
@@ -690,7 +691,7 @@ where
C: Channel + 'static,
C::Req: Send + 'static,
C::Resp: Send + 'static,
S: Serve<C::Req, Resp = C::Resp> + Send + Sync + 'static + Clone,
S: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
S::Fut: Send,
{
type Output = ();

View File

@@ -15,7 +15,7 @@ use log::{debug, info, trace};
use pin_project::pin_project;
use std::sync::{Arc, Weak};
use std::{
collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin,
collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin,
time::SystemTime,
};
use tokio::sync::mpsc;
@@ -30,9 +30,7 @@ where
#[pin]
listener: Fuse<S>,
channels_per_key: u32,
#[pin]
dropped_keys: mpsc::UnboundedReceiver<K>,
#[pin]
dropped_keys_tx: mpsc::UnboundedSender<K>,
key_counts: FnvHashMap<K, Weak<Tracker<K>>>,
keymaker: F,
@@ -66,8 +64,8 @@ where
{
type Item = <C as Stream>::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.channel().poll_next(cx)
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.inner_pin_mut().poll_next(cx)
}
}
@@ -77,20 +75,20 @@ where
{
type Error = C::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.channel().poll_ready(cx)
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.inner_pin_mut().poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
self.channel().start_send(item)
fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
self.inner_pin_mut().start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.channel().poll_flush(cx)
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.inner_pin_mut().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.channel().poll_close(cx)
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.inner_pin_mut().poll_close(cx)
}
}
@@ -116,15 +114,15 @@ where
}
fn start_request(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
id: u64,
deadline: SystemTime,
) -> Result<AbortRegistration, super::in_flight_requests::AlreadyExistsError> {
self.project().inner.start_request(id, deadline)
self.inner_pin_mut().start_request(id, deadline)
}
fn poll_expired(self: Pin<&mut Self>, cx: &mut Context) -> PollIo<u64> {
self.project().inner.poll_expired(cx)
fn poll_expired(mut self: Pin<&mut Self>, cx: &mut Context) -> PollIo<u64> {
self.inner_pin_mut().poll_expired(cx)
}
}
@@ -135,8 +133,8 @@ impl<C, K> TrackedChannel<C, K> {
}
/// Returns the pinned inner channel.
fn channel(self: Pin<&mut Self>) -> Pin<&mut C> {
self.project().inner
fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
self.as_mut().project().inner
}
}
@@ -166,6 +164,10 @@ where
K: fmt::Display + Eq + Hash + Clone + Unpin,
F: Fn(&S::Item) -> K,
{
fn listener_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<S>> {
self.as_mut().project().listener
}
fn handle_new_channel(
mut self: Pin<&mut Self>,
stream: S::Item,
@@ -177,7 +179,7 @@ where
"[{}] Opening channel ({}/{}) channels for key.",
key,
Arc::strong_count(&tracker),
self.as_mut().project().channels_per_key
self.channels_per_key
);
Ok(TrackedChannel {
@@ -186,15 +188,14 @@ where
})
}
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
let channels_per_key = self.channels_per_key;
let dropped_keys = self.dropped_keys_tx.clone();
let key_counts = &mut self.as_mut().project().key_counts;
match key_counts.entry(key.clone()) {
fn increment_channels_for_key(self: Pin<&mut Self>, key: K) -> Result<Arc<Tracker<K>>, K> {
let self_ = self.project();
let dropped_keys = self_.dropped_keys_tx;
match self_.key_counts.entry(key.clone()) {
Entry::Vacant(vacant) => {
let tracker = Arc::new(Tracker {
key: Some(key),
dropped_keys,
dropped_keys: dropped_keys.clone(),
});
vacant.insert(Arc::downgrade(&tracker));
@@ -202,17 +203,17 @@ where
}
Entry::Occupied(mut o) => {
let count = o.get().strong_count();
if count >= channels_per_key.try_into().unwrap() {
if count >= TryFrom::try_from(*self_.channels_per_key).unwrap() {
info!(
"[{}] Opened max channels from key ({}/{}).",
key, count, channels_per_key
key, count, self_.channels_per_key
);
Err(key)
} else {
Ok(o.get().upgrade().unwrap_or_else(|| {
let tracker = Arc::new(Tracker {
key: Some(key),
dropped_keys,
dropped_keys: dropped_keys.clone(),
});
*o.get_mut() = Arc::downgrade(&tracker);
@@ -227,18 +228,19 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
match ready!(self.as_mut().project().listener.poll_next_unpin(cx)) {
match ready!(self.listener_pin_mut().poll_next_unpin(cx)) {
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
None => Poll::Ready(None),
}
}
fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
match ready!(self.as_mut().project().dropped_keys.poll_recv(cx)) {
fn poll_closed_channels(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let self_ = self.project();
match ready!(self_.dropped_keys.poll_recv(cx)) {
Some(key) => {
debug!("All channels dropped for key [{}]", key);
self.as_mut().project().key_counts.remove(&key);
self.as_mut().project().key_counts.compact(0.1);
self_.key_counts.remove(&key);
self_.key_counts.compact(0.1);
Poll::Ready(())
}
None => unreachable!("Holding a copy of closed_channels and didn't close it."),

View File

@@ -212,3 +212,38 @@ async fn concurrent_join_all() -> io::Result<()> {
Ok(())
}
#[tokio::test]
async fn counter() -> io::Result<()> {
#[tarpc::service]
trait Counter {
async fn count() -> u32;
}
struct CountService(u32);
impl Counter for &mut CountService {
type CountFut = futures::future::Ready<u32>;
fn count(self, _: context::Context) -> Self::CountFut {
self.0 += 1;
futures::future::ready(self.0)
}
}
let (tx, rx) = channel::unbounded();
tokio::spawn(async {
let mut requests = BaseChannel::with_defaults(rx).requests();
let mut counter = CountService(0);
while let Some(Ok(request)) = requests.next().await {
request.execute(counter.serve()).await;
}
});
let client = CounterClient::new(client::Config::default(), tx).spawn()?;
assert_matches!(client.count(context::current()).await, Ok(1));
assert_matches!(client.count(context::current()).await, Ok(2));
Ok(())
}