Port to pin-project

This commit is contained in:
Artem Vorotnikov
2019-10-09 19:07:47 +03:00
committed by Tim
parent 915fe3ed4e
commit 5f6c3d7d98
9 changed files with 211 additions and 245 deletions

View File

@@ -18,7 +18,7 @@ use futures::{
task::{Context, Poll},
};
use log::{debug, info, trace};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use pin_project::pin_project;
use raii_counter::{Counter, WeakCounter};
use std::sync::{Arc, Weak};
use std::{
@@ -26,30 +26,32 @@ use std::{
};
/// A single-threaded filter that drops channels based on per-key limits.
#[pin_project]
#[derive(Debug)]
pub struct ChannelFilter<S, K, F>
where
K: Eq + Hash,
{
#[pin]
listener: Fuse<S>,
channels_per_key: u32,
#[pin]
dropped_keys: mpsc::UnboundedReceiver<K>,
#[pin]
dropped_keys_tx: mpsc::UnboundedSender<K>,
key_counts: FnvHashMap<K, TrackerPrototype<K>>,
keymaker: F,
}
/// A channel that is tracked by a ChannelFilter.
#[pin_project]
#[derive(Debug)]
pub struct TrackedChannel<C, K> {
#[pin]
inner: C,
tracker: Tracker<K>,
}
impl<C, K> TrackedChannel<C, K> {
unsafe_pinned!(inner: C);
}
#[derive(Clone, Debug)]
struct Tracker<K> {
key: Option<Arc<K>>,
@@ -130,11 +132,11 @@ where
}
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
self.inner().in_flight_requests()
self.project().inner.in_flight_requests()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
self.inner().start_request(request_id)
self.project().inner.start_request(request_id)
}
}
@@ -146,22 +148,10 @@ impl<C, K> TrackedChannel<C, K> {
/// Returns the pinned inner channel.
fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> {
self.inner()
self.project().inner
}
}
impl<S, K, F> ChannelFilter<S, K, F>
where
K: fmt::Display + Eq + Hash + Clone,
{
unsafe_pinned!(listener: Fuse<S>);
unsafe_pinned!(dropped_keys: mpsc::UnboundedReceiver<K>);
unsafe_pinned!(dropped_keys_tx: mpsc::UnboundedSender<K>);
unsafe_unpinned!(key_counts: FnvHashMap<K, TrackerPrototype<K>>);
unsafe_unpinned!(channels_per_key: u32);
unsafe_unpinned!(keymaker: F);
}
impl<S, K, F> ChannelFilter<S, K, F>
where
K: Eq + Hash,
@@ -192,14 +182,14 @@ where
mut self: Pin<&mut Self>,
stream: S::Item,
) -> Result<TrackedChannel<S::Item, K>, K> {
let key = self.as_mut().keymaker()(&stream);
let key = (self.as_mut().keymaker)(&stream);
let tracker = self.as_mut().increment_channels_for_key(key.clone())?;
trace!(
"[{}] Opening channel ({}/{}) channels for key.",
key,
tracker.counter.count(),
self.as_mut().channels_per_key()
self.as_mut().project().channels_per_key
);
Ok(TrackedChannel {
@@ -211,7 +201,7 @@ where
fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result<Tracker<K>, K> {
let channels_per_key = self.channels_per_key;
let dropped_keys = self.dropped_keys_tx.clone();
let key_counts = &mut self.as_mut().key_counts();
let key_counts = &mut self.as_mut().project().key_counts;
match key_counts.entry(key.clone()) {
Entry::Vacant(vacant) => {
let key = Arc::new(key);
@@ -256,18 +246,18 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<TrackedChannel<S::Item, K>, K>>> {
match ready!(self.as_mut().listener().poll_next_unpin(cx)) {
match ready!(self.as_mut().project().listener.poll_next_unpin(cx)) {
Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))),
None => Poll::Ready(None),
}
}
fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
match ready!(self.as_mut().dropped_keys().poll_next_unpin(cx)) {
match ready!(self.as_mut().project().dropped_keys.poll_next_unpin(cx)) {
Some(key) => {
debug!("All channels dropped for key [{}]", key);
self.as_mut().key_counts().remove(&key);
self.as_mut().key_counts().compact(0.1);
self.as_mut().project().key_counts.remove(&key);
self.as_mut().project().key_counts.compact(0.1);
Poll::Ready(())
}
None => unreachable!("Holding a copy of closed_channels and didn't close it."),

View File

@@ -21,7 +21,7 @@ use futures::{
};
use humantime::format_rfc3339;
use log::{debug, trace};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use pin_project::pin_project;
use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
use tokio_timer::{timeout, Timeout};
@@ -165,10 +165,12 @@ where
}
/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
#[pin_project]
#[derive(Debug)]
pub struct BaseChannel<Req, Resp, T> {
config: Config,
/// Writes responses to the wire and reads requests off the wire.
#[pin]
transport: Fuse<T>,
/// Number of requests currently being responded to.
in_flight_requests: FnvHashMap<u64, AbortHandle>,
@@ -176,10 +178,6 @@ pub struct BaseChannel<Req, Resp, T> {
ghost: PhantomData<(Req, Resp)>,
}
impl<Req, Resp, T> BaseChannel<Req, Resp, T> {
unsafe_unpinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
}
impl<Req, Resp, T> BaseChannel<Req, Resp, T>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
@@ -204,19 +202,19 @@ where
self.transport.get_ref()
}
/// Returns the pinned inner transport.
pub fn transport<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> {
unsafe { self.map_unchecked_mut(|me| me.transport.get_mut()) }
}
fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
// It's possible the request was already completed, so it's fine
// if this is None.
if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) {
self.as_mut().in_flight_requests().compact(0.1);
if let Some(cancel_handle) = self
.as_mut()
.project()
.in_flight_requests
.remove(&request_id)
{
self.as_mut().project().in_flight_requests.compact(0.1);
cancel_handle.abort();
let remaining = self.as_mut().in_flight_requests().len();
let remaining = self.as_mut().project().in_flight_requests.len();
trace!(
"[{}] Request canceled. In-flight requests = {}",
trace_context.trace_id,
@@ -295,7 +293,7 @@ where
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
loop {
match ready!(self.as_mut().transport().poll_next(cx)?) {
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
Some(message) => match message {
ClientMessage::Request(request) => {
return Poll::Ready(Some(Ok(request)));
@@ -321,28 +319,29 @@ where
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.transport().poll_ready(cx)
self.project().transport.poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
if self
.as_mut()
.in_flight_requests()
.project()
.in_flight_requests
.remove(&response.request_id)
.is_some()
{
self.as_mut().in_flight_requests().compact(0.1);
self.as_mut().project().in_flight_requests.compact(0.1);
}
self.transport().start_send(response)
self.project().transport.start_send(response)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.transport().poll_flush(cx)
self.project().transport.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.transport().poll_close(cx)
self.project().transport.poll_close(cx)
}
}
@@ -364,13 +363,14 @@ where
}
fn in_flight_requests(mut self: Pin<&mut Self>) -> usize {
self.as_mut().in_flight_requests().len()
self.as_mut().project().in_flight_requests.len()
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
let (abort_handle, abort_registration) = AbortHandle::new_pair();
assert!(self
.in_flight_requests()
.project()
.in_flight_requests
.insert(request_id, abort_handle)
.is_none());
abort_registration
@@ -378,32 +378,24 @@ where
}
/// A running handler serving all requests coming over a channel.
#[pin_project]
#[derive(Debug)]
pub struct ClientHandler<C, S>
where
C: Channel,
{
#[pin]
channel: C,
/// Responses waiting to be written to the wire.
#[pin]
pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>,
/// Handed out to request handlers to fan in responses.
#[pin]
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
/// Server
server: S,
}
impl<C, S> ClientHandler<C, S>
where
C: Channel,
{
unsafe_pinned!(channel: C);
unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>);
unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>);
// For this to be safe, field f must be private, and code in this module must never
// construct PinMut<S>.
unsafe_unpinned!(server: S);
}
impl<C, S> ClientHandler<C, S>
where
C: Channel,
@@ -413,7 +405,7 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<RequestHandler<S::Fut, C::Resp>> {
match ready!(self.as_mut().channel().poll_next(cx)?) {
match ready!(self.as_mut().project().channel.poll_next(cx)?) {
Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))),
None => Poll::Ready(None),
}
@@ -429,24 +421,24 @@ where
trace!(
"[{}] Staging response. In-flight requests = {}.",
ctx.trace_id(),
self.as_mut().channel().in_flight_requests(),
self.as_mut().project().channel.in_flight_requests(),
);
self.as_mut().channel().start_send(response)?;
self.as_mut().project().channel.start_send(response)?;
Poll::Ready(Some(Ok(())))
}
Poll::Ready(None) => {
// Shutdown can't be done before we finish pumping out remaining responses.
ready!(self.as_mut().channel().poll_flush(cx)?);
ready!(self.as_mut().project().channel.poll_flush(cx)?);
Poll::Ready(None)
}
Poll::Pending => {
// No more requests to process, so flush any requests buffered in the transport.
ready!(self.as_mut().channel().poll_flush(cx)?);
ready!(self.as_mut().project().channel.poll_flush(cx)?);
// Being here means there are no staged requests and all written responses are
// fully flushed. So, if the read half is closed and there are no in-flight
// requests, then we can close the write half.
if read_half_closed && self.as_mut().channel().in_flight_requests() == 0 {
if read_half_closed && self.as_mut().project().channel.in_flight_requests() == 0 {
Poll::Ready(None)
} else {
Poll::Pending
@@ -460,11 +452,11 @@ where
cx: &mut Context<'_>,
) -> PollIo<(context::Context, Response<C::Resp>)> {
// Ensure there's room to write a response.
while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? {
ready!(self.as_mut().channel().poll_flush(cx)?);
while let Poll::Pending = self.as_mut().project().channel.poll_ready(cx)? {
ready!(self.as_mut().project().channel.poll_flush(cx)?);
}
match ready!(self.as_mut().pending_responses().poll_next(cx)) {
match ready!(self.as_mut().project().pending_responses.poll_next(cx)) {
Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
None => {
// This branch likely won't happen, since the ClientHandler is holding a Sender.
@@ -489,7 +481,7 @@ where
let ctx = request.context;
let request = request.message;
let response = self.as_mut().server().clone().serve(ctx, request);
let response = self.as_mut().project().server.clone().serve(ctx, request);
let response = Resp {
state: RespState::PollResp,
request_id,
@@ -497,9 +489,9 @@ where
deadline,
f: Timeout::new(response, timeout),
response: None,
response_tx: self.as_mut().responses_tx().clone(),
response_tx: self.as_mut().project().responses_tx.clone(),
};
let abort_registration = self.as_mut().channel().start_request(request_id);
let abort_registration = self.as_mut().project().channel.start_request(request_id);
RequestHandler {
resp: Abortable::new(response, abort_registration),
}
@@ -507,15 +499,13 @@ where
}
/// A future fulfilling a single client request.
#[pin_project]
#[derive(Debug)]
pub struct RequestHandler<F, R> {
#[pin]
resp: Abortable<Resp<F, R>>,
}
impl<F, R> RequestHandler<F, R> {
unsafe_pinned!(resp: Abortable<Resp<F, R>>);
}
impl<F, R> Future for RequestHandler<F, R>
where
F: Future<Output = R>,
@@ -523,19 +513,22 @@ where
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let _ = ready!(self.resp().poll(cx));
let _ = ready!(self.project().resp.poll(cx));
Poll::Ready(())
}
}
#[pin_project]
#[derive(Debug)]
struct Resp<F, R> {
state: RespState,
request_id: u64,
ctx: context::Context,
deadline: SystemTime,
#[pin]
f: Timeout<F>,
response: Option<Response<R>>,
#[pin]
response_tx: mpsc::Sender<(context::Context, Response<R>)>,
}
@@ -546,13 +539,6 @@ enum RespState {
PollFlush,
}
impl<F, R> Resp<F, R> {
unsafe_pinned!(f: Timeout<F>);
unsafe_pinned!(response_tx: mpsc::Sender<(context::Context, Response<R>)>);
unsafe_unpinned!(response: Option<Response<R>>);
unsafe_unpinned!(state: RespState);
}
impl<F, R> Future for Resp<F, R>
where
F: Future<Output = R>,
@@ -561,10 +547,10 @@ where
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
loop {
match self.as_mut().state() {
match self.as_mut().project().state {
RespState::PollResp => {
let result = ready!(self.as_mut().f().poll(cx));
*self.as_mut().response() = Some(Response {
let result = ready!(self.as_mut().project().f.poll(cx));
*self.as_mut().project().response = Some(Response {
request_id: self.request_id,
message: match result {
Ok(message) => Ok(message),
@@ -588,21 +574,27 @@ where
},
_non_exhaustive: (),
});
*self.as_mut().state() = RespState::PollReady;
*self.as_mut().project().state = RespState::PollReady;
}
RespState::PollReady => {
let ready = ready!(self.as_mut().response_tx().poll_ready(cx));
let ready = ready!(self.as_mut().project().response_tx.poll_ready(cx));
if ready.is_err() {
return Poll::Ready(());
}
let resp = (self.ctx, self.as_mut().response().take().unwrap());
if self.as_mut().response_tx().start_send(resp).is_err() {
let resp = (self.ctx, self.as_mut().project().response.take().unwrap());
if self
.as_mut()
.project()
.response_tx
.start_send(resp)
.is_err()
{
return Poll::Ready(());
}
*self.as_mut().state() = RespState::PollFlush;
*self.as_mut().project().state = RespState::PollFlush;
}
RespState::PollFlush => {
let ready = ready!(self.as_mut().response_tx().poll_flush(cx));
let ready = ready!(self.as_mut().project().response_tx.poll_flush(cx));
if ready.is_err() {
return Poll::Ready(());
}
@@ -672,19 +664,15 @@ where
/// A future that drives the server by spawning channels and request handlers on the default
/// executor.
#[pin_project]
#[derive(Debug)]
#[cfg(feature = "tokio1")]
pub struct Running<St, Se> {
#[pin]
incoming: St,
server: Se,
}
#[cfg(feature = "tokio1")]
impl<St, Se> Running<St, Se> {
unsafe_pinned!(incoming: St);
unsafe_unpinned!(server: Se);
}
#[cfg(feature = "tokio1")]
impl<St, C, Se> Future for Running<St, Se>
where
@@ -700,10 +688,10 @@ where
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
use log::info;
while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) {
while let Some(channel) = ready!(self.as_mut().project().incoming.poll_next(cx)) {
tokio::spawn(
channel
.respond_with(self.as_mut().server().clone())
.respond_with(self.as_mut().project().server.clone())
.execute(),
);
}

View File

@@ -4,26 +4,23 @@ use fnv::FnvHashSet;
use futures::future::{AbortHandle, AbortRegistration};
use futures::{Sink, Stream};
use futures_test::task::noop_waker_ref;
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use pin_project::pin_project;
use std::collections::VecDeque;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::SystemTime;
#[pin_project]
pub(crate) struct FakeChannel<In, Out> {
#[pin]
pub stream: VecDeque<In>,
#[pin]
pub sink: VecDeque<Out>,
pub config: Config,
pub in_flight_requests: FnvHashSet<u64>,
}
impl<In, Out> FakeChannel<In, Out> {
unsafe_pinned!(stream: VecDeque<In>);
unsafe_pinned!(sink: VecDeque<Out>);
unsafe_unpinned!(in_flight_requests: FnvHashSet<u64>);
}
impl<In, Out> Stream for FakeChannel<In, Out>
where
In: Unpin,
@@ -31,7 +28,7 @@ where
type Item = In;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.stream().poll_next(cx)
self.project().stream.poll_next(cx)
}
}
@@ -39,22 +36,26 @@ impl<In, Resp> Sink<Response<Resp>> for FakeChannel<In, Response<Resp>> {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.sink().poll_ready(cx).map_err(|e| match e {})
self.project().sink.poll_ready(cx).map_err(|e| match e {})
}
fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
self.as_mut()
.in_flight_requests()
.project()
.in_flight_requests
.remove(&response.request_id);
self.sink().start_send(response).map_err(|e| match e {})
self.project()
.sink
.start_send(response)
.map_err(|e| match e {})
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.sink().poll_flush(cx).map_err(|e| match e {})
self.project().sink.poll_flush(cx).map_err(|e| match e {})
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.sink().poll_close(cx).map_err(|e| match e {})
self.project().sink.poll_close(cx).map_err(|e| match e {})
}
}
@@ -74,7 +75,7 @@ where
}
fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration {
self.in_flight_requests().insert(id);
self.project().in_flight_requests.insert(id);
AbortHandle::new_pair().1
}
}

View File

@@ -7,21 +7,20 @@ use futures::{
task::{Context, Poll},
};
use log::debug;
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use pin_project::pin_project;
use std::{io, pin::Pin};
/// A [`Channel`] that limits the number of concurrent
/// requests by throttling.
#[pin_project]
#[derive(Debug)]
pub struct Throttler<C> {
max_in_flight_requests: usize,
#[pin]
inner: C,
}
impl<C> Throttler<C> {
unsafe_unpinned!(max_in_flight_requests: usize);
unsafe_pinned!(inner: C);
/// Returns the inner channel.
pub fn get_ref(&self) -> &C {
&self.inner
@@ -49,16 +48,17 @@ where
type Item = <C as Stream>::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
while self.as_mut().in_flight_requests() >= *self.as_mut().max_in_flight_requests() {
ready!(self.as_mut().inner().poll_ready(cx)?);
while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests
{
ready!(self.as_mut().project().inner.poll_ready(cx)?);
match ready!(self.as_mut().inner().poll_next(cx)?) {
match ready!(self.as_mut().project().inner.poll_next(cx)?) {
Some(request) => {
debug!(
"[{}] Client has reached in-flight request limit ({}/{}).",
request.context.trace_id(),
self.as_mut().in_flight_requests(),
self.as_mut().max_in_flight_requests(),
self.as_mut().project().max_in_flight_requests,
);
self.as_mut().start_send(Response {
@@ -74,7 +74,7 @@ where
None => return Poll::Ready(None),
}
}
self.inner().poll_next(cx)
self.project().inner.poll_next(cx)
}
}
@@ -85,19 +85,19 @@ where
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.inner().poll_ready(cx)
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Response<<C as Channel>::Resp>) -> io::Result<()> {
self.inner().start_send(item)
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.inner().poll_flush(cx)
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.inner().poll_close(cx)
self.project().inner.poll_close(cx)
}
}
@@ -115,7 +115,7 @@ where
type Resp = <C as Channel>::Resp;
fn in_flight_requests(self: Pin<&mut Self>) -> usize {
self.inner().in_flight_requests()
self.project().inner.in_flight_requests()
}
fn config(&self) -> &Config {
@@ -123,13 +123,15 @@ where
}
fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
self.inner().start_request(request_id)
self.project().inner.start_request(request_id)
}
}
/// A stream of throttling channels.
#[pin_project]
#[derive(Debug)]
pub struct ThrottlerStream<S> {
#[pin]
inner: S,
max_in_flight_requests: usize,
}
@@ -139,9 +141,6 @@ where
S: Stream,
<S as Stream>::Item: Channel,
{
unsafe_pinned!(inner: S);
unsafe_unpinned!(max_in_flight_requests: usize);
pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self {
Self {
inner,
@@ -158,10 +157,10 @@ where
type Item = Throttler<<S as Stream>::Item>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match ready!(self.as_mut().inner().poll_next(cx)) {
match ready!(self.as_mut().project().inner.poll_next(cx)) {
Some(channel) => Poll::Ready(Some(Throttler::new(
channel,
*self.max_in_flight_requests(),
*self.project().max_in_flight_requests,
))),
None => Poll::Ready(None),
}