Port to pin-project

This commit is contained in:
Artem Vorotnikov
2019-10-09 19:07:47 +03:00
committed by Tim
parent 915fe3ed4e
commit 5f6c3d7d98
9 changed files with 211 additions and 245 deletions

View File

@@ -15,7 +15,7 @@ description = "A bincode-based transport for tarpc services."
[dependencies]
futures-preview = { version = "0.3.0-alpha.18", features = ["compat"] }
futures_legacy = { version = "0.1", package = "futures" }
pin-utils = "0.1.0-alpha.4"
pin-project = "0.4"
serde = "1.0"
tokio-io = "0.1"
async-bincode = "0.4"
@@ -24,3 +24,4 @@ tokio-tcp = "0.1"
[dev-dependencies]
futures-test-preview = { version = "0.3.0-alpha.18" }
assert_matches = "1.0"
pin-utils = "0.1.0-alpha"

View File

@@ -10,7 +10,7 @@
use async_bincode::{AsyncBincodeStream, AsyncDestination};
use futures::{compat::*, prelude::*, ready};
use pin_utils::unsafe_pinned;
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use std::{
error::Error,
@@ -24,17 +24,13 @@ use tokio_io::{AsyncRead, AsyncWrite};
use tokio_tcp::{TcpListener, TcpStream};
/// A transport that serializes to, and deserializes from, a [`TcpStream`].
#[pin_project]
#[derive(Debug)]
pub struct Transport<S, Item, SinkItem> {
#[pin]
inner: Compat01As03Sink<AsyncBincodeStream<S, Item, SinkItem, AsyncDestination>, SinkItem>,
}
impl<S, Item, SinkItem> Transport<S, Item, SinkItem> {
unsafe_pinned!(
inner: Compat01As03Sink<AsyncBincodeStream<S, Item, SinkItem, AsyncDestination>, SinkItem>
);
}
impl<S, Item, SinkItem> Stream for Transport<S, Item, SinkItem>
where
S: AsyncRead,
@@ -43,7 +39,7 @@ where
type Item = io::Result<Item>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Item>>> {
match self.inner().poll_next(cx) {
match self.project().inner.poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok(next))) => Poll::Ready(Some(Ok(next))),
@@ -62,21 +58,22 @@ where
type Error = io::Error;
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.inner()
self.project()
.inner
.start_send(item)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.inner().poll_ready(cx))
convert(self.project().inner.poll_ready(cx))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.inner().poll_flush(cx))
convert(self.project().inner.poll_flush(cx))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
convert(self.inner().poll_close(cx))
convert(self.project().inner.poll_close(cx))
}
}
@@ -153,16 +150,16 @@ where
}
/// A [`TcpListener`] that wraps connections in bincode transports.
#[pin_project]
#[derive(Debug)]
pub struct Incoming<Item, SinkItem> {
#[pin]
incoming: Compat01As03<tokio_tcp::Incoming>,
local_addr: SocketAddr,
ghost: PhantomData<(Item, SinkItem)>,
}
impl<Item, SinkItem> Incoming<Item, SinkItem> {
unsafe_pinned!(incoming: Compat01As03<tokio_tcp::Incoming>);
/// Returns the address being listened on.
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
@@ -177,7 +174,7 @@ where
type Item = io::Result<Transport<TcpStream, Item, SinkItem>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let next = ready!(self.incoming().poll_next(cx)?);
let next = ready!(self.project().incoming.poll_next(cx)?);
Poll::Ready(next.map(|conn| Ok(new(conn))))
}
}

View File

@@ -22,7 +22,7 @@ fnv = "1.0"
futures-preview = { version = "0.3.0-alpha.18" }
humantime = "1.0"
log = "0.4"
pin-utils = "0.1.0-alpha.4"
pin-project = "0.4"
raii-counter = "0.2"
rand = "0.7"
tokio-timer = "0.3.0-alpha.4"
@@ -34,3 +34,4 @@ tokio = { optional = true, version = "0.2.0-alpha.4" }
futures-test-preview = { version = "0.3.0-alpha.18" }
env_logger = "0.6"
assert_matches = "1.0"
pin-utils = "0.1.0-alpha"

View File

@@ -19,10 +19,9 @@ use futures::{
Poll,
};
use log::{debug, info, trace};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use pin_project::{pin_project, pinned_drop};
use std::{
io,
marker::Unpin,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
@@ -55,9 +54,11 @@ impl<Req, Resp> Clone for Channel<Req, Resp> {
}
/// A future returned by [`Channel::send`] that resolves to a server response.
#[pin_project]
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
struct Send<'a, Req, Resp> {
#[pin]
fut: MapOkDispatchResponse<SendMapErrConnectionReset<'a, Req, Resp>, Resp>,
}
@@ -65,45 +66,28 @@ type SendMapErrConnectionReset<'a, Req, Resp> = MapErrConnectionReset<
futures::sink::Send<'a, mpsc::Sender<DispatchRequest<Req, Resp>>, DispatchRequest<Req, Resp>>,
>;
impl<'a, Req, Resp> Send<'a, Req, Resp> {
unsafe_pinned!(
fut: MapOkDispatchResponse<
MapErrConnectionReset<
futures::sink::Send<
'a,
mpsc::Sender<DispatchRequest<Req, Resp>>,
DispatchRequest<Req, Resp>,
>,
>,
Resp,
>
);
}
impl<'a, Req, Resp> Future for Send<'a, Req, Resp> {
type Output = io::Result<DispatchResponse<Resp>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.as_mut().fut().poll(cx)
self.as_mut().project().fut.poll(cx)
}
}
/// A future returned by [`Channel::call`] that resolves to a server response.
#[pin_project]
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
pub struct Call<'a, Req, Resp> {
#[pin]
fut: AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>,
}
impl<'a, Req, Resp> Call<'a, Req, Resp> {
unsafe_pinned!(fut: AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>);
}
impl<'a, Req, Resp> Future for Call<'a, Req, Resp> {
type Output = io::Result<Resp>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.as_mut().fut().poll(cx)
self.as_mut().project().fut.poll(cx)
}
}
@@ -155,6 +139,7 @@ impl<Req, Resp> Channel<Req, Resp> {
/// A server response that is completed by request dispatch when the corresponding response
/// arrives off the wire.
#[pin_project(PinnedDrop)]
#[derive(Debug)]
struct DispatchResponse<Resp> {
response: Timeout<oneshot::Receiver<Response<Resp>>>,
@@ -164,10 +149,6 @@ struct DispatchResponse<Resp> {
request_id: u64,
}
impl<Resp> DispatchResponse<Resp> {
unsafe_pinned!(ctx: context::Context);
}
impl<Resp> Future for DispatchResponse<Resp> {
type Output = io::Result<Resp>;
@@ -196,8 +177,9 @@ impl<Resp> Future for DispatchResponse<Resp> {
}
// Cancels the request when dropped, if not already complete.
impl<Resp> Drop for DispatchResponse<Resp> {
fn drop(&mut self) {
#[pinned_drop]
impl<Resp> PinnedDrop for DispatchResponse<Resp> {
fn drop(mut self: Pin<&mut Self>) {
if !self.complete {
// The receiver needs to be closed to handle the edge case that the request has not
// yet been received by the dispatch task. It is possible for the cancel message to
@@ -210,7 +192,8 @@ impl<Resp> Drop for DispatchResponse<Resp> {
// dispatch task misses an early-arriving cancellation message, then it will see the
// receiver as closed.
self.response.get_mut().close();
self.cancellation.cancel(self.request_id);
let request_id = self.request_id;
self.cancellation.cancel(request_id);
}
}
}
@@ -246,13 +229,17 @@ where
/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
/// and dispatching responses to the appropriate channel.
#[pin_project]
#[derive(Debug)]
pub struct RequestDispatch<Req, Resp, C> {
/// Writes requests to the wire and reads responses off the wire.
#[pin]
transport: Fuse<C>,
/// Requests waiting to be written to the wire.
#[pin]
pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>,
/// Requests that were dropped.
#[pin]
canceled_requests: Fuse<CanceledRequests>,
/// Requests already written to the wire that haven't yet received responses.
in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>,
@@ -264,19 +251,16 @@ impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
unsafe_pinned!(in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>);
unsafe_pinned!(canceled_requests: Fuse<CanceledRequests>);
unsafe_pinned!(pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>);
unsafe_pinned!(transport: Fuse<C>);
fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
Poll::Ready(match ready!(self.as_mut().transport().poll_next(cx)?) {
Some(response) => {
self.complete(response);
Some(Ok(()))
}
None => None,
})
Poll::Ready(
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
Some(response) => {
self.complete(response);
Some(Ok(()))
}
None => None,
},
)
}
fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
@@ -305,12 +289,12 @@ where
match (pending_requests_status, canceled_requests_status) {
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
ready!(self.as_mut().transport().poll_flush(cx)?);
ready!(self.as_mut().project().transport.poll_flush(cx)?);
Poll::Ready(None)
}
(ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
// No more messages to process, so flush any messages buffered in the transport.
ready!(self.as_mut().transport().poll_flush(cx)?);
ready!(self.as_mut().project().transport.poll_flush(cx)?);
// Even if we fully-flush, we return Pending, because we have no more requests
// or cancellations right now.
@@ -324,10 +308,10 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<DispatchRequest<Req, Resp>> {
if self.as_mut().in_flight_requests().len() >= self.config.max_in_flight_requests {
if self.as_mut().project().in_flight_requests.len() >= self.config.max_in_flight_requests {
info!(
"At in-flight request capacity ({}/{}).",
self.as_mut().in_flight_requests().len(),
self.as_mut().project().in_flight_requests.len(),
self.config.max_in_flight_requests
);
@@ -336,13 +320,13 @@ where
return Poll::Pending;
}
while let Poll::Pending = self.as_mut().transport().poll_ready(cx)? {
while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? {
// We can't yield a request-to-be-sent before the transport is capable of buffering it.
ready!(self.as_mut().transport().poll_flush(cx)?);
ready!(self.as_mut().project().transport.poll_flush(cx)?);
}
loop {
match ready!(self.as_mut().pending_requests().poll_next_unpin(cx)) {
match ready!(self.as_mut().project().pending_requests.poll_next_unpin(cx)) {
Some(request) => {
if request.response_completion.is_canceled() {
trace!(
@@ -364,18 +348,25 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<(context::Context, u64)> {
while let Poll::Pending = self.as_mut().transport().poll_ready(cx)? {
ready!(self.as_mut().transport().poll_flush(cx)?);
while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? {
ready!(self.as_mut().project().transport.poll_flush(cx)?);
}
loop {
let cancellation = self.as_mut().canceled_requests().poll_next_unpin(cx);
let cancellation = self
.as_mut()
.project()
.canceled_requests
.poll_next_unpin(cx);
match ready!(cancellation) {
Some(request_id) => {
if let Some(in_flight_data) =
self.as_mut().in_flight_requests().remove(&request_id)
if let Some(in_flight_data) = self
.as_mut()
.project()
.in_flight_requests
.remove(&request_id)
{
self.as_mut().in_flight_requests().compact(0.1);
self.as_mut().project().in_flight_requests.compact(0.1);
debug!("[{}] Removed request.", in_flight_data.ctx.trace_id());
return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id))));
}
@@ -400,8 +391,8 @@ where
},
_non_exhaustive: (),
});
self.as_mut().transport().start_send(request)?;
self.as_mut().in_flight_requests().insert(
self.as_mut().project().transport.start_send(request)?;
self.as_mut().project().in_flight_requests.insert(
request_id,
InFlightData {
ctx: dispatch_request.ctx,
@@ -421,7 +412,7 @@ where
trace_context: context.trace_context,
request_id,
};
self.as_mut().transport().start_send(cancel)?;
self.as_mut().project().transport.start_send(cancel)?;
trace!("[{}] Cancel message sent.", trace_id);
Ok(())
}
@@ -430,10 +421,11 @@ where
fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
if let Some(in_flight_data) = self
.as_mut()
.in_flight_requests()
.project()
.in_flight_requests
.remove(&response.request_id)
{
self.as_mut().in_flight_requests().compact(0.1);
self.as_mut().project().in_flight_requests.compact(0.1);
trace!("[{}] Received response.", in_flight_data.ctx.trace_id());
let _ = in_flight_data.response_completion.send(response);
@@ -460,13 +452,13 @@ where
loop {
match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) {
(read, Poll::Ready(None)) => {
if self.as_mut().in_flight_requests().is_empty() {
if self.as_mut().project().in_flight_requests.is_empty() {
info!("Shutdown: write half closed, and no requests in flight.");
return Poll::Ready(Ok(()));
}
info!(
"Shutdown: write half closed, and {} requests in flight.",
self.as_mut().in_flight_requests().len()
self.as_mut().project().in_flight_requests.len()
);
match read {
Poll::Ready(Some(())) => continue,
@@ -529,17 +521,16 @@ impl Stream for CanceledRequests {
}
}
#[pin_project]
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
struct MapErrConnectionReset<Fut> {
#[pin]
future: Fut,
finished: Option<()>,
}
impl<Fut> MapErrConnectionReset<Fut> {
unsafe_pinned!(future: Fut);
unsafe_unpinned!(finished: Option<()>);
fn new(future: Fut) -> MapErrConnectionReset<Fut> {
MapErrConnectionReset {
future,
@@ -548,8 +539,6 @@ impl<Fut> MapErrConnectionReset<Fut> {
}
}
impl<Fut: Unpin> Unpin for MapErrConnectionReset<Fut> {}
impl<Fut> Future for MapErrConnectionReset<Fut>
where
Fut: TryFuture,
@@ -557,10 +546,10 @@ where
type Output = io::Result<Fut::Ok>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().future().try_poll(cx) {
match self.as_mut().project().future.try_poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(result) => {
self.finished().take().expect(
self.project().finished.take().expect(
"MapErrConnectionReset must not be polled after it returned `Poll::Ready`",
);
Poll::Ready(result.map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset)))
@@ -569,17 +558,16 @@ where
}
}
#[pin_project]
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
struct MapOkDispatchResponse<Fut, Resp> {
#[pin]
future: Fut,
response: Option<DispatchResponse<Resp>>,
}
impl<Fut, Resp> MapOkDispatchResponse<Fut, Resp> {
unsafe_pinned!(future: Fut);
unsafe_unpinned!(response: Option<DispatchResponse<Resp>>);
fn new(future: Fut, response: DispatchResponse<Resp>) -> MapOkDispatchResponse<Fut, Resp> {
MapOkDispatchResponse {
future,
@@ -588,8 +576,6 @@ impl<Fut, Resp> MapOkDispatchResponse<Fut, Resp> {
}
}
impl<Fut: Unpin, Resp> Unpin for MapOkDispatchResponse<Fut, Resp> {}
impl<Fut, Resp> Future for MapOkDispatchResponse<Fut, Resp>
where
Fut: TryFuture,
@@ -597,12 +583,13 @@ where
type Output = Result<DispatchResponse<Resp>, Fut::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().future().try_poll(cx) {
match self.as_mut().project().future.try_poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(result) => {
let response = self
.as_mut()
.response()
.project()
.response
.take()
.expect("MapOk must not be polled after it returned `Poll::Ready`");
Poll::Ready(result.map(|_| response))
@@ -611,9 +598,11 @@ where
}
}
#[pin_project]
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
struct AndThenIdent<Fut1, Fut2> {
#[pin]
try_chain: TryChain<Fut1, Fut2>,
}
@@ -622,8 +611,6 @@ where
Fut1: TryFuture<Ok = Fut2>,
Fut2: TryFuture,
{
unsafe_pinned!(try_chain: TryChain<Fut1, Fut2>);
/// Creates a new `Then`.
fn new(future: Fut1) -> AndThenIdent<Fut1, Fut2> {
AndThenIdent {
@@ -640,7 +627,7 @@ where
type Output = Result<Fut2::Ok, Fut2::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.try_chain().poll(cx, |result| match result {
self.project().try_chain.poll(cx, |result| match result {
Ok(ok) => TryChainAction::Future(ok),
Err(err) => TryChainAction::Output(Err(err)),
})
@@ -830,7 +817,7 @@ mod tests {
let req = send_request(&mut channel, "hi");
assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
assert!(!dispatch.as_mut().in_flight_requests().is_empty());
assert!(!dispatch.as_mut().project().in_flight_requests.is_empty());
// Test that a request future dropped after it's processed by dispatch will cause the request
// to be removed from the in-flight request map.
@@ -840,7 +827,7 @@ mod tests {
} else {
panic!("Expected request to be cancelled")
};
assert!(dispatch.in_flight_requests().is_empty());
assert!(dispatch.project().in_flight_requests.is_empty());
}
#[test]

View File

@@ -18,7 +18,7 @@ use futures::{
task::{Context, Poll},
};
use log::{debug, info, trace};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use pin_project::pin_project;
use raii_counter::{Counter, WeakCounter};
use std::sync::{Arc, Weak};
use std::{
@@ -26,30 +26,32 @@ use std::{
};
/// A single-threaded filter that drops channels based on per-key limits.
#[pin_project]
#[derive(Debug)]
pub struct ChannelFilter<S, K, F>
where
K: Eq + Hash,
{
#[pin]
listener: Fuse<S>,
channels_per_key: u32,
#[pin]
dropped_keys: mpsc::UnboundedReceiver<K>,
#[pin]
dropped_keys_tx: mpsc::UnboundedSender<K>,
key_counts: FnvHashMap<K, TrackerPrototype<K>>,
keymaker: F,
}
/// A channel that is tracked by a ChannelFilter.
#[pin_project]
#[derive(Debug)]
pub struct TrackedChannel<C, K> {
#[pin]
inner: C,
tracker: Tracker<K>,
}
impl<C, K> TrackedChannel<C, K> {
unsafe_pinned!(inner: C);
}
#[derive(Clone, Debug)]
struct Tracker<K> {
key: Option<Arc<K>>,
@@ -130,11 +132,11 @@ where
}
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
self.inner().in_flight_requests()
self.project().inner.in_flight_requests()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
self.inner().start_request(request_id)
self.project().inner.start_request(request_id)
}
}
@@ -146,22 +148,10 @@ impl<C, K> TrackedChannel<C, K> {
/// Returns the pinned inner channel.
fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> {
self.inner()
self.project().inner
}
}
impl<S, K, F> ChannelFilter<S, K, F>
where
K: fmt::Display + Eq + Hash + Clone,
{
unsafe_pinned!(listener: Fuse<S>);
unsafe_pinned!(dropped_keys: mpsc::UnboundedReceiver<K>);
unsafe_pinned!(dropped_keys_tx: mpsc::UnboundedSender<K>);
unsafe_unpinned!(key_counts: FnvHashMap<K, TrackerPrototype<K>>);
unsafe_unpinned!(channels_per_key: u32);
unsafe_unpinned!(keymaker: F);
}
impl<S, K, F> ChannelFilter<S, K, F>
where
K: Eq + Hash,
@@ -192,14 +182,14 @@ where
mut self: Pin<&mut Self>,
stream: S::Item,
) -> Result<TrackedChannel<S::Item, K>, K> {
let key = self.as_mut().keymaker()(&stream);
let key = (self.as_mut().keymaker)(&stream);
let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
trace!(
"[{}] Opening channel ({}/{}) channels for key.",
key,
tracker.counter.count(),
self.as_mut().channels_per_key()
self.as_mut().project().channels_per_key
);
Ok(TrackedChannel {
@@ -211,7 +201,7 @@ where
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Tracker<K>, K> {
let channels_per_key = self.channels_per_key;
let dropped_keys = self.dropped_keys_tx.clone();
let key_counts = &mut self.as_mut().key_counts();
let key_counts = &mut self.as_mut().project().key_counts;
match key_counts.entry(key.clone()) {
Entry::Vacant(vacant) => {
let key = Arc::new(key);
@@ -256,18 +246,18 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
match ready!(self.as_mut().listener().poll_next_unpin(cx)) {
match ready!(self.as_mut().project().listener.poll_next_unpin(cx)) {
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
None => Poll::Ready(None),
}
}
fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
match ready!(self.as_mut().dropped_keys().poll_next_unpin(cx)) {
match ready!(self.as_mut().project().dropped_keys.poll_next_unpin(cx)) {
Some(key) => {
debug!("All channels dropped for key [{}]", key);
self.as_mut().key_counts().remove(&key);
self.as_mut().key_counts().compact(0.1);
self.as_mut().project().key_counts.remove(&key);
self.as_mut().project().key_counts.compact(0.1);
Poll::Ready(())
}
None => unreachable!("Holding a copy of closed_channels and didn't close it."),

View File

@@ -21,7 +21,7 @@ use futures::{
};
use humantime::format_rfc3339;
use log::{debug, trace};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use pin_project::pin_project;
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
use tokio_timer::{timeout, Timeout};
@@ -165,10 +165,12 @@ where
}
/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
#[pin_project]
#[derive(Debug)]
pub struct BaseChannel<Req, Resp, T> {
config: Config,
/// Writes responses to the wire and reads requests off the wire.
#[pin]
transport: Fuse<T>,
/// Number of requests currently being responded to.
in_flight_requests: FnvHashMap<u64, AbortHandle>,
@@ -176,10 +178,6 @@ pub struct BaseChannel<Req, Resp, T> {
ghost: PhantomData<(Req, Resp)>,
}
impl<Req, Resp, T> BaseChannel<Req, Resp, T> {
unsafe_unpinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
}
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
@@ -204,19 +202,19 @@ where
self.transport.get_ref()
}
/// Returns the pinned inner transport.
pub fn transport<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> {
unsafe { self.map_unchecked_mut(|me| me.transport.get_mut()) }
}
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
// It's possible the request was already completed, so it's fine
// if this is None.
if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) {
self.as_mut().in_flight_requests().compact(0.1);
if let Some(cancel_handle) = self
.as_mut()
.project()
.in_flight_requests
.remove(&request_id)
{
self.as_mut().project().in_flight_requests.compact(0.1);
cancel_handle.abort();
let remaining = self.as_mut().in_flight_requests().len();
let remaining = self.as_mut().project().in_flight_requests.len();
trace!(
"[{}] Request canceled. In-flight requests = {}",
trace_context.trace_id,
@@ -295,7 +293,7 @@ where
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
loop {
match ready!(self.as_mut().transport().poll_next(cx)?) {
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
Some(message) => match message {
ClientMessage::Request(request) => {
return Poll::Ready(Some(Ok(request)));
@@ -321,28 +319,29 @@ where
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.transport().poll_ready(cx)
self.project().transport.poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
if self
.as_mut()
.in_flight_requests()
.project()
.in_flight_requests
.remove(&response.request_id)
.is_some()
{
self.as_mut().in_flight_requests().compact(0.1);
self.as_mut().project().in_flight_requests.compact(0.1);
}
self.transport().start_send(response)
self.project().transport.start_send(response)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.transport().poll_flush(cx)
self.project().transport.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.transport().poll_close(cx)
self.project().transport.poll_close(cx)
}
}
@@ -364,13 +363,14 @@ where
}
fn in_flight_requests(mut self: Pin<&mut Self>) -> usize {
self.as_mut().in_flight_requests().len()
self.as_mut().project().in_flight_requests.len()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
let (abort_handle, abort_registration) = AbortHandle::new_pair();
assert!(self
.in_flight_requests()
.project()
.in_flight_requests
.insert(request_id, abort_handle)
.is_none());
abort_registration
@@ -378,32 +378,24 @@ where
}
/// A running handler serving all requests coming over a channel.
#[pin_project]
#[derive(Debug)]
pub struct ClientHandler<C, S>
where
C: Channel,
{
#[pin]
channel: C,
/// Responses waiting to be written to the wire.
#[pin]
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>,
/// Handed out to request handlers to fan in responses.
#[pin]
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
/// Server
server: S,
}
impl<C, S> ClientHandler<C, S>
where
C: Channel,
{
unsafe_pinned!(channel: C);
unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>);
unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>);
// For this to be safe, field f must be private, and code in this module must never
// construct PinMut<S>.
unsafe_unpinned!(server: S);
}
impl<C, S> ClientHandler<C, S>
where
C: Channel,
@@ -413,7 +405,7 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<RequestHandler<S::Fut, C::Resp>> {
match ready!(self.as_mut().channel().poll_next(cx)?) {
match ready!(self.as_mut().project().channel.poll_next(cx)?) {
Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))),
None => Poll::Ready(None),
}
@@ -429,24 +421,24 @@ where
trace!(
"[{}] Staging response. In-flight requests = {}.",
ctx.trace_id(),
self.as_mut().channel().in_flight_requests(),
self.as_mut().project().channel.in_flight_requests(),
);
self.as_mut().channel().start_send(response)?;
self.as_mut().project().channel.start_send(response)?;
Poll::Ready(Some(Ok(())))
}
Poll::Ready(None) => {
// Shutdown can't be done before we finish pumping out remaining responses.
ready!(self.as_mut().channel().poll_flush(cx)?);
ready!(self.as_mut().project().channel.poll_flush(cx)?);
Poll::Ready(None)
}
Poll::Pending => {
// No more requests to process, so flush any requests buffered in the transport.
ready!(self.as_mut().channel().poll_flush(cx)?);
ready!(self.as_mut().project().channel.poll_flush(cx)?);
// Being here means there are no staged requests and all written responses are
// fully flushed. So, if the read half is closed and there are no in-flight
// requests, then we can close the write half.
if read_half_closed && self.as_mut().channel().in_flight_requests() == 0 {
if read_half_closed && self.as_mut().project().channel.in_flight_requests() == 0 {
Poll::Ready(None)
} else {
Poll::Pending
@@ -460,11 +452,11 @@ where
cx: &mut Context<'_>,
) -> PollIo<(context::Context, Response<C::Resp>)> {
// Ensure there's room to write a response.
while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? {
ready!(self.as_mut().channel().poll_flush(cx)?);
while let Poll::Pending = self.as_mut().project().channel.poll_ready(cx)? {
ready!(self.as_mut().project().channel.poll_flush(cx)?);
}
match ready!(self.as_mut().pending_responses().poll_next(cx)) {
match ready!(self.as_mut().project().pending_responses.poll_next(cx)) {
Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
None => {
// This branch likely won't happen, since the ClientHandler is holding a Sender.
@@ -489,7 +481,7 @@ where
let ctx = request.context;
let request = request.message;
let response = self.as_mut().server().clone().serve(ctx, request);
let response = self.as_mut().project().server.clone().serve(ctx, request);
let response = Resp {
state: RespState::PollResp,
request_id,
@@ -497,9 +489,9 @@ where
deadline,
f: Timeout::new(response, timeout),
response: None,
response_tx: self.as_mut().responses_tx().clone(),
response_tx: self.as_mut().project().responses_tx.clone(),
};
let abort_registration = self.as_mut().channel().start_request(request_id);
let abort_registration = self.as_mut().project().channel.start_request(request_id);
RequestHandler {
resp: Abortable::new(response, abort_registration),
}
@@ -507,15 +499,13 @@ where
}
/// A future fulfilling a single client request.
#[pin_project]
#[derive(Debug)]
pub struct RequestHandler<F, R> {
#[pin]
resp: Abortable<Resp<F, R>>,
}
impl<F, R> RequestHandler<F, R> {
unsafe_pinned!(resp: Abortable<Resp<F, R>>);
}
impl<F, R> Future for RequestHandler<F, R>
where
F: Future<Output = R>,
@@ -523,19 +513,22 @@ where
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let _ = ready!(self.resp().poll(cx));
let _ = ready!(self.project().resp.poll(cx));
Poll::Ready(())
}
}
#[pin_project]
#[derive(Debug)]
struct Resp<F, R> {
state: RespState,
request_id: u64,
ctx: context::Context,
deadline: SystemTime,
#[pin]
f: Timeout<F>,
response: Option<Response<R>>,
#[pin]
response_tx: mpsc::Sender<(context::Context, Response<R>)>,
}
@@ -546,13 +539,6 @@ enum RespState {
PollFlush,
}
impl<F, R> Resp<F, R> {
unsafe_pinned!(f: Timeout<F>);
unsafe_pinned!(response_tx: mpsc::Sender<(context::Context, Response<R>)>);
unsafe_unpinned!(response: Option<Response<R>>);
unsafe_unpinned!(state: RespState);
}
impl<F, R> Future for Resp<F, R>
where
F: Future<Output = R>,
@@ -561,10 +547,10 @@ where
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
loop {
match self.as_mut().state() {
match self.as_mut().project().state {
RespState::PollResp => {
let result = ready!(self.as_mut().f().poll(cx));
*self.as_mut().response() = Some(Response {
let result = ready!(self.as_mut().project().f.poll(cx));
*self.as_mut().project().response = Some(Response {
request_id: self.request_id,
message: match result {
Ok(message) => Ok(message),
@@ -588,21 +574,27 @@ where
},
_non_exhaustive: (),
});
*self.as_mut().state() = RespState::PollReady;
*self.as_mut().project().state = RespState::PollReady;
}
RespState::PollReady => {
let ready = ready!(self.as_mut().response_tx().poll_ready(cx));
let ready = ready!(self.as_mut().project().response_tx.poll_ready(cx));
if ready.is_err() {
return Poll::Ready(());
}
let resp = (self.ctx, self.as_mut().response().take().unwrap());
if self.as_mut().response_tx().start_send(resp).is_err() {
let resp = (self.ctx, self.as_mut().project().response.take().unwrap());
if self
.as_mut()
.project()
.response_tx
.start_send(resp)
.is_err()
{
return Poll::Ready(());
}
*self.as_mut().state() = RespState::PollFlush;
*self.as_mut().project().state = RespState::PollFlush;
}
RespState::PollFlush => {
let ready = ready!(self.as_mut().response_tx().poll_flush(cx));
let ready = ready!(self.as_mut().project().response_tx.poll_flush(cx));
if ready.is_err() {
return Poll::Ready(());
}
@@ -672,19 +664,15 @@ where
/// A future that drives the server by spawning channels and request handlers on the default
/// executor.
#[pin_project]
#[derive(Debug)]
#[cfg(feature = "tokio1")]
pub struct Running<St, Se> {
#[pin]
incoming: St,
server: Se,
}
#[cfg(feature = "tokio1")]
impl<St, Se> Running<St, Se> {
unsafe_pinned!(incoming: St);
unsafe_unpinned!(server: Se);
}
#[cfg(feature = "tokio1")]
impl<St, C, Se> Future for Running<St, Se>
where
@@ -700,10 +688,10 @@ where
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
use log::info;
while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) {
while let Some(channel) = ready!(self.as_mut().project().incoming.poll_next(cx)) {
tokio::spawn(
channel
.respond_with(self.as_mut().server().clone())
.respond_with(self.as_mut().project().server.clone())
.execute(),
);
}

View File

@@ -4,26 +4,23 @@ use fnv::FnvHashSet;
use futures::future::{AbortHandle, AbortRegistration};
use futures::{Sink, Stream};
use futures_test::task::noop_waker_ref;
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use pin_project::pin_project;
use std::collections::VecDeque;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::SystemTime;
#[pin_project]
pub(crate) struct FakeChannel<In, Out> {
#[pin]
pub stream: VecDeque<In>,
#[pin]
pub sink: VecDeque<Out>,
pub config: Config,
pub in_flight_requests: FnvHashSet<u64>,
}
impl<In, Out> FakeChannel<In, Out> {
unsafe_pinned!(stream: VecDeque<In>);
unsafe_pinned!(sink: VecDeque<Out>);
unsafe_unpinned!(in_flight_requests: FnvHashSet<u64>);
}
impl<In, Out> Stream for FakeChannel<In, Out>
where
In: Unpin,
@@ -31,7 +28,7 @@ where
type Item = In;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.stream().poll_next(cx)
self.project().stream.poll_next(cx)
}
}
@@ -39,22 +36,26 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.sink().poll_ready(cx).map_err(|e| match e {})
self.project().sink.poll_ready(cx).map_err(|e| match e {})
}
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
self.as_mut()
.in_flight_requests()
.project()
.in_flight_requests
.remove(&response.request_id);
self.sink().start_send(response).map_err(|e| match e {})
self.project()
.sink
.start_send(response)
.map_err(|e| match e {})
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.sink().poll_flush(cx).map_err(|e| match e {})
self.project().sink.poll_flush(cx).map_err(|e| match e {})
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.sink().poll_close(cx).map_err(|e| match e {})
self.project().sink.poll_close(cx).map_err(|e| match e {})
}
}
@@ -74,7 +75,7 @@ where
}
fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration {
self.in_flight_requests().insert(id);
self.project().in_flight_requests.insert(id);
AbortHandle::new_pair().1
}
}

View File

@@ -7,21 +7,20 @@ use futures::{
task::{Context, Poll},
};
use log::debug;
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use pin_project::pin_project;
use std::{io, pin::Pin};
/// A [`Channel`] that limits the number of concurrent
/// requests by throttling.
#[pin_project]
#[derive(Debug)]
pub struct Throttler<C> {
max_in_flight_requests: usize,
#[pin]
inner: C,
}
impl<C> Throttler<C> {
unsafe_unpinned!(max_in_flight_requests: usize);
unsafe_pinned!(inner: C);
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
@@ -49,16 +48,17 @@ where
type Item = <C as Stream>::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
while self.as_mut().in_flight_requests() >= *self.as_mut().max_in_flight_requests() {
ready!(self.as_mut().inner().poll_ready(cx)?);
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
{
ready!(self.as_mut().project().inner.poll_ready(cx)?);
match ready!(self.as_mut().inner().poll_next(cx)?) {
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
Some(request) => {
debug!(
"[{}] Client has reached in-flight request limit ({}/{}).",
request.context.trace_id(),
self.as_mut().in_flight_requests(),
self.as_mut().max_in_flight_requests(),
self.as_mut().project().max_in_flight_requests,
);
self.as_mut().start_send(Response {
@@ -74,7 +74,7 @@ where
None => return Poll::Ready(None),
}
}
self.inner().poll_next(cx)
self.project().inner.poll_next(cx)
}
}
@@ -85,19 +85,19 @@ where
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.inner().poll_ready(cx)
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
self.inner().start_send(item)
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.inner().poll_flush(cx)
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.inner().poll_close(cx)
self.project().inner.poll_close(cx)
}
}
@@ -115,7 +115,7 @@ where
type Resp = <C as Channel>::Resp;
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
self.inner().in_flight_requests()
self.project().inner.in_flight_requests()
}
fn config(&self) -> &Config {
@@ -123,13 +123,15 @@ where
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
self.inner().start_request(request_id)
self.project().inner.start_request(request_id)
}
}
/// A stream of throttling channels.
#[pin_project]
#[derive(Debug)]
pub struct ThrottlerStream<S> {
#[pin]
inner: S,
max_in_flight_requests: usize,
}
@@ -139,9 +141,6 @@ where
S: Stream,
<S as Stream>::Item: Channel,
{
unsafe_pinned!(inner: S);
unsafe_unpinned!(max_in_flight_requests: usize);
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
Self {
inner,
@@ -158,10 +157,10 @@ where
type Item = Throttler<<S as Stream>::Item>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match ready!(self.as_mut().inner().poll_next(cx)) {
match ready!(self.as_mut().project().inner.poll_next(cx)) {
Some(channel) => Poll::Ready(Some(Throttler::new(
channel,
*self.max_in_flight_requests(),
*self.project().max_in_flight_requests,
))),
None => Poll::Ready(None),
}

View File

@@ -8,7 +8,7 @@
use crate::PollIo;
use futures::{channel::mpsc, task::Context, Poll, Sink, Stream};
use pin_utils::unsafe_pinned;
use pin_project::pin_project;
use std::io;
use std::pin::Pin;
@@ -28,22 +28,20 @@ pub fn unbounded<SinkItem, Item>() -> (
/// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender)
/// and [`UnboundedReceiver`](mpsc::UnboundedReceiver).
#[pin_project]
#[derive(Debug)]
pub struct UnboundedChannel<Item, SinkItem> {
#[pin]
rx: mpsc::UnboundedReceiver<Item>,
#[pin]
tx: mpsc::UnboundedSender<SinkItem>,
}
impl<Item, SinkItem> UnboundedChannel<Item, SinkItem> {
unsafe_pinned!(rx: mpsc::UnboundedReceiver<Item>);
unsafe_pinned!(tx: mpsc::UnboundedSender<SinkItem>);
}
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
type Item = Result<Item, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<Item> {
self.rx().poll_next(cx).map(|option| option.map(Ok))
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
}
}
@@ -51,25 +49,29 @@ impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.tx()
self.project()
.tx
.poll_ready(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> {
self.tx()
self.project()
.tx
.start_send(item)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx()
self.project()
.tx
.poll_flush(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.tx()
self.project()
.tx
.poll_close(cx)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))
}