mirror of
https://github.com/OMGeeky/tarpc.git
synced 2025-12-29 15:49:52 +01:00
Port to pin-project
This commit is contained in:
@@ -15,7 +15,7 @@ description = "A bincode-based transport for tarpc services."
|
||||
[dependencies]
|
||||
futures-preview = { version = "0.3.0-alpha.18", features = ["compat"] }
|
||||
futures_legacy = { version = "0.1", package = "futures" }
|
||||
pin-utils = "0.1.0-alpha.4"
|
||||
pin-project = "0.4"
|
||||
serde = "1.0"
|
||||
tokio-io = "0.1"
|
||||
async-bincode = "0.4"
|
||||
@@ -24,3 +24,4 @@ tokio-tcp = "0.1"
|
||||
[dev-dependencies]
|
||||
futures-test-preview = { version = "0.3.0-alpha.18" }
|
||||
assert_matches = "1.0"
|
||||
pin-utils = "0.1.0-alpha"
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
|
||||
use async_bincode::{AsyncBincodeStream, AsyncDestination};
|
||||
use futures::{compat::*, prelude::*, ready};
|
||||
use pin_utils::unsafe_pinned;
|
||||
use pin_project::pin_project;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
error::Error,
|
||||
@@ -24,17 +24,13 @@ use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use tokio_tcp::{TcpListener, TcpStream};
|
||||
|
||||
/// A transport that serializes to, and deserializes from, a [`TcpStream`].
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct Transport<S, Item, SinkItem> {
|
||||
#[pin]
|
||||
inner: Compat01As03Sink<AsyncBincodeStream<S, Item, SinkItem, AsyncDestination>, SinkItem>,
|
||||
}
|
||||
|
||||
impl<S, Item, SinkItem> Transport<S, Item, SinkItem> {
|
||||
unsafe_pinned!(
|
||||
inner: Compat01As03Sink<AsyncBincodeStream<S, Item, SinkItem, AsyncDestination>, SinkItem>
|
||||
);
|
||||
}
|
||||
|
||||
impl<S, Item, SinkItem> Stream for Transport<S, Item, SinkItem>
|
||||
where
|
||||
S: AsyncRead,
|
||||
@@ -43,7 +39,7 @@ where
|
||||
type Item = io::Result<Item>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
|
||||
match self.inner().poll_next(cx) {
|
||||
match self.project().inner.poll_next(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Ready(Some(Ok(next))) => Poll::Ready(Some(Ok(next))),
|
||||
@@ -62,21 +58,22 @@ where
|
||||
type Error = io::Error;
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
||||
self.inner()
|
||||
self.project()
|
||||
.inner
|
||||
.start_send(item)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
convert(self.inner().poll_ready(cx))
|
||||
convert(self.project().inner.poll_ready(cx))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
convert(self.inner().poll_flush(cx))
|
||||
convert(self.project().inner.poll_flush(cx))
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
convert(self.inner().poll_close(cx))
|
||||
convert(self.project().inner.poll_close(cx))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,16 +150,16 @@ where
|
||||
}
|
||||
|
||||
/// A [`TcpListener`] that wraps connections in bincode transports.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct Incoming<Item, SinkItem> {
|
||||
#[pin]
|
||||
incoming: Compat01As03<tokio_tcp::Incoming>,
|
||||
local_addr: SocketAddr,
|
||||
ghost: PhantomData<(Item, SinkItem)>,
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Incoming<Item, SinkItem> {
|
||||
unsafe_pinned!(incoming: Compat01As03<tokio_tcp::Incoming>);
|
||||
|
||||
/// Returns the address being listened on.
|
||||
pub fn local_addr(&self) -> SocketAddr {
|
||||
self.local_addr
|
||||
@@ -177,7 +174,7 @@ where
|
||||
type Item = io::Result<Transport<TcpStream, Item, SinkItem>>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let next = ready!(self.incoming().poll_next(cx)?);
|
||||
let next = ready!(self.project().incoming.poll_next(cx)?);
|
||||
Poll::Ready(next.map(|conn| Ok(new(conn))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ fnv = "1.0"
|
||||
futures-preview = { version = "0.3.0-alpha.18" }
|
||||
humantime = "1.0"
|
||||
log = "0.4"
|
||||
pin-utils = "0.1.0-alpha.4"
|
||||
pin-project = "0.4"
|
||||
raii-counter = "0.2"
|
||||
rand = "0.7"
|
||||
tokio-timer = "0.3.0-alpha.4"
|
||||
@@ -34,3 +34,4 @@ tokio = { optional = true, version = "0.2.0-alpha.4" }
|
||||
futures-test-preview = { version = "0.3.0-alpha.18" }
|
||||
env_logger = "0.6"
|
||||
assert_matches = "1.0"
|
||||
pin-utils = "0.1.0-alpha"
|
||||
|
||||
@@ -19,10 +19,9 @@ use futures::{
|
||||
Poll,
|
||||
};
|
||||
use log::{debug, info, trace};
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use pin_project::{pin_project, pinned_drop};
|
||||
use std::{
|
||||
io,
|
||||
marker::Unpin,
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
@@ -55,9 +54,11 @@ impl<Req, Resp> Clone for Channel<Req, Resp> {
|
||||
}
|
||||
|
||||
/// A future returned by [`Channel::send`] that resolves to a server response.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
struct Send<'a, Req, Resp> {
|
||||
#[pin]
|
||||
fut: MapOkDispatchResponse<SendMapErrConnectionReset<'a, Req, Resp>, Resp>,
|
||||
}
|
||||
|
||||
@@ -65,45 +66,28 @@ type SendMapErrConnectionReset<'a, Req, Resp> = MapErrConnectionReset<
|
||||
futures::sink::Send<'a, mpsc::Sender<DispatchRequest<Req, Resp>>, DispatchRequest<Req, Resp>>,
|
||||
>;
|
||||
|
||||
impl<'a, Req, Resp> Send<'a, Req, Resp> {
|
||||
unsafe_pinned!(
|
||||
fut: MapOkDispatchResponse<
|
||||
MapErrConnectionReset<
|
||||
futures::sink::Send<
|
||||
'a,
|
||||
mpsc::Sender<DispatchRequest<Req, Resp>>,
|
||||
DispatchRequest<Req, Resp>,
|
||||
>,
|
||||
>,
|
||||
Resp,
|
||||
>
|
||||
);
|
||||
}
|
||||
|
||||
impl<'a, Req, Resp> Future for Send<'a, Req, Resp> {
|
||||
type Output = io::Result<DispatchResponse<Resp>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.as_mut().fut().poll(cx)
|
||||
self.as_mut().project().fut.poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// A future returned by [`Channel::call`] that resolves to a server response.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct Call<'a, Req, Resp> {
|
||||
#[pin]
|
||||
fut: AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>,
|
||||
}
|
||||
|
||||
impl<'a, Req, Resp> Call<'a, Req, Resp> {
|
||||
unsafe_pinned!(fut: AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>);
|
||||
}
|
||||
|
||||
impl<'a, Req, Resp> Future for Call<'a, Req, Resp> {
|
||||
type Output = io::Result<Resp>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.as_mut().fut().poll(cx)
|
||||
self.as_mut().project().fut.poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,6 +139,7 @@ impl<Req, Resp> Channel<Req, Resp> {
|
||||
|
||||
/// A server response that is completed by request dispatch when the corresponding response
|
||||
/// arrives off the wire.
|
||||
#[pin_project(PinnedDrop)]
|
||||
#[derive(Debug)]
|
||||
struct DispatchResponse<Resp> {
|
||||
response: Timeout<oneshot::Receiver<Response<Resp>>>,
|
||||
@@ -164,10 +149,6 @@ struct DispatchResponse<Resp> {
|
||||
request_id: u64,
|
||||
}
|
||||
|
||||
impl<Resp> DispatchResponse<Resp> {
|
||||
unsafe_pinned!(ctx: context::Context);
|
||||
}
|
||||
|
||||
impl<Resp> Future for DispatchResponse<Resp> {
|
||||
type Output = io::Result<Resp>;
|
||||
|
||||
@@ -196,8 +177,9 @@ impl<Resp> Future for DispatchResponse<Resp> {
|
||||
}
|
||||
|
||||
// Cancels the request when dropped, if not already complete.
|
||||
impl<Resp> Drop for DispatchResponse<Resp> {
|
||||
fn drop(&mut self) {
|
||||
#[pinned_drop]
|
||||
impl<Resp> PinnedDrop for DispatchResponse<Resp> {
|
||||
fn drop(mut self: Pin<&mut Self>) {
|
||||
if !self.complete {
|
||||
// The receiver needs to be closed to handle the edge case that the request has not
|
||||
// yet been received by the dispatch task. It is possible for the cancel message to
|
||||
@@ -210,7 +192,8 @@ impl<Resp> Drop for DispatchResponse<Resp> {
|
||||
// dispatch task misses an early-arriving cancellation message, then it will see the
|
||||
// receiver as closed.
|
||||
self.response.get_mut().close();
|
||||
self.cancellation.cancel(self.request_id);
|
||||
let request_id = self.request_id;
|
||||
self.cancellation.cancel(request_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -246,13 +229,17 @@ where
|
||||
|
||||
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
|
||||
/// and dispatching responses to the appropriate channel.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct RequestDispatch<Req, Resp, C> {
|
||||
/// Writes requests to the wire and reads responses off the wire.
|
||||
#[pin]
|
||||
transport: Fuse<C>,
|
||||
/// Requests waiting to be written to the wire.
|
||||
#[pin]
|
||||
pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>,
|
||||
/// Requests that were dropped.
|
||||
#[pin]
|
||||
canceled_requests: Fuse<CanceledRequests>,
|
||||
/// Requests already written to the wire that haven't yet received responses.
|
||||
in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>,
|
||||
@@ -264,19 +251,16 @@ impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
|
||||
where
|
||||
C: Transport<ClientMessage<Req>, Response<Resp>>,
|
||||
{
|
||||
unsafe_pinned!(in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>);
|
||||
unsafe_pinned!(canceled_requests: Fuse<CanceledRequests>);
|
||||
unsafe_pinned!(pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>);
|
||||
unsafe_pinned!(transport: Fuse<C>);
|
||||
|
||||
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
Poll::Ready(match ready!(self.as_mut().transport().poll_next(cx)?) {
|
||||
Some(response) => {
|
||||
self.complete(response);
|
||||
Some(Ok(()))
|
||||
}
|
||||
None => None,
|
||||
})
|
||||
Poll::Ready(
|
||||
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
|
||||
Some(response) => {
|
||||
self.complete(response);
|
||||
Some(Ok(()))
|
||||
}
|
||||
None => None,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
|
||||
@@ -305,12 +289,12 @@ where
|
||||
|
||||
match (pending_requests_status, canceled_requests_status) {
|
||||
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
|
||||
ready!(self.as_mut().transport().poll_flush(cx)?);
|
||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
(ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
|
||||
// No more messages to process, so flush any messages buffered in the transport.
|
||||
ready!(self.as_mut().transport().poll_flush(cx)?);
|
||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
||||
|
||||
// Even if we fully-flush, we return Pending, because we have no more requests
|
||||
// or cancellations right now.
|
||||
@@ -324,10 +308,10 @@ where
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<DispatchRequest<Req, Resp>> {
|
||||
if self.as_mut().in_flight_requests().len() >= self.config.max_in_flight_requests {
|
||||
if self.as_mut().project().in_flight_requests.len() >= self.config.max_in_flight_requests {
|
||||
info!(
|
||||
"At in-flight request capacity ({}/{}).",
|
||||
self.as_mut().in_flight_requests().len(),
|
||||
self.as_mut().project().in_flight_requests.len(),
|
||||
self.config.max_in_flight_requests
|
||||
);
|
||||
|
||||
@@ -336,13 +320,13 @@ where
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
while let Poll::Pending = self.as_mut().transport().poll_ready(cx)? {
|
||||
while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? {
|
||||
// We can't yield a request-to-be-sent before the transport is capable of buffering it.
|
||||
ready!(self.as_mut().transport().poll_flush(cx)?);
|
||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
||||
}
|
||||
|
||||
loop {
|
||||
match ready!(self.as_mut().pending_requests().poll_next_unpin(cx)) {
|
||||
match ready!(self.as_mut().project().pending_requests.poll_next_unpin(cx)) {
|
||||
Some(request) => {
|
||||
if request.response_completion.is_canceled() {
|
||||
trace!(
|
||||
@@ -364,18 +348,25 @@ where
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<(context::Context, u64)> {
|
||||
while let Poll::Pending = self.as_mut().transport().poll_ready(cx)? {
|
||||
ready!(self.as_mut().transport().poll_flush(cx)?);
|
||||
while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? {
|
||||
ready!(self.as_mut().project().transport.poll_flush(cx)?);
|
||||
}
|
||||
|
||||
loop {
|
||||
let cancellation = self.as_mut().canceled_requests().poll_next_unpin(cx);
|
||||
let cancellation = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.canceled_requests
|
||||
.poll_next_unpin(cx);
|
||||
match ready!(cancellation) {
|
||||
Some(request_id) => {
|
||||
if let Some(in_flight_data) =
|
||||
self.as_mut().in_flight_requests().remove(&request_id)
|
||||
if let Some(in_flight_data) = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&request_id)
|
||||
{
|
||||
self.as_mut().in_flight_requests().compact(0.1);
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
debug!("[{}] Removed request.", in_flight_data.ctx.trace_id());
|
||||
return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id))));
|
||||
}
|
||||
@@ -400,8 +391,8 @@ where
|
||||
},
|
||||
_non_exhaustive: (),
|
||||
});
|
||||
self.as_mut().transport().start_send(request)?;
|
||||
self.as_mut().in_flight_requests().insert(
|
||||
self.as_mut().project().transport.start_send(request)?;
|
||||
self.as_mut().project().in_flight_requests.insert(
|
||||
request_id,
|
||||
InFlightData {
|
||||
ctx: dispatch_request.ctx,
|
||||
@@ -421,7 +412,7 @@ where
|
||||
trace_context: context.trace_context,
|
||||
request_id,
|
||||
};
|
||||
self.as_mut().transport().start_send(cancel)?;
|
||||
self.as_mut().project().transport.start_send(cancel)?;
|
||||
trace!("[{}] Cancel message sent.", trace_id);
|
||||
Ok(())
|
||||
}
|
||||
@@ -430,10 +421,11 @@ where
|
||||
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
|
||||
if let Some(in_flight_data) = self
|
||||
.as_mut()
|
||||
.in_flight_requests()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&response.request_id)
|
||||
{
|
||||
self.as_mut().in_flight_requests().compact(0.1);
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
|
||||
trace!("[{}] Received response.", in_flight_data.ctx.trace_id());
|
||||
let _ = in_flight_data.response_completion.send(response);
|
||||
@@ -460,13 +452,13 @@ where
|
||||
loop {
|
||||
match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) {
|
||||
(read, Poll::Ready(None)) => {
|
||||
if self.as_mut().in_flight_requests().is_empty() {
|
||||
if self.as_mut().project().in_flight_requests.is_empty() {
|
||||
info!("Shutdown: write half closed, and no requests in flight.");
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
info!(
|
||||
"Shutdown: write half closed, and {} requests in flight.",
|
||||
self.as_mut().in_flight_requests().len()
|
||||
self.as_mut().project().in_flight_requests.len()
|
||||
);
|
||||
match read {
|
||||
Poll::Ready(Some(())) => continue,
|
||||
@@ -529,17 +521,16 @@ impl Stream for CanceledRequests {
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
struct MapErrConnectionReset<Fut> {
|
||||
#[pin]
|
||||
future: Fut,
|
||||
finished: Option<()>,
|
||||
}
|
||||
|
||||
impl<Fut> MapErrConnectionReset<Fut> {
|
||||
unsafe_pinned!(future: Fut);
|
||||
unsafe_unpinned!(finished: Option<()>);
|
||||
|
||||
fn new(future: Fut) -> MapErrConnectionReset<Fut> {
|
||||
MapErrConnectionReset {
|
||||
future,
|
||||
@@ -548,8 +539,6 @@ impl<Fut> MapErrConnectionReset<Fut> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fut: Unpin> Unpin for MapErrConnectionReset<Fut> {}
|
||||
|
||||
impl<Fut> Future for MapErrConnectionReset<Fut>
|
||||
where
|
||||
Fut: TryFuture,
|
||||
@@ -557,10 +546,10 @@ where
|
||||
type Output = io::Result<Fut::Ok>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.as_mut().future().try_poll(cx) {
|
||||
match self.as_mut().project().future.try_poll(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(result) => {
|
||||
self.finished().take().expect(
|
||||
self.project().finished.take().expect(
|
||||
"MapErrConnectionReset must not be polled after it returned `Poll::Ready`",
|
||||
);
|
||||
Poll::Ready(result.map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset)))
|
||||
@@ -569,17 +558,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
struct MapOkDispatchResponse<Fut, Resp> {
|
||||
#[pin]
|
||||
future: Fut,
|
||||
response: Option<DispatchResponse<Resp>>,
|
||||
}
|
||||
|
||||
impl<Fut, Resp> MapOkDispatchResponse<Fut, Resp> {
|
||||
unsafe_pinned!(future: Fut);
|
||||
unsafe_unpinned!(response: Option<DispatchResponse<Resp>>);
|
||||
|
||||
fn new(future: Fut, response: DispatchResponse<Resp>) -> MapOkDispatchResponse<Fut, Resp> {
|
||||
MapOkDispatchResponse {
|
||||
future,
|
||||
@@ -588,8 +576,6 @@ impl<Fut, Resp> MapOkDispatchResponse<Fut, Resp> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fut: Unpin, Resp> Unpin for MapOkDispatchResponse<Fut, Resp> {}
|
||||
|
||||
impl<Fut, Resp> Future for MapOkDispatchResponse<Fut, Resp>
|
||||
where
|
||||
Fut: TryFuture,
|
||||
@@ -597,12 +583,13 @@ where
|
||||
type Output = Result<DispatchResponse<Resp>, Fut::Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.as_mut().future().try_poll(cx) {
|
||||
match self.as_mut().project().future.try_poll(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(result) => {
|
||||
let response = self
|
||||
.as_mut()
|
||||
.response()
|
||||
.project()
|
||||
.response
|
||||
.take()
|
||||
.expect("MapOk must not be polled after it returned `Poll::Ready`");
|
||||
Poll::Ready(result.map(|_| response))
|
||||
@@ -611,9 +598,11 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
struct AndThenIdent<Fut1, Fut2> {
|
||||
#[pin]
|
||||
try_chain: TryChain<Fut1, Fut2>,
|
||||
}
|
||||
|
||||
@@ -622,8 +611,6 @@ where
|
||||
Fut1: TryFuture<Ok = Fut2>,
|
||||
Fut2: TryFuture,
|
||||
{
|
||||
unsafe_pinned!(try_chain: TryChain<Fut1, Fut2>);
|
||||
|
||||
/// Creates a new `Then`.
|
||||
fn new(future: Fut1) -> AndThenIdent<Fut1, Fut2> {
|
||||
AndThenIdent {
|
||||
@@ -640,7 +627,7 @@ where
|
||||
type Output = Result<Fut2::Ok, Fut2::Error>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.try_chain().poll(cx, |result| match result {
|
||||
self.project().try_chain.poll(cx, |result| match result {
|
||||
Ok(ok) => TryChainAction::Future(ok),
|
||||
Err(err) => TryChainAction::Output(Err(err)),
|
||||
})
|
||||
@@ -830,7 +817,7 @@ mod tests {
|
||||
let req = send_request(&mut channel, "hi");
|
||||
|
||||
assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
|
||||
assert!(!dispatch.as_mut().in_flight_requests().is_empty());
|
||||
assert!(!dispatch.as_mut().project().in_flight_requests.is_empty());
|
||||
|
||||
// Test that a request future dropped after it's processed by dispatch will cause the request
|
||||
// to be removed from the in-flight request map.
|
||||
@@ -840,7 +827,7 @@ mod tests {
|
||||
} else {
|
||||
panic!("Expected request to be cancelled")
|
||||
};
|
||||
assert!(dispatch.in_flight_requests().is_empty());
|
||||
assert!(dispatch.project().in_flight_requests.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -18,7 +18,7 @@ use futures::{
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use log::{debug, info, trace};
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use pin_project::pin_project;
|
||||
use raii_counter::{Counter, WeakCounter};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::{
|
||||
@@ -26,30 +26,32 @@ use std::{
|
||||
};
|
||||
|
||||
/// A single-threaded filter that drops channels based on per-key limits.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ChannelFilter<S, K, F>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
{
|
||||
#[pin]
|
||||
listener: Fuse<S>,
|
||||
channels_per_key: u32,
|
||||
#[pin]
|
||||
dropped_keys: mpsc::UnboundedReceiver<K>,
|
||||
#[pin]
|
||||
dropped_keys_tx: mpsc::UnboundedSender<K>,
|
||||
key_counts: FnvHashMap<K, TrackerPrototype<K>>,
|
||||
keymaker: F,
|
||||
}
|
||||
|
||||
/// A channel that is tracked by a ChannelFilter.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct TrackedChannel<C, K> {
|
||||
#[pin]
|
||||
inner: C,
|
||||
tracker: Tracker<K>,
|
||||
}
|
||||
|
||||
impl<C, K> TrackedChannel<C, K> {
|
||||
unsafe_pinned!(inner: C);
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Tracker<K> {
|
||||
key: Option<Arc<K>>,
|
||||
@@ -130,11 +132,11 @@ where
|
||||
}
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
self.inner().in_flight_requests()
|
||||
self.project().inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
self.inner().start_request(request_id)
|
||||
self.project().inner.start_request(request_id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,22 +148,10 @@ impl<C, K> TrackedChannel<C, K> {
|
||||
|
||||
/// Returns the pinned inner channel.
|
||||
fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> {
|
||||
self.inner()
|
||||
self.project().inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, K, F> ChannelFilter<S, K, F>
|
||||
where
|
||||
K: fmt::Display + Eq + Hash + Clone,
|
||||
{
|
||||
unsafe_pinned!(listener: Fuse<S>);
|
||||
unsafe_pinned!(dropped_keys: mpsc::UnboundedReceiver<K>);
|
||||
unsafe_pinned!(dropped_keys_tx: mpsc::UnboundedSender<K>);
|
||||
unsafe_unpinned!(key_counts: FnvHashMap<K, TrackerPrototype<K>>);
|
||||
unsafe_unpinned!(channels_per_key: u32);
|
||||
unsafe_unpinned!(keymaker: F);
|
||||
}
|
||||
|
||||
impl<S, K, F> ChannelFilter<S, K, F>
|
||||
where
|
||||
K: Eq + Hash,
|
||||
@@ -192,14 +182,14 @@ where
|
||||
mut self: Pin<&mut Self>,
|
||||
stream: S::Item,
|
||||
) -> Result<TrackedChannel<S::Item, K>, K> {
|
||||
let key = self.as_mut().keymaker()(&stream);
|
||||
let key = (self.as_mut().keymaker)(&stream);
|
||||
let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
|
||||
|
||||
trace!(
|
||||
"[{}] Opening channel ({}/{}) channels for key.",
|
||||
key,
|
||||
tracker.counter.count(),
|
||||
self.as_mut().channels_per_key()
|
||||
self.as_mut().project().channels_per_key
|
||||
);
|
||||
|
||||
Ok(TrackedChannel {
|
||||
@@ -211,7 +201,7 @@ where
|
||||
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Tracker<K>, K> {
|
||||
let channels_per_key = self.channels_per_key;
|
||||
let dropped_keys = self.dropped_keys_tx.clone();
|
||||
let key_counts = &mut self.as_mut().key_counts();
|
||||
let key_counts = &mut self.as_mut().project().key_counts;
|
||||
match key_counts.entry(key.clone()) {
|
||||
Entry::Vacant(vacant) => {
|
||||
let key = Arc::new(key);
|
||||
@@ -256,18 +246,18 @@ where
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
|
||||
match ready!(self.as_mut().listener().poll_next_unpin(cx)) {
|
||||
match ready!(self.as_mut().project().listener.poll_next_unpin(cx)) {
|
||||
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
match ready!(self.as_mut().dropped_keys().poll_next_unpin(cx)) {
|
||||
match ready!(self.as_mut().project().dropped_keys.poll_next_unpin(cx)) {
|
||||
Some(key) => {
|
||||
debug!("All channels dropped for key [{}]", key);
|
||||
self.as_mut().key_counts().remove(&key);
|
||||
self.as_mut().key_counts().compact(0.1);
|
||||
self.as_mut().project().key_counts.remove(&key);
|
||||
self.as_mut().project().key_counts.compact(0.1);
|
||||
Poll::Ready(())
|
||||
}
|
||||
None => unreachable!("Holding a copy of closed_channels and didn't close it."),
|
||||
|
||||
@@ -21,7 +21,7 @@ use futures::{
|
||||
};
|
||||
use humantime::format_rfc3339;
|
||||
use log::{debug, trace};
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use pin_project::pin_project;
|
||||
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
|
||||
use tokio_timer::{timeout, Timeout};
|
||||
|
||||
@@ -165,10 +165,12 @@ where
|
||||
}
|
||||
|
||||
/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct BaseChannel<Req, Resp, T> {
|
||||
config: Config,
|
||||
/// Writes responses to the wire and reads requests off the wire.
|
||||
#[pin]
|
||||
transport: Fuse<T>,
|
||||
/// Number of requests currently being responded to.
|
||||
in_flight_requests: FnvHashMap<u64, AbortHandle>,
|
||||
@@ -176,10 +178,6 @@ pub struct BaseChannel<Req, Resp, T> {
|
||||
ghost: PhantomData<(Req, Resp)>,
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> BaseChannel<Req, Resp, T> {
|
||||
unsafe_unpinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
|
||||
}
|
||||
|
||||
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
|
||||
where
|
||||
T: Transport<Response<Resp>, ClientMessage<Req>>,
|
||||
@@ -204,19 +202,19 @@ where
|
||||
self.transport.get_ref()
|
||||
}
|
||||
|
||||
/// Returns the pinned inner transport.
|
||||
pub fn transport<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> {
|
||||
unsafe { self.map_unchecked_mut(|me| me.transport.get_mut()) }
|
||||
}
|
||||
|
||||
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
|
||||
// It's possible the request was already completed, so it's fine
|
||||
// if this is None.
|
||||
if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) {
|
||||
self.as_mut().in_flight_requests().compact(0.1);
|
||||
if let Some(cancel_handle) = self
|
||||
.as_mut()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&request_id)
|
||||
{
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
|
||||
cancel_handle.abort();
|
||||
let remaining = self.as_mut().in_flight_requests().len();
|
||||
let remaining = self.as_mut().project().in_flight_requests.len();
|
||||
trace!(
|
||||
"[{}] Request canceled. In-flight requests = {}",
|
||||
trace_context.trace_id,
|
||||
@@ -295,7 +293,7 @@ where
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
match ready!(self.as_mut().transport().poll_next(cx)?) {
|
||||
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
|
||||
Some(message) => match message {
|
||||
ClientMessage::Request(request) => {
|
||||
return Poll::Ready(Some(Ok(request)));
|
||||
@@ -321,28 +319,29 @@ where
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.transport().poll_ready(cx)
|
||||
self.project().transport.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
|
||||
if self
|
||||
.as_mut()
|
||||
.in_flight_requests()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&response.request_id)
|
||||
.is_some()
|
||||
{
|
||||
self.as_mut().in_flight_requests().compact(0.1);
|
||||
self.as_mut().project().in_flight_requests.compact(0.1);
|
||||
}
|
||||
|
||||
self.transport().start_send(response)
|
||||
self.project().transport.start_send(response)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.transport().poll_flush(cx)
|
||||
self.project().transport.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.transport().poll_close(cx)
|
||||
self.project().transport.poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,13 +363,14 @@ where
|
||||
}
|
||||
|
||||
fn in_flight_requests(mut self: Pin<&mut Self>) -> usize {
|
||||
self.as_mut().in_flight_requests().len()
|
||||
self.as_mut().project().in_flight_requests.len()
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
let (abort_handle, abort_registration) = AbortHandle::new_pair();
|
||||
assert!(self
|
||||
.in_flight_requests()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.insert(request_id, abort_handle)
|
||||
.is_none());
|
||||
abort_registration
|
||||
@@ -378,32 +378,24 @@ where
|
||||
}
|
||||
|
||||
/// A running handler serving all requests coming over a channel.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ClientHandler<C, S>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
#[pin]
|
||||
channel: C,
|
||||
/// Responses waiting to be written to the wire.
|
||||
#[pin]
|
||||
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>,
|
||||
/// Handed out to request handlers to fan in responses.
|
||||
#[pin]
|
||||
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
|
||||
/// Server
|
||||
server: S,
|
||||
}
|
||||
|
||||
impl<C, S> ClientHandler<C, S>
|
||||
where
|
||||
C: Channel,
|
||||
{
|
||||
unsafe_pinned!(channel: C);
|
||||
unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>);
|
||||
unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>);
|
||||
// For this to be safe, field f must be private, and code in this module must never
|
||||
// construct PinMut<S>.
|
||||
unsafe_unpinned!(server: S);
|
||||
}
|
||||
|
||||
impl<C, S> ClientHandler<C, S>
|
||||
where
|
||||
C: Channel,
|
||||
@@ -413,7 +405,7 @@ where
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<RequestHandler<S::Fut, C::Resp>> {
|
||||
match ready!(self.as_mut().channel().poll_next(cx)?) {
|
||||
match ready!(self.as_mut().project().channel.poll_next(cx)?) {
|
||||
Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
@@ -429,24 +421,24 @@ where
|
||||
trace!(
|
||||
"[{}] Staging response. In-flight requests = {}.",
|
||||
ctx.trace_id(),
|
||||
self.as_mut().channel().in_flight_requests(),
|
||||
self.as_mut().project().channel.in_flight_requests(),
|
||||
);
|
||||
self.as_mut().channel().start_send(response)?;
|
||||
self.as_mut().project().channel.start_send(response)?;
|
||||
Poll::Ready(Some(Ok(())))
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
// Shutdown can't be done before we finish pumping out remaining responses.
|
||||
ready!(self.as_mut().channel().poll_flush(cx)?);
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Pending => {
|
||||
// No more requests to process, so flush any requests buffered in the transport.
|
||||
ready!(self.as_mut().channel().poll_flush(cx)?);
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
|
||||
// Being here means there are no staged requests and all written responses are
|
||||
// fully flushed. So, if the read half is closed and there are no in-flight
|
||||
// requests, then we can close the write half.
|
||||
if read_half_closed && self.as_mut().channel().in_flight_requests() == 0 {
|
||||
if read_half_closed && self.as_mut().project().channel.in_flight_requests() == 0 {
|
||||
Poll::Ready(None)
|
||||
} else {
|
||||
Poll::Pending
|
||||
@@ -460,11 +452,11 @@ where
|
||||
cx: &mut Context<'_>,
|
||||
) -> PollIo<(context::Context, Response<C::Resp>)> {
|
||||
// Ensure there's room to write a response.
|
||||
while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? {
|
||||
ready!(self.as_mut().channel().poll_flush(cx)?);
|
||||
while let Poll::Pending = self.as_mut().project().channel.poll_ready(cx)? {
|
||||
ready!(self.as_mut().project().channel.poll_flush(cx)?);
|
||||
}
|
||||
|
||||
match ready!(self.as_mut().pending_responses().poll_next(cx)) {
|
||||
match ready!(self.as_mut().project().pending_responses.poll_next(cx)) {
|
||||
Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
|
||||
None => {
|
||||
// This branch likely won't happen, since the ClientHandler is holding a Sender.
|
||||
@@ -489,7 +481,7 @@ where
|
||||
let ctx = request.context;
|
||||
let request = request.message;
|
||||
|
||||
let response = self.as_mut().server().clone().serve(ctx, request);
|
||||
let response = self.as_mut().project().server.clone().serve(ctx, request);
|
||||
let response = Resp {
|
||||
state: RespState::PollResp,
|
||||
request_id,
|
||||
@@ -497,9 +489,9 @@ where
|
||||
deadline,
|
||||
f: Timeout::new(response, timeout),
|
||||
response: None,
|
||||
response_tx: self.as_mut().responses_tx().clone(),
|
||||
response_tx: self.as_mut().project().responses_tx.clone(),
|
||||
};
|
||||
let abort_registration = self.as_mut().channel().start_request(request_id);
|
||||
let abort_registration = self.as_mut().project().channel.start_request(request_id);
|
||||
RequestHandler {
|
||||
resp: Abortable::new(response, abort_registration),
|
||||
}
|
||||
@@ -507,15 +499,13 @@ where
|
||||
}
|
||||
|
||||
/// A future fulfilling a single client request.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct RequestHandler<F, R> {
|
||||
#[pin]
|
||||
resp: Abortable<Resp<F, R>>,
|
||||
}
|
||||
|
||||
impl<F, R> RequestHandler<F, R> {
|
||||
unsafe_pinned!(resp: Abortable<Resp<F, R>>);
|
||||
}
|
||||
|
||||
impl<F, R> Future for RequestHandler<F, R>
|
||||
where
|
||||
F: Future<Output = R>,
|
||||
@@ -523,19 +513,22 @@ where
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
let _ = ready!(self.resp().poll(cx));
|
||||
let _ = ready!(self.project().resp.poll(cx));
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
struct Resp<F, R> {
|
||||
state: RespState,
|
||||
request_id: u64,
|
||||
ctx: context::Context,
|
||||
deadline: SystemTime,
|
||||
#[pin]
|
||||
f: Timeout<F>,
|
||||
response: Option<Response<R>>,
|
||||
#[pin]
|
||||
response_tx: mpsc::Sender<(context::Context, Response<R>)>,
|
||||
}
|
||||
|
||||
@@ -546,13 +539,6 @@ enum RespState {
|
||||
PollFlush,
|
||||
}
|
||||
|
||||
impl<F, R> Resp<F, R> {
|
||||
unsafe_pinned!(f: Timeout<F>);
|
||||
unsafe_pinned!(response_tx: mpsc::Sender<(context::Context, Response<R>)>);
|
||||
unsafe_unpinned!(response: Option<Response<R>>);
|
||||
unsafe_unpinned!(state: RespState);
|
||||
}
|
||||
|
||||
impl<F, R> Future for Resp<F, R>
|
||||
where
|
||||
F: Future<Output = R>,
|
||||
@@ -561,10 +547,10 @@ where
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
loop {
|
||||
match self.as_mut().state() {
|
||||
match self.as_mut().project().state {
|
||||
RespState::PollResp => {
|
||||
let result = ready!(self.as_mut().f().poll(cx));
|
||||
*self.as_mut().response() = Some(Response {
|
||||
let result = ready!(self.as_mut().project().f.poll(cx));
|
||||
*self.as_mut().project().response = Some(Response {
|
||||
request_id: self.request_id,
|
||||
message: match result {
|
||||
Ok(message) => Ok(message),
|
||||
@@ -588,21 +574,27 @@ where
|
||||
},
|
||||
_non_exhaustive: (),
|
||||
});
|
||||
*self.as_mut().state() = RespState::PollReady;
|
||||
*self.as_mut().project().state = RespState::PollReady;
|
||||
}
|
||||
RespState::PollReady => {
|
||||
let ready = ready!(self.as_mut().response_tx().poll_ready(cx));
|
||||
let ready = ready!(self.as_mut().project().response_tx.poll_ready(cx));
|
||||
if ready.is_err() {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
let resp = (self.ctx, self.as_mut().response().take().unwrap());
|
||||
if self.as_mut().response_tx().start_send(resp).is_err() {
|
||||
let resp = (self.ctx, self.as_mut().project().response.take().unwrap());
|
||||
if self
|
||||
.as_mut()
|
||||
.project()
|
||||
.response_tx
|
||||
.start_send(resp)
|
||||
.is_err()
|
||||
{
|
||||
return Poll::Ready(());
|
||||
}
|
||||
*self.as_mut().state() = RespState::PollFlush;
|
||||
*self.as_mut().project().state = RespState::PollFlush;
|
||||
}
|
||||
RespState::PollFlush => {
|
||||
let ready = ready!(self.as_mut().response_tx().poll_flush(cx));
|
||||
let ready = ready!(self.as_mut().project().response_tx.poll_flush(cx));
|
||||
if ready.is_err() {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
@@ -672,19 +664,15 @@ where
|
||||
|
||||
/// A future that drives the server by spawning channels and request handlers on the default
|
||||
/// executor.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
#[cfg(feature = "tokio1")]
|
||||
pub struct Running<St, Se> {
|
||||
#[pin]
|
||||
incoming: St,
|
||||
server: Se,
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
impl<St, Se> Running<St, Se> {
|
||||
unsafe_pinned!(incoming: St);
|
||||
unsafe_unpinned!(server: Se);
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio1")]
|
||||
impl<St, C, Se> Future for Running<St, Se>
|
||||
where
|
||||
@@ -700,10 +688,10 @@ where
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
use log::info;
|
||||
|
||||
while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) {
|
||||
while let Some(channel) = ready!(self.as_mut().project().incoming.poll_next(cx)) {
|
||||
tokio::spawn(
|
||||
channel
|
||||
.respond_with(self.as_mut().server().clone())
|
||||
.respond_with(self.as_mut().project().server.clone())
|
||||
.execute(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,26 +4,23 @@ use fnv::FnvHashSet;
|
||||
use futures::future::{AbortHandle, AbortRegistration};
|
||||
use futures::{Sink, Stream};
|
||||
use futures_test::task::noop_waker_ref;
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use pin_project::pin_project;
|
||||
use std::collections::VecDeque;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[pin_project]
|
||||
pub(crate) struct FakeChannel<In, Out> {
|
||||
#[pin]
|
||||
pub stream: VecDeque<In>,
|
||||
#[pin]
|
||||
pub sink: VecDeque<Out>,
|
||||
pub config: Config,
|
||||
pub in_flight_requests: FnvHashSet<u64>,
|
||||
}
|
||||
|
||||
impl<In, Out> FakeChannel<In, Out> {
|
||||
unsafe_pinned!(stream: VecDeque<In>);
|
||||
unsafe_pinned!(sink: VecDeque<Out>);
|
||||
unsafe_unpinned!(in_flight_requests: FnvHashSet<u64>);
|
||||
}
|
||||
|
||||
impl<In, Out> Stream for FakeChannel<In, Out>
|
||||
where
|
||||
In: Unpin,
|
||||
@@ -31,7 +28,7 @@ where
|
||||
type Item = In;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
self.stream().poll_next(cx)
|
||||
self.project().stream.poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,22 +36,26 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.sink().poll_ready(cx).map_err(|e| match e {})
|
||||
self.project().sink.poll_ready(cx).map_err(|e| match e {})
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
|
||||
self.as_mut()
|
||||
.in_flight_requests()
|
||||
.project()
|
||||
.in_flight_requests
|
||||
.remove(&response.request_id);
|
||||
self.sink().start_send(response).map_err(|e| match e {})
|
||||
self.project()
|
||||
.sink
|
||||
.start_send(response)
|
||||
.map_err(|e| match e {})
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.sink().poll_flush(cx).map_err(|e| match e {})
|
||||
self.project().sink.poll_flush(cx).map_err(|e| match e {})
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.sink().poll_close(cx).map_err(|e| match e {})
|
||||
self.project().sink.poll_close(cx).map_err(|e| match e {})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +75,7 @@ where
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration {
|
||||
self.in_flight_requests().insert(id);
|
||||
self.project().in_flight_requests.insert(id);
|
||||
AbortHandle::new_pair().1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,21 +7,20 @@ use futures::{
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use log::debug;
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use pin_project::pin_project;
|
||||
use std::{io, pin::Pin};
|
||||
|
||||
/// A [`Channel`] that limits the number of concurrent
|
||||
/// requests by throttling.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct Throttler<C> {
|
||||
max_in_flight_requests: usize,
|
||||
#[pin]
|
||||
inner: C,
|
||||
}
|
||||
|
||||
impl<C> Throttler<C> {
|
||||
unsafe_unpinned!(max_in_flight_requests: usize);
|
||||
unsafe_pinned!(inner: C);
|
||||
|
||||
/// Returns the inner channel.
|
||||
pub fn get_ref(&self) -> &C {
|
||||
&self.inner
|
||||
@@ -49,16 +48,17 @@ where
|
||||
type Item = <C as Stream>::Item;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
while self.as_mut().in_flight_requests() >= *self.as_mut().max_in_flight_requests() {
|
||||
ready!(self.as_mut().inner().poll_ready(cx)?);
|
||||
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
|
||||
{
|
||||
ready!(self.as_mut().project().inner.poll_ready(cx)?);
|
||||
|
||||
match ready!(self.as_mut().inner().poll_next(cx)?) {
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
|
||||
Some(request) => {
|
||||
debug!(
|
||||
"[{}] Client has reached in-flight request limit ({}/{}).",
|
||||
request.context.trace_id(),
|
||||
self.as_mut().in_flight_requests(),
|
||||
self.as_mut().max_in_flight_requests(),
|
||||
self.as_mut().project().max_in_flight_requests,
|
||||
);
|
||||
|
||||
self.as_mut().start_send(Response {
|
||||
@@ -74,7 +74,7 @@ where
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
self.inner().poll_next(cx)
|
||||
self.project().inner.poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,19 +85,19 @@ where
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner().poll_ready(cx)
|
||||
self.project().inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
|
||||
self.inner().start_send(item)
|
||||
self.project().inner.start_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
self.inner().poll_flush(cx)
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
self.inner().poll_close(cx)
|
||||
self.project().inner.poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ where
|
||||
type Resp = <C as Channel>::Resp;
|
||||
|
||||
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
|
||||
self.inner().in_flight_requests()
|
||||
self.project().inner.in_flight_requests()
|
||||
}
|
||||
|
||||
fn config(&self) -> &Config {
|
||||
@@ -123,13 +123,15 @@ where
|
||||
}
|
||||
|
||||
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
|
||||
self.inner().start_request(request_id)
|
||||
self.project().inner.start_request(request_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream of throttling channels.
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct ThrottlerStream<S> {
|
||||
#[pin]
|
||||
inner: S,
|
||||
max_in_flight_requests: usize,
|
||||
}
|
||||
@@ -139,9 +141,6 @@ where
|
||||
S: Stream,
|
||||
<S as Stream>::Item: Channel,
|
||||
{
|
||||
unsafe_pinned!(inner: S);
|
||||
unsafe_unpinned!(max_in_flight_requests: usize);
|
||||
|
||||
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
@@ -158,10 +157,10 @@ where
|
||||
type Item = Throttler<<S as Stream>::Item>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
match ready!(self.as_mut().inner().poll_next(cx)) {
|
||||
match ready!(self.as_mut().project().inner.poll_next(cx)) {
|
||||
Some(channel) => Poll::Ready(Some(Throttler::new(
|
||||
channel,
|
||||
*self.max_in_flight_requests(),
|
||||
*self.project().max_in_flight_requests,
|
||||
))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
use crate::PollIo;
|
||||
use futures::{channel::mpsc, task::Context, Poll, Sink, Stream};
|
||||
use pin_utils::unsafe_pinned;
|
||||
use pin_project::pin_project;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
|
||||
@@ -28,22 +28,20 @@ pub fn unbounded<SinkItem, Item>() -> (
|
||||
|
||||
/// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender)
|
||||
/// and [`UnboundedReceiver`](mpsc::UnboundedReceiver).
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct UnboundedChannel<Item, SinkItem> {
|
||||
#[pin]
|
||||
rx: mpsc::UnboundedReceiver<Item>,
|
||||
#[pin]
|
||||
tx: mpsc::UnboundedSender<SinkItem>,
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> UnboundedChannel<Item, SinkItem> {
|
||||
unsafe_pinned!(rx: mpsc::UnboundedReceiver<Item>);
|
||||
unsafe_pinned!(tx: mpsc::UnboundedSender<SinkItem>);
|
||||
}
|
||||
|
||||
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
|
||||
type Item = Result<Item, io::Error>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> {
|
||||
self.rx().poll_next(cx).map(|option| option.map(Ok))
|
||||
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,25 +49,29 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.tx()
|
||||
self.project()
|
||||
.tx
|
||||
.poll_ready(cx)
|
||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
|
||||
self.tx()
|
||||
self.project()
|
||||
.tx
|
||||
.start_send(item)
|
||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.tx()
|
||||
self.project()
|
||||
.tx
|
||||
.poll_flush(cx)
|
||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.tx()
|
||||
self.project()
|
||||
.tx
|
||||
.poll_close(cx)
|
||||
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user