mirror of
https://github.com/OMGeeky/tarpc.git
synced 2026-01-04 02:21:02 +01:00
922 lines
30 KiB
Rust
922 lines
30 KiB
Rust
// Copyright 2018 Google LLC
|
|
//
|
|
// Use of this source code is governed by an MIT-style
|
|
// license that can be found in the LICENSE file or at
|
|
// https://opensource.org/licenses/MIT.
|
|
|
|
use crate::{
|
|
context,
|
|
util::{Compact, TimeUntil},
|
|
ClientMessage, PollIo, Request, Response, Transport,
|
|
};
|
|
use fnv::FnvHashMap;
|
|
use futures::{
|
|
channel::{mpsc, oneshot},
|
|
prelude::*,
|
|
ready,
|
|
stream::Fuse,
|
|
task::Context,
|
|
Poll,
|
|
};
|
|
use log::{debug, info, trace};
|
|
use pin_project::{pin_project, pinned_drop};
|
|
use std::{
|
|
io,
|
|
pin::Pin,
|
|
sync::{
|
|
atomic::{AtomicU64, Ordering},
|
|
Arc,
|
|
},
|
|
};
|
|
use tokio_timer::{timeout, Timeout};
|
|
use trace::SpanId;
|
|
|
|
use super::{Config, NewClient};
|
|
|
|
/// Handles communication from the client to request dispatch.
|
|
#[derive(Debug)]
|
|
pub struct Channel<Req, Resp> {
|
|
to_dispatch: mpsc::Sender<DispatchRequest<Req, Resp>>,
|
|
/// Channel to send a cancel message to the dispatcher.
|
|
cancellation: RequestCancellation,
|
|
/// The ID to use for the next request to stage.
|
|
next_request_id: Arc<AtomicU64>,
|
|
}
|
|
|
|
impl<Req, Resp> Clone for Channel<Req, Resp> {
|
|
fn clone(&self) -> Self {
|
|
Self {
|
|
to_dispatch: self.to_dispatch.clone(),
|
|
cancellation: self.cancellation.clone(),
|
|
next_request_id: self.next_request_id.clone(),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A future returned by [`Channel::send`] that resolves to a server response.
|
|
#[pin_project]
|
|
#[derive(Debug)]
|
|
#[must_use = "futures do nothing unless polled"]
|
|
struct Send<'a, Req, Resp> {
|
|
#[pin]
|
|
fut: MapOkDispatchResponse<SendMapErrConnectionReset<'a, Req, Resp>, Resp>,
|
|
}
|
|
|
|
type SendMapErrConnectionReset<'a, Req, Resp> = MapErrConnectionReset<
|
|
futures::sink::Send<'a, mpsc::Sender<DispatchRequest<Req, Resp>>, DispatchRequest<Req, Resp>>,
|
|
>;
|
|
|
|
impl<'a, Req, Resp> Future for Send<'a, Req, Resp> {
|
|
type Output = io::Result<DispatchResponse<Resp>>;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
self.as_mut().project().fut.poll(cx)
|
|
}
|
|
}
|
|
|
|
/// A future returned by [`Channel::call`] that resolves to a server response.
|
|
#[pin_project]
|
|
#[derive(Debug)]
|
|
#[must_use = "futures do nothing unless polled"]
|
|
pub struct Call<'a, Req, Resp> {
|
|
#[pin]
|
|
fut: AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>,
|
|
}
|
|
|
|
impl<'a, Req, Resp> Future for Call<'a, Req, Resp> {
|
|
type Output = io::Result<Resp>;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
self.as_mut().project().fut.poll(cx)
|
|
}
|
|
}
|
|
|
|
impl<Req, Resp> Channel<Req, Resp> {
|
|
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
|
/// resolves when the request is sent (not when the response is received).
|
|
fn send(&mut self, mut ctx: context::Context, request: Req) -> Send<Req, Resp> {
|
|
// Convert the context to the call context.
|
|
ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
|
|
ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
|
|
|
|
let timeout = ctx.deadline.time_until();
|
|
trace!(
|
|
"[{}] Queuing request with timeout {:?}.",
|
|
ctx.trace_id(),
|
|
timeout,
|
|
);
|
|
|
|
let (response_completion, response) = oneshot::channel();
|
|
let cancellation = self.cancellation.clone();
|
|
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
|
Send {
|
|
fut: MapOkDispatchResponse::new(
|
|
MapErrConnectionReset::new(self.to_dispatch.send(DispatchRequest {
|
|
ctx,
|
|
request_id,
|
|
request,
|
|
response_completion,
|
|
})),
|
|
DispatchResponse {
|
|
response: Timeout::new(response, timeout),
|
|
complete: false,
|
|
request_id,
|
|
cancellation,
|
|
ctx,
|
|
},
|
|
),
|
|
}
|
|
}
|
|
|
|
/// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
|
|
/// resolves to the response.
|
|
pub fn call(&mut self, context: context::Context, request: Req) -> Call<Req, Resp> {
|
|
Call {
|
|
fut: AndThenIdent::new(self.send(context, request)),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A server response that is completed by request dispatch when the corresponding response
|
|
/// arrives off the wire.
|
|
#[pin_project(PinnedDrop)]
|
|
#[derive(Debug)]
|
|
struct DispatchResponse<Resp> {
|
|
response: Timeout<oneshot::Receiver<Response<Resp>>>,
|
|
ctx: context::Context,
|
|
complete: bool,
|
|
cancellation: RequestCancellation,
|
|
request_id: u64,
|
|
}
|
|
|
|
impl<Resp> Future for DispatchResponse<Resp> {
|
|
type Output = io::Result<Resp>;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
|
|
let resp = ready!(self.response.poll_unpin(cx));
|
|
|
|
Poll::Ready(match resp {
|
|
Ok(resp) => {
|
|
self.complete = true;
|
|
match resp {
|
|
Ok(resp) => Ok(resp.message?),
|
|
Err(oneshot::Canceled) => {
|
|
// The oneshot is Canceled when the dispatch task ends. In that case,
|
|
// there's nothing listening on the other side, so there's no point in
|
|
// propagating cancellation.
|
|
Err(io::Error::from(io::ErrorKind::ConnectionReset))
|
|
}
|
|
}
|
|
}
|
|
Err(timeout::Elapsed { .. }) => Err(io::Error::new(
|
|
io::ErrorKind::TimedOut,
|
|
"Client dropped expired request.".to_string(),
|
|
)),
|
|
})
|
|
}
|
|
}
|
|
|
|
// Cancels the request when dropped, if not already complete.
|
|
#[pinned_drop]
|
|
impl<Resp> PinnedDrop for DispatchResponse<Resp> {
|
|
fn drop(mut self: Pin<&mut Self>) {
|
|
if !self.complete {
|
|
// The receiver needs to be closed to handle the edge case that the request has not
|
|
// yet been received by the dispatch task. It is possible for the cancel message to
|
|
// arrive before the request itself, in which case the request could get stuck in the
|
|
// dispatch map forever if the server never responds (e.g. if the server dies while
|
|
// responding). Even if the server does respond, it will have unnecessarily done work
|
|
// for a client no longer waiting for a response. To avoid this, the dispatch task
|
|
// checks if the receiver is closed before inserting the request in the map. By
|
|
// closing the receiver before sending the cancel message, it is guaranteed that if the
|
|
// dispatch task misses an early-arriving cancellation message, then it will see the
|
|
// receiver as closed.
|
|
self.response.get_mut().close();
|
|
let request_id = self.request_id;
|
|
self.cancellation.cancel(request_id);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the
|
|
/// channel.
|
|
pub fn new<Req, Resp, C>(
|
|
config: Config,
|
|
transport: C,
|
|
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C>>
|
|
where
|
|
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
|
{
|
|
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
|
|
let (cancellation, canceled_requests) = cancellations();
|
|
let canceled_requests = canceled_requests.fuse();
|
|
|
|
NewClient {
|
|
client: Channel {
|
|
to_dispatch,
|
|
cancellation,
|
|
next_request_id: Arc::new(AtomicU64::new(0)),
|
|
},
|
|
dispatch: RequestDispatch {
|
|
config,
|
|
canceled_requests,
|
|
transport: transport.fuse(),
|
|
in_flight_requests: FnvHashMap::default(),
|
|
pending_requests: pending_requests.fuse(),
|
|
},
|
|
}
|
|
}
|
|
|
|
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
|
|
/// and dispatching responses to the appropriate channel.
|
|
#[pin_project]
|
|
#[derive(Debug)]
|
|
pub struct RequestDispatch<Req, Resp, C> {
|
|
/// Writes requests to the wire and reads responses off the wire.
|
|
#[pin]
|
|
transport: Fuse<C>,
|
|
/// Requests waiting to be written to the wire.
|
|
#[pin]
|
|
pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>,
|
|
/// Requests that were dropped.
|
|
#[pin]
|
|
canceled_requests: Fuse<CanceledRequests>,
|
|
/// Requests already written to the wire that haven't yet received responses.
|
|
in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>,
|
|
/// Configures limits to prevent unlimited resource usage.
|
|
config: Config,
|
|
}
|
|
|
|
impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
|
|
where
|
|
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
|
{
|
|
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
|
Poll::Ready(
|
|
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
|
|
Some(response) => {
|
|
self.complete(response);
|
|
Some(Ok(()))
|
|
}
|
|
None => None,
|
|
},
|
|
)
|
|
}
|
|
|
|
fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
|
enum ReceiverStatus {
|
|
NotReady,
|
|
Closed,
|
|
}
|
|
|
|
let pending_requests_status = match self.as_mut().poll_next_request(cx)? {
|
|
Poll::Ready(Some(dispatch_request)) => {
|
|
self.as_mut().write_request(dispatch_request)?;
|
|
return Poll::Ready(Some(Ok(())));
|
|
}
|
|
Poll::Ready(None) => ReceiverStatus::Closed,
|
|
Poll::Pending => ReceiverStatus::NotReady,
|
|
};
|
|
|
|
let canceled_requests_status = match self.as_mut().poll_next_cancellation(cx)? {
|
|
Poll::Ready(Some((context, request_id))) => {
|
|
self.as_mut().write_cancel(context, request_id)?;
|
|
return Poll::Ready(Some(Ok(())));
|
|
}
|
|
Poll::Ready(None) => ReceiverStatus::Closed,
|
|
Poll::Pending => ReceiverStatus::NotReady,
|
|
};
|
|
|
|
match (pending_requests_status, canceled_requests_status) {
|
|
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
|
|
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
|
Poll::Ready(None)
|
|
}
|
|
(ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
|
|
// No more messages to process, so flush any messages buffered in the transport.
|
|
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
|
|
|
// Even if we fully-flush, we return Pending, because we have no more requests
|
|
// or cancellations right now.
|
|
Poll::Pending
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Yields the next pending request, if one is ready to be sent.
|
|
fn poll_next_request(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
) -> PollIo<DispatchRequest<Req, Resp>> {
|
|
if self.as_mut().project().in_flight_requests.len() >= self.config.max_in_flight_requests {
|
|
info!(
|
|
"At in-flight request capacity ({}/{}).",
|
|
self.as_mut().project().in_flight_requests.len(),
|
|
self.config.max_in_flight_requests
|
|
);
|
|
|
|
// No need to schedule a wakeup, because timers and responses are responsible
|
|
// for clearing out in-flight requests.
|
|
return Poll::Pending;
|
|
}
|
|
|
|
while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? {
|
|
// We can't yield a request-to-be-sent before the transport is capable of buffering it.
|
|
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
|
}
|
|
|
|
loop {
|
|
match ready!(self.as_mut().project().pending_requests.poll_next_unpin(cx)) {
|
|
Some(request) => {
|
|
if request.response_completion.is_canceled() {
|
|
trace!(
|
|
"[{}] Request canceled before being sent.",
|
|
request.ctx.trace_id()
|
|
);
|
|
continue;
|
|
}
|
|
|
|
return Poll::Ready(Some(Ok(request)));
|
|
}
|
|
None => return Poll::Ready(None),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
|
|
fn poll_next_cancellation(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
) -> PollIo<(context::Context, u64)> {
|
|
while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? {
|
|
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
|
}
|
|
|
|
loop {
|
|
let cancellation = self
|
|
.as_mut()
|
|
.project()
|
|
.canceled_requests
|
|
.poll_next_unpin(cx);
|
|
match ready!(cancellation) {
|
|
Some(request_id) => {
|
|
if let Some(in_flight_data) = self
|
|
.as_mut()
|
|
.project()
|
|
.in_flight_requests
|
|
.remove(&request_id)
|
|
{
|
|
self.as_mut().project().in_flight_requests.compact(0.1);
|
|
debug!("[{}] Removed request.", in_flight_data.ctx.trace_id());
|
|
return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id))));
|
|
}
|
|
}
|
|
None => return Poll::Ready(None),
|
|
}
|
|
}
|
|
}
|
|
|
|
fn write_request(
|
|
mut self: Pin<&mut Self>,
|
|
dispatch_request: DispatchRequest<Req, Resp>,
|
|
) -> io::Result<()> {
|
|
let request_id = dispatch_request.request_id;
|
|
let request = ClientMessage::Request(Request {
|
|
id: request_id,
|
|
message: dispatch_request.request,
|
|
context: context::Context {
|
|
deadline: dispatch_request.ctx.deadline,
|
|
trace_context: dispatch_request.ctx.trace_context,
|
|
_non_exhaustive: (),
|
|
},
|
|
_non_exhaustive: (),
|
|
});
|
|
self.as_mut().project().transport.start_send(request)?;
|
|
self.as_mut().project().in_flight_requests.insert(
|
|
request_id,
|
|
InFlightData {
|
|
ctx: dispatch_request.ctx,
|
|
response_completion: dispatch_request.response_completion,
|
|
},
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
fn write_cancel(
|
|
mut self: Pin<&mut Self>,
|
|
context: context::Context,
|
|
request_id: u64,
|
|
) -> io::Result<()> {
|
|
let trace_id = *context.trace_id();
|
|
let cancel = ClientMessage::Cancel {
|
|
trace_context: context.trace_context,
|
|
request_id,
|
|
};
|
|
self.as_mut().project().transport.start_send(cancel)?;
|
|
trace!("[{}] Cancel message sent.", trace_id);
|
|
Ok(())
|
|
}
|
|
|
|
/// Sends a server response to the client task that initiated the associated request.
|
|
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
|
|
if let Some(in_flight_data) = self
|
|
.as_mut()
|
|
.project()
|
|
.in_flight_requests
|
|
.remove(&response.request_id)
|
|
{
|
|
self.as_mut().project().in_flight_requests.compact(0.1);
|
|
|
|
trace!("[{}] Received response.", in_flight_data.ctx.trace_id());
|
|
let _ = in_flight_data.response_completion.send(response);
|
|
return true;
|
|
}
|
|
|
|
debug!(
|
|
"No in-flight request found for request_id = {}.",
|
|
response.request_id
|
|
);
|
|
|
|
// If the response completion was absent, then the request was already canceled.
|
|
false
|
|
}
|
|
}
|
|
|
|
impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
|
|
where
|
|
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
|
{
|
|
type Output = io::Result<()>;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
loop {
|
|
match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) {
|
|
(read, Poll::Ready(None)) => {
|
|
if self.as_mut().project().in_flight_requests.is_empty() {
|
|
info!("Shutdown: write half closed, and no requests in flight.");
|
|
return Poll::Ready(Ok(()));
|
|
}
|
|
info!(
|
|
"Shutdown: write half closed, and {} requests in flight.",
|
|
self.as_mut().project().in_flight_requests.len()
|
|
);
|
|
match read {
|
|
Poll::Ready(Some(())) => continue,
|
|
_ => return Poll::Pending,
|
|
}
|
|
}
|
|
(Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
|
|
_ => return Poll::Pending,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage
|
|
/// the lifecycle of the request.
|
|
#[derive(Debug)]
|
|
struct DispatchRequest<Req, Resp> {
|
|
ctx: context::Context,
|
|
request_id: u64,
|
|
request: Req,
|
|
response_completion: oneshot::Sender<Response<Resp>>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct InFlightData<Resp> {
|
|
ctx: context::Context,
|
|
response_completion: oneshot::Sender<Response<Resp>>,
|
|
}
|
|
|
|
/// Sends request cancellation signals.
|
|
#[derive(Debug, Clone)]
|
|
struct RequestCancellation(mpsc::UnboundedSender<u64>);
|
|
|
|
/// A stream of IDs of requests that have been canceled.
|
|
#[derive(Debug)]
|
|
struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
|
|
|
|
/// Returns a channel to send request cancellation messages.
|
|
fn cancellations() -> (RequestCancellation, CanceledRequests) {
|
|
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
|
|
// bounded by the number of in-flight requests. Additionally, each request has a clone
|
|
// of the sender, so the bounded channel would have the same behavior,
|
|
// since it guarantees a slot.
|
|
let (tx, rx) = mpsc::unbounded();
|
|
(RequestCancellation(tx), CanceledRequests(rx))
|
|
}
|
|
|
|
impl RequestCancellation {
|
|
/// Cancels the request with ID `request_id`.
|
|
fn cancel(&mut self, request_id: u64) {
|
|
let _ = self.0.unbounded_send(request_id);
|
|
}
|
|
}
|
|
|
|
impl Stream for CanceledRequests {
|
|
type Item = u64;
|
|
|
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
|
|
self.0.poll_next_unpin(cx)
|
|
}
|
|
}
|
|
|
|
#[pin_project]
|
|
#[derive(Debug)]
|
|
#[must_use = "futures do nothing unless polled"]
|
|
struct MapErrConnectionReset<Fut> {
|
|
#[pin]
|
|
future: Fut,
|
|
finished: Option<()>,
|
|
}
|
|
|
|
impl<Fut> MapErrConnectionReset<Fut> {
|
|
fn new(future: Fut) -> MapErrConnectionReset<Fut> {
|
|
MapErrConnectionReset {
|
|
future,
|
|
finished: Some(()),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<Fut> Future for MapErrConnectionReset<Fut>
|
|
where
|
|
Fut: TryFuture,
|
|
{
|
|
type Output = io::Result<Fut::Ok>;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
match self.as_mut().project().future.try_poll(cx) {
|
|
Poll::Pending => Poll::Pending,
|
|
Poll::Ready(result) => {
|
|
self.project().finished.take().expect(
|
|
"MapErrConnectionReset must not be polled after it returned `Poll::Ready`",
|
|
);
|
|
Poll::Ready(result.map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset)))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[pin_project]
|
|
#[derive(Debug)]
|
|
#[must_use = "futures do nothing unless polled"]
|
|
struct MapOkDispatchResponse<Fut, Resp> {
|
|
#[pin]
|
|
future: Fut,
|
|
response: Option<DispatchResponse<Resp>>,
|
|
}
|
|
|
|
impl<Fut, Resp> MapOkDispatchResponse<Fut, Resp> {
|
|
fn new(future: Fut, response: DispatchResponse<Resp>) -> MapOkDispatchResponse<Fut, Resp> {
|
|
MapOkDispatchResponse {
|
|
future,
|
|
response: Some(response),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<Fut, Resp> Future for MapOkDispatchResponse<Fut, Resp>
|
|
where
|
|
Fut: TryFuture,
|
|
{
|
|
type Output = Result<DispatchResponse<Resp>, Fut::Error>;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
match self.as_mut().project().future.try_poll(cx) {
|
|
Poll::Pending => Poll::Pending,
|
|
Poll::Ready(result) => {
|
|
let response = self
|
|
.as_mut()
|
|
.project()
|
|
.response
|
|
.take()
|
|
.expect("MapOk must not be polled after it returned `Poll::Ready`");
|
|
Poll::Ready(result.map(|_| response))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[pin_project]
|
|
#[derive(Debug)]
|
|
#[must_use = "futures do nothing unless polled"]
|
|
struct AndThenIdent<Fut1, Fut2> {
|
|
#[pin]
|
|
try_chain: TryChain<Fut1, Fut2>,
|
|
}
|
|
|
|
impl<Fut1, Fut2> AndThenIdent<Fut1, Fut2>
|
|
where
|
|
Fut1: TryFuture<Ok = Fut2>,
|
|
Fut2: TryFuture,
|
|
{
|
|
/// Creates a new `Then`.
|
|
fn new(future: Fut1) -> AndThenIdent<Fut1, Fut2> {
|
|
AndThenIdent {
|
|
try_chain: TryChain::new(future),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<Fut1, Fut2> Future for AndThenIdent<Fut1, Fut2>
|
|
where
|
|
Fut1: TryFuture<Ok = Fut2>,
|
|
Fut2: TryFuture<Error = Fut1::Error>,
|
|
{
|
|
type Output = Result<Fut2::Ok, Fut2::Error>;
|
|
|
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
self.project().try_chain.poll(cx, |result| match result {
|
|
Ok(ok) => TryChainAction::Future(ok),
|
|
Err(err) => TryChainAction::Output(Err(err)),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[must_use = "futures do nothing unless polled"]
|
|
#[derive(Debug)]
|
|
enum TryChain<Fut1, Fut2> {
|
|
First(Fut1),
|
|
Second(Fut2),
|
|
Empty,
|
|
}
|
|
|
|
enum TryChainAction<Fut2>
|
|
where
|
|
Fut2: TryFuture,
|
|
{
|
|
Future(Fut2),
|
|
Output(Result<Fut2::Ok, Fut2::Error>),
|
|
}
|
|
|
|
impl<Fut1, Fut2> TryChain<Fut1, Fut2>
|
|
where
|
|
Fut1: TryFuture<Ok = Fut2>,
|
|
Fut2: TryFuture,
|
|
{
|
|
fn new(fut1: Fut1) -> TryChain<Fut1, Fut2> {
|
|
TryChain::First(fut1)
|
|
}
|
|
|
|
fn poll<F>(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
f: F,
|
|
) -> Poll<Result<Fut2::Ok, Fut2::Error>>
|
|
where
|
|
F: FnOnce(Result<Fut1::Ok, Fut1::Error>) -> TryChainAction<Fut2>,
|
|
{
|
|
let mut f = Some(f);
|
|
|
|
// Safe to call `get_unchecked_mut` because we won't move the futures.
|
|
let this = unsafe { Pin::get_unchecked_mut(self) };
|
|
|
|
loop {
|
|
let output = match this {
|
|
TryChain::First(fut1) => {
|
|
// Poll the first future
|
|
match unsafe { Pin::new_unchecked(fut1) }.try_poll(cx) {
|
|
Poll::Pending => return Poll::Pending,
|
|
Poll::Ready(output) => output,
|
|
}
|
|
}
|
|
TryChain::Second(fut2) => {
|
|
// Poll the second future
|
|
return unsafe { Pin::new_unchecked(fut2) }.try_poll(cx);
|
|
}
|
|
TryChain::Empty => {
|
|
panic!("future must not be polled after it returned `Poll::Ready`");
|
|
}
|
|
};
|
|
|
|
*this = TryChain::Empty; // Drop fut1
|
|
let f = f.take().unwrap();
|
|
match f(output) {
|
|
TryChainAction::Future(fut2) => *this = TryChain::Second(fut2),
|
|
TryChainAction::Output(output) => return Poll::Ready(output),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::{
|
|
cancellations, CanceledRequests, Channel, DispatchResponse, RequestCancellation,
|
|
RequestDispatch,
|
|
};
|
|
use crate::{
|
|
client::Config,
|
|
context,
|
|
transport::{self, channel::UnboundedChannel},
|
|
ClientMessage, Response,
|
|
};
|
|
use fnv::FnvHashMap;
|
|
use futures::{
|
|
channel::{mpsc, oneshot},
|
|
prelude::*,
|
|
task::Context,
|
|
Poll,
|
|
};
|
|
use futures_test::task::noop_waker_ref;
|
|
use std::time::Duration;
|
|
use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc};
|
|
use tokio::runtime::current_thread;
|
|
use tokio_timer::Timeout;
|
|
|
|
#[test]
|
|
fn dispatch_response_cancels_on_timeout() {
|
|
let (_response_completion, response) = oneshot::channel();
|
|
let (cancellation, mut canceled_requests) = cancellations();
|
|
let resp = DispatchResponse::<u64> {
|
|
// Timeout in the past should cause resp to error out when polled.
|
|
response: Timeout::new(response, Duration::from_secs(0)),
|
|
complete: false,
|
|
request_id: 3,
|
|
cancellation,
|
|
ctx: context::current(),
|
|
};
|
|
{
|
|
pin_utils::pin_mut!(resp);
|
|
let timer = tokio_timer::Timer::default();
|
|
let handle = timer.handle();
|
|
let _guard = tokio_timer::set_default(&handle);
|
|
|
|
let _ = resp
|
|
.as_mut()
|
|
.poll(&mut Context::from_waker(&noop_waker_ref()));
|
|
// End of block should cause resp.drop() to run, which should send a cancel message.
|
|
}
|
|
assert!(canceled_requests.0.try_next().unwrap() == Some(3));
|
|
}
|
|
|
|
#[test]
|
|
fn stage_request() {
|
|
let (mut dispatch, mut channel, _server_channel) = set_up();
|
|
let dispatch = Pin::new(&mut dispatch);
|
|
let cx = &mut Context::from_waker(&noop_waker_ref());
|
|
|
|
let _resp = send_request(&mut channel, "hi");
|
|
|
|
let req = dispatch.poll_next_request(cx).ready();
|
|
assert!(req.is_some());
|
|
|
|
let req = req.unwrap();
|
|
assert_eq!(req.request_id, 0);
|
|
assert_eq!(req.request, "hi".to_string());
|
|
}
|
|
|
|
fn block_on<F: Future>(f: F) -> F::Output {
|
|
current_thread::Runtime::new().unwrap().block_on(f)
|
|
}
|
|
|
|
// Regression test for https://github.com/google/tarpc/issues/220
|
|
#[test]
|
|
fn stage_request_channel_dropped_doesnt_panic() {
|
|
let (mut dispatch, mut channel, mut server_channel) = set_up();
|
|
let mut dispatch = Pin::new(&mut dispatch);
|
|
let cx = &mut Context::from_waker(&noop_waker_ref());
|
|
|
|
let _ = send_request(&mut channel, "hi");
|
|
drop(channel);
|
|
|
|
assert!(dispatch.as_mut().poll(cx).is_ready());
|
|
send_response(
|
|
&mut server_channel,
|
|
Response {
|
|
request_id: 0,
|
|
message: Ok("hello".into()),
|
|
_non_exhaustive: (),
|
|
},
|
|
);
|
|
block_on(dispatch).unwrap();
|
|
}
|
|
|
|
#[test]
|
|
fn stage_request_response_future_dropped_is_canceled_before_sending() {
|
|
let (mut dispatch, mut channel, _server_channel) = set_up();
|
|
let dispatch = Pin::new(&mut dispatch);
|
|
let cx = &mut Context::from_waker(&noop_waker_ref());
|
|
|
|
let _ = send_request(&mut channel, "hi");
|
|
|
|
// Drop the channel so polling returns none if no requests are currently ready.
|
|
drop(channel);
|
|
// Test that a request future dropped before it's processed by dispatch will cause the request
|
|
// to not be added to the in-flight request map.
|
|
assert!(dispatch.poll_next_request(cx).ready().is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn stage_request_response_future_dropped_is_canceled_after_sending() {
|
|
let (mut dispatch, mut channel, _server_channel) = set_up();
|
|
let cx = &mut Context::from_waker(&noop_waker_ref());
|
|
let mut dispatch = Pin::new(&mut dispatch);
|
|
|
|
let req = send_request(&mut channel, "hi");
|
|
|
|
assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
|
|
assert!(!dispatch.as_mut().project().in_flight_requests.is_empty());
|
|
|
|
// Test that a request future dropped after it's processed by dispatch will cause the request
|
|
// to be removed from the in-flight request map.
|
|
drop(req);
|
|
if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(cx).unwrap() {
|
|
// ok
|
|
} else {
|
|
panic!("Expected request to be cancelled")
|
|
};
|
|
assert!(dispatch.project().in_flight_requests.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn stage_request_response_closed_skipped() {
|
|
let (mut dispatch, mut channel, _server_channel) = set_up();
|
|
let dispatch = Pin::new(&mut dispatch);
|
|
let cx = &mut Context::from_waker(&noop_waker_ref());
|
|
|
|
// Test that a request future that's closed its receiver but not yet canceled its request --
|
|
// i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
|
|
// map.
|
|
let mut resp = send_request(&mut channel, "hi");
|
|
resp.response.get_mut().close();
|
|
|
|
assert!(dispatch.poll_next_request(cx).is_pending());
|
|
}
|
|
|
|
fn set_up() -> (
|
|
RequestDispatch<String, String, UnboundedChannel<Response<String>, ClientMessage<String>>>,
|
|
Channel<String, String>,
|
|
UnboundedChannel<ClientMessage<String>, Response<String>>,
|
|
) {
|
|
let _ = env_logger::try_init();
|
|
|
|
let (to_dispatch, pending_requests) = mpsc::channel(1);
|
|
let (cancel_tx, canceled_requests) = mpsc::unbounded();
|
|
let (client_channel, server_channel) = transport::channel::unbounded();
|
|
|
|
let dispatch = RequestDispatch::<String, String, _> {
|
|
transport: client_channel.fuse(),
|
|
pending_requests: pending_requests.fuse(),
|
|
canceled_requests: CanceledRequests(canceled_requests).fuse(),
|
|
in_flight_requests: FnvHashMap::default(),
|
|
config: Config::default(),
|
|
};
|
|
|
|
let cancellation = RequestCancellation(cancel_tx);
|
|
let channel = Channel {
|
|
to_dispatch,
|
|
cancellation,
|
|
next_request_id: Arc::new(AtomicU64::new(0)),
|
|
};
|
|
|
|
(dispatch, channel, server_channel)
|
|
}
|
|
|
|
fn send_request(
|
|
channel: &mut Channel<String, String>,
|
|
request: &str,
|
|
) -> DispatchResponse<String> {
|
|
block_on(channel.send(context::current(), request.to_string())).unwrap()
|
|
}
|
|
|
|
fn send_response(
|
|
channel: &mut UnboundedChannel<ClientMessage<String>, Response<String>>,
|
|
response: Response<String>,
|
|
) {
|
|
block_on(channel.send(response)).unwrap();
|
|
}
|
|
|
|
trait PollTest {
|
|
type T;
|
|
fn unwrap(self) -> Poll<Self::T>;
|
|
fn ready(self) -> Self::T;
|
|
}
|
|
|
|
impl<T, E> PollTest for Poll<Option<Result<T, E>>>
|
|
where
|
|
E: ::std::fmt::Display,
|
|
{
|
|
type T = Option<T>;
|
|
|
|
fn unwrap(self) -> Poll<Option<T>> {
|
|
match self {
|
|
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
|
|
Poll::Ready(None) => Poll::Ready(None),
|
|
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
|
|
Poll::Pending => Poll::Pending,
|
|
}
|
|
}
|
|
|
|
fn ready(self) -> Option<T> {
|
|
match self {
|
|
Poll::Ready(Some(Ok(t))) => Some(t),
|
|
Poll::Ready(None) => None,
|
|
Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
|
|
Poll::Pending => panic!("Pending"),
|
|
}
|
|
}
|
|
}
|
|
}
|