From 5f6c3d7d989b7d7267c19d4b89c658450e86614e Mon Sep 17 00:00:00 2001 From: Artem Vorotnikov Date: Wed, 9 Oct 2019 19:07:47 +0300 Subject: [PATCH] Port to pin-project --- bincode-transport/Cargo.toml | 3 +- bincode-transport/src/lib.rs | 27 +++---- rpc/Cargo.toml | 3 +- rpc/src/client/channel.rs | 149 ++++++++++++++++------------------- rpc/src/server/filter.rs | 44 ++++------- rpc/src/server/mod.rs | 138 +++++++++++++++----------------- rpc/src/server/testing.rs | 29 +++---- rpc/src/server/throttle.rs | 39 +++++---- rpc/src/transport/channel.rs | 24 +++--- 9 files changed, 211 insertions(+), 245 deletions(-) diff --git a/bincode-transport/Cargo.toml b/bincode-transport/Cargo.toml index 95a378e..c542a6e 100644 --- a/bincode-transport/Cargo.toml +++ b/bincode-transport/Cargo.toml @@ -15,7 +15,7 @@ description = "A bincode-based transport for tarpc services." [dependencies] futures-preview = { version = "0.3.0-alpha.18", features = ["compat"] } futures_legacy = { version = "0.1", package = "futures" } -pin-utils = "0.1.0-alpha.4" +pin-project = "0.4" serde = "1.0" tokio-io = "0.1" async-bincode = "0.4" @@ -24,3 +24,4 @@ tokio-tcp = "0.1" [dev-dependencies] futures-test-preview = { version = "0.3.0-alpha.18" } assert_matches = "1.0" +pin-utils = "0.1.0-alpha" diff --git a/bincode-transport/src/lib.rs b/bincode-transport/src/lib.rs index 93ff2dc..6b2fa26 100644 --- a/bincode-transport/src/lib.rs +++ b/bincode-transport/src/lib.rs @@ -10,7 +10,7 @@ use async_bincode::{AsyncBincodeStream, AsyncDestination}; use futures::{compat::*, prelude::*, ready}; -use pin_utils::unsafe_pinned; +use pin_project::pin_project; use serde::{Deserialize, Serialize}; use std::{ error::Error, @@ -24,17 +24,13 @@ use tokio_io::{AsyncRead, AsyncWrite}; use tokio_tcp::{TcpListener, TcpStream}; /// A transport that serializes to, and deserializes from, a [`TcpStream`]. +#[pin_project] #[derive(Debug)] pub struct Transport { + #[pin] inner: Compat01As03Sink, SinkItem>, } -impl Transport { - unsafe_pinned!( - inner: Compat01As03Sink, SinkItem> - ); -} - impl Stream for Transport where S: AsyncRead, @@ -43,7 +39,7 @@ where type Item = io::Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { - match self.inner().poll_next(cx) { + match self.project().inner.poll_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => Poll::Ready(None), Poll::Ready(Some(Ok(next))) => Poll::Ready(Some(Ok(next))), @@ -62,21 +58,22 @@ where type Error = io::Error; fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { - self.inner() + self.project() + .inner .start_send(item) .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) } fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - convert(self.inner().poll_ready(cx)) + convert(self.project().inner.poll_ready(cx)) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - convert(self.inner().poll_flush(cx)) + convert(self.project().inner.poll_flush(cx)) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - convert(self.inner().poll_close(cx)) + convert(self.project().inner.poll_close(cx)) } } @@ -153,16 +150,16 @@ where } /// A [`TcpListener`] that wraps connections in bincode transports. +#[pin_project] #[derive(Debug)] pub struct Incoming { + #[pin] incoming: Compat01As03, local_addr: SocketAddr, ghost: PhantomData<(Item, SinkItem)>, } impl Incoming { - unsafe_pinned!(incoming: Compat01As03); - /// Returns the address being listened on. pub fn local_addr(&self) -> SocketAddr { self.local_addr @@ -177,7 +174,7 @@ where type Item = io::Result>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let next = ready!(self.incoming().poll_next(cx)?); + let next = ready!(self.project().incoming.poll_next(cx)?); Poll::Ready(next.map(|conn| Ok(new(conn)))) } } diff --git a/rpc/Cargo.toml b/rpc/Cargo.toml index 531fb0f..5e85a69 100644 --- a/rpc/Cargo.toml +++ b/rpc/Cargo.toml @@ -22,7 +22,7 @@ fnv = "1.0" futures-preview = { version = "0.3.0-alpha.18" } humantime = "1.0" log = "0.4" -pin-utils = "0.1.0-alpha.4" +pin-project = "0.4" raii-counter = "0.2" rand = "0.7" tokio-timer = "0.3.0-alpha.4" @@ -34,3 +34,4 @@ tokio = { optional = true, version = "0.2.0-alpha.4" } futures-test-preview = { version = "0.3.0-alpha.18" } env_logger = "0.6" assert_matches = "1.0" +pin-utils = "0.1.0-alpha" diff --git a/rpc/src/client/channel.rs b/rpc/src/client/channel.rs index 70a4475..4f93a1f 100644 --- a/rpc/src/client/channel.rs +++ b/rpc/src/client/channel.rs @@ -19,10 +19,9 @@ use futures::{ Poll, }; use log::{debug, info, trace}; -use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use pin_project::{pin_project, pinned_drop}; use std::{ io, - marker::Unpin, pin::Pin, sync::{ atomic::{AtomicU64, Ordering}, @@ -55,9 +54,11 @@ impl Clone for Channel { } /// A future returned by [`Channel::send`] that resolves to a server response. +#[pin_project] #[derive(Debug)] #[must_use = "futures do nothing unless polled"] struct Send<'a, Req, Resp> { + #[pin] fut: MapOkDispatchResponse, Resp>, } @@ -65,45 +66,28 @@ type SendMapErrConnectionReset<'a, Req, Resp> = MapErrConnectionReset< futures::sink::Send<'a, mpsc::Sender>, DispatchRequest>, >; -impl<'a, Req, Resp> Send<'a, Req, Resp> { - unsafe_pinned!( - fut: MapOkDispatchResponse< - MapErrConnectionReset< - futures::sink::Send< - 'a, - mpsc::Sender>, - DispatchRequest, - >, - >, - Resp, - > - ); -} - impl<'a, Req, Resp> Future for Send<'a, Req, Resp> { type Output = io::Result>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.as_mut().fut().poll(cx) + self.as_mut().project().fut.poll(cx) } } /// A future returned by [`Channel::call`] that resolves to a server response. +#[pin_project] #[derive(Debug)] #[must_use = "futures do nothing unless polled"] pub struct Call<'a, Req, Resp> { + #[pin] fut: AndThenIdent, DispatchResponse>, } -impl<'a, Req, Resp> Call<'a, Req, Resp> { - unsafe_pinned!(fut: AndThenIdent, DispatchResponse>); -} - impl<'a, Req, Resp> Future for Call<'a, Req, Resp> { type Output = io::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.as_mut().fut().poll(cx) + self.as_mut().project().fut.poll(cx) } } @@ -155,6 +139,7 @@ impl Channel { /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. +#[pin_project(PinnedDrop)] #[derive(Debug)] struct DispatchResponse { response: Timeout>>, @@ -164,10 +149,6 @@ struct DispatchResponse { request_id: u64, } -impl DispatchResponse { - unsafe_pinned!(ctx: context::Context); -} - impl Future for DispatchResponse { type Output = io::Result; @@ -196,8 +177,9 @@ impl Future for DispatchResponse { } // Cancels the request when dropped, if not already complete. -impl Drop for DispatchResponse { - fn drop(&mut self) { +#[pinned_drop] +impl PinnedDrop for DispatchResponse { + fn drop(mut self: Pin<&mut Self>) { if !self.complete { // The receiver needs to be closed to handle the edge case that the request has not // yet been received by the dispatch task. It is possible for the cancel message to @@ -210,7 +192,8 @@ impl Drop for DispatchResponse { // dispatch task misses an early-arriving cancellation message, then it will see the // receiver as closed. self.response.get_mut().close(); - self.cancellation.cancel(self.request_id); + let request_id = self.request_id; + self.cancellation.cancel(request_id); } } } @@ -246,13 +229,17 @@ where /// Handles the lifecycle of requests, writing requests to the wire, managing cancellations, /// and dispatching responses to the appropriate channel. +#[pin_project] #[derive(Debug)] pub struct RequestDispatch { /// Writes requests to the wire and reads responses off the wire. + #[pin] transport: Fuse, /// Requests waiting to be written to the wire. + #[pin] pending_requests: Fuse>>, /// Requests that were dropped. + #[pin] canceled_requests: Fuse, /// Requests already written to the wire that haven't yet received responses. in_flight_requests: FnvHashMap>, @@ -264,19 +251,16 @@ impl RequestDispatch where C: Transport, Response>, { - unsafe_pinned!(in_flight_requests: FnvHashMap>); - unsafe_pinned!(canceled_requests: Fuse); - unsafe_pinned!(pending_requests: Fuse>>); - unsafe_pinned!(transport: Fuse); - fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { - Poll::Ready(match ready!(self.as_mut().transport().poll_next(cx)?) { - Some(response) => { - self.complete(response); - Some(Ok(())) - } - None => None, - }) + Poll::Ready( + match ready!(self.as_mut().project().transport.poll_next(cx)?) { + Some(response) => { + self.complete(response); + Some(Ok(())) + } + None => None, + }, + ) } fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { @@ -305,12 +289,12 @@ where match (pending_requests_status, canceled_requests_status) { (ReceiverStatus::Closed, ReceiverStatus::Closed) => { - ready!(self.as_mut().transport().poll_flush(cx)?); + ready!(self.as_mut().project().transport.poll_flush(cx)?); Poll::Ready(None) } (ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => { // No more messages to process, so flush any messages buffered in the transport. - ready!(self.as_mut().transport().poll_flush(cx)?); + ready!(self.as_mut().project().transport.poll_flush(cx)?); // Even if we fully-flush, we return Pending, because we have no more requests // or cancellations right now. @@ -324,10 +308,10 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> PollIo> { - if self.as_mut().in_flight_requests().len() >= self.config.max_in_flight_requests { + if self.as_mut().project().in_flight_requests.len() >= self.config.max_in_flight_requests { info!( "At in-flight request capacity ({}/{}).", - self.as_mut().in_flight_requests().len(), + self.as_mut().project().in_flight_requests.len(), self.config.max_in_flight_requests ); @@ -336,13 +320,13 @@ where return Poll::Pending; } - while let Poll::Pending = self.as_mut().transport().poll_ready(cx)? { + while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? { // We can't yield a request-to-be-sent before the transport is capable of buffering it. - ready!(self.as_mut().transport().poll_flush(cx)?); + ready!(self.as_mut().project().transport.poll_flush(cx)?); } loop { - match ready!(self.as_mut().pending_requests().poll_next_unpin(cx)) { + match ready!(self.as_mut().project().pending_requests.poll_next_unpin(cx)) { Some(request) => { if request.response_completion.is_canceled() { trace!( @@ -364,18 +348,25 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> PollIo<(context::Context, u64)> { - while let Poll::Pending = self.as_mut().transport().poll_ready(cx)? { - ready!(self.as_mut().transport().poll_flush(cx)?); + while let Poll::Pending = self.as_mut().project().transport.poll_ready(cx)? { + ready!(self.as_mut().project().transport.poll_flush(cx)?); } loop { - let cancellation = self.as_mut().canceled_requests().poll_next_unpin(cx); + let cancellation = self + .as_mut() + .project() + .canceled_requests + .poll_next_unpin(cx); match ready!(cancellation) { Some(request_id) => { - if let Some(in_flight_data) = - self.as_mut().in_flight_requests().remove(&request_id) + if let Some(in_flight_data) = self + .as_mut() + .project() + .in_flight_requests + .remove(&request_id) { - self.as_mut().in_flight_requests().compact(0.1); + self.as_mut().project().in_flight_requests.compact(0.1); debug!("[{}] Removed request.", in_flight_data.ctx.trace_id()); return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id)))); } @@ -400,8 +391,8 @@ where }, _non_exhaustive: (), }); - self.as_mut().transport().start_send(request)?; - self.as_mut().in_flight_requests().insert( + self.as_mut().project().transport.start_send(request)?; + self.as_mut().project().in_flight_requests.insert( request_id, InFlightData { ctx: dispatch_request.ctx, @@ -421,7 +412,7 @@ where trace_context: context.trace_context, request_id, }; - self.as_mut().transport().start_send(cancel)?; + self.as_mut().project().transport.start_send(cancel)?; trace!("[{}] Cancel message sent.", trace_id); Ok(()) } @@ -430,10 +421,11 @@ where fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { if let Some(in_flight_data) = self .as_mut() - .in_flight_requests() + .project() + .in_flight_requests .remove(&response.request_id) { - self.as_mut().in_flight_requests().compact(0.1); + self.as_mut().project().in_flight_requests.compact(0.1); trace!("[{}] Received response.", in_flight_data.ctx.trace_id()); let _ = in_flight_data.response_completion.send(response); @@ -460,13 +452,13 @@ where loop { match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) { (read, Poll::Ready(None)) => { - if self.as_mut().in_flight_requests().is_empty() { + if self.as_mut().project().in_flight_requests.is_empty() { info!("Shutdown: write half closed, and no requests in flight."); return Poll::Ready(Ok(())); } info!( "Shutdown: write half closed, and {} requests in flight.", - self.as_mut().in_flight_requests().len() + self.as_mut().project().in_flight_requests.len() ); match read { Poll::Ready(Some(())) => continue, @@ -529,17 +521,16 @@ impl Stream for CanceledRequests { } } +#[pin_project] #[derive(Debug)] #[must_use = "futures do nothing unless polled"] struct MapErrConnectionReset { + #[pin] future: Fut, finished: Option<()>, } impl MapErrConnectionReset { - unsafe_pinned!(future: Fut); - unsafe_unpinned!(finished: Option<()>); - fn new(future: Fut) -> MapErrConnectionReset { MapErrConnectionReset { future, @@ -548,8 +539,6 @@ impl MapErrConnectionReset { } } -impl Unpin for MapErrConnectionReset {} - impl Future for MapErrConnectionReset where Fut: TryFuture, @@ -557,10 +546,10 @@ where type Output = io::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.as_mut().future().try_poll(cx) { + match self.as_mut().project().future.try_poll(cx) { Poll::Pending => Poll::Pending, Poll::Ready(result) => { - self.finished().take().expect( + self.project().finished.take().expect( "MapErrConnectionReset must not be polled after it returned `Poll::Ready`", ); Poll::Ready(result.map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset))) @@ -569,17 +558,16 @@ where } } +#[pin_project] #[derive(Debug)] #[must_use = "futures do nothing unless polled"] struct MapOkDispatchResponse { + #[pin] future: Fut, response: Option>, } impl MapOkDispatchResponse { - unsafe_pinned!(future: Fut); - unsafe_unpinned!(response: Option>); - fn new(future: Fut, response: DispatchResponse) -> MapOkDispatchResponse { MapOkDispatchResponse { future, @@ -588,8 +576,6 @@ impl MapOkDispatchResponse { } } -impl Unpin for MapOkDispatchResponse {} - impl Future for MapOkDispatchResponse where Fut: TryFuture, @@ -597,12 +583,13 @@ where type Output = Result, Fut::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.as_mut().future().try_poll(cx) { + match self.as_mut().project().future.try_poll(cx) { Poll::Pending => Poll::Pending, Poll::Ready(result) => { let response = self .as_mut() - .response() + .project() + .response .take() .expect("MapOk must not be polled after it returned `Poll::Ready`"); Poll::Ready(result.map(|_| response)) @@ -611,9 +598,11 @@ where } } +#[pin_project] #[derive(Debug)] #[must_use = "futures do nothing unless polled"] struct AndThenIdent { + #[pin] try_chain: TryChain, } @@ -622,8 +611,6 @@ where Fut1: TryFuture, Fut2: TryFuture, { - unsafe_pinned!(try_chain: TryChain); - /// Creates a new `Then`. fn new(future: Fut1) -> AndThenIdent { AndThenIdent { @@ -640,7 +627,7 @@ where type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.try_chain().poll(cx, |result| match result { + self.project().try_chain.poll(cx, |result| match result { Ok(ok) => TryChainAction::Future(ok), Err(err) => TryChainAction::Output(Err(err)), }) @@ -830,7 +817,7 @@ mod tests { let req = send_request(&mut channel, "hi"); assert!(dispatch.as_mut().pump_write(cx).ready().is_some()); - assert!(!dispatch.as_mut().in_flight_requests().is_empty()); + assert!(!dispatch.as_mut().project().in_flight_requests.is_empty()); // Test that a request future dropped after it's processed by dispatch will cause the request // to be removed from the in-flight request map. @@ -840,7 +827,7 @@ mod tests { } else { panic!("Expected request to be cancelled") }; - assert!(dispatch.in_flight_requests().is_empty()); + assert!(dispatch.project().in_flight_requests.is_empty()); } #[test] diff --git a/rpc/src/server/filter.rs b/rpc/src/server/filter.rs index c59daa3..024ea46 100644 --- a/rpc/src/server/filter.rs +++ b/rpc/src/server/filter.rs @@ -18,7 +18,7 @@ use futures::{ task::{Context, Poll}, }; use log::{debug, info, trace}; -use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use pin_project::pin_project; use raii_counter::{Counter, WeakCounter}; use std::sync::{Arc, Weak}; use std::{ @@ -26,30 +26,32 @@ use std::{ }; /// A single-threaded filter that drops channels based on per-key limits. +#[pin_project] #[derive(Debug)] pub struct ChannelFilter where K: Eq + Hash, { + #[pin] listener: Fuse, channels_per_key: u32, + #[pin] dropped_keys: mpsc::UnboundedReceiver, + #[pin] dropped_keys_tx: mpsc::UnboundedSender, key_counts: FnvHashMap>, keymaker: F, } /// A channel that is tracked by a ChannelFilter. +#[pin_project] #[derive(Debug)] pub struct TrackedChannel { + #[pin] inner: C, tracker: Tracker, } -impl TrackedChannel { - unsafe_pinned!(inner: C); -} - #[derive(Clone, Debug)] struct Tracker { key: Option>, @@ -130,11 +132,11 @@ where } fn in_flight_requests(self: Pin<&mut Self>) -> usize { - self.inner().in_flight_requests() + self.project().inner.in_flight_requests() } fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration { - self.inner().start_request(request_id) + self.project().inner.start_request(request_id) } } @@ -146,22 +148,10 @@ impl TrackedChannel { /// Returns the pinned inner channel. fn channel<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut C> { - self.inner() + self.project().inner } } -impl ChannelFilter -where - K: fmt::Display + Eq + Hash + Clone, -{ - unsafe_pinned!(listener: Fuse); - unsafe_pinned!(dropped_keys: mpsc::UnboundedReceiver); - unsafe_pinned!(dropped_keys_tx: mpsc::UnboundedSender); - unsafe_unpinned!(key_counts: FnvHashMap>); - unsafe_unpinned!(channels_per_key: u32); - unsafe_unpinned!(keymaker: F); -} - impl ChannelFilter where K: Eq + Hash, @@ -192,14 +182,14 @@ where mut self: Pin<&mut Self>, stream: S::Item, ) -> Result, K> { - let key = self.as_mut().keymaker()(&stream); + let key = (self.as_mut().keymaker)(&stream); let tracker = self.as_mut().increment_channels_for_key(key.clone())?; trace!( "[{}] Opening channel ({}/{}) channels for key.", key, tracker.counter.count(), - self.as_mut().channels_per_key() + self.as_mut().project().channels_per_key ); Ok(TrackedChannel { @@ -211,7 +201,7 @@ where fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result, K> { let channels_per_key = self.channels_per_key; let dropped_keys = self.dropped_keys_tx.clone(); - let key_counts = &mut self.as_mut().key_counts(); + let key_counts = &mut self.as_mut().project().key_counts; match key_counts.entry(key.clone()) { Entry::Vacant(vacant) => { let key = Arc::new(key); @@ -256,18 +246,18 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, K>>> { - match ready!(self.as_mut().listener().poll_next_unpin(cx)) { + match ready!(self.as_mut().project().listener.poll_next_unpin(cx)) { Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))), None => Poll::Ready(None), } } fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - match ready!(self.as_mut().dropped_keys().poll_next_unpin(cx)) { + match ready!(self.as_mut().project().dropped_keys.poll_next_unpin(cx)) { Some(key) => { debug!("All channels dropped for key [{}]", key); - self.as_mut().key_counts().remove(&key); - self.as_mut().key_counts().compact(0.1); + self.as_mut().project().key_counts.remove(&key); + self.as_mut().project().key_counts.compact(0.1); Poll::Ready(()) } None => unreachable!("Holding a copy of closed_channels and didn't close it."), diff --git a/rpc/src/server/mod.rs b/rpc/src/server/mod.rs index 27d430b..4011891 100644 --- a/rpc/src/server/mod.rs +++ b/rpc/src/server/mod.rs @@ -21,7 +21,7 @@ use futures::{ }; use humantime::format_rfc3339; use log::{debug, trace}; -use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use pin_project::pin_project; use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime}; use tokio_timer::{timeout, Timeout}; @@ -165,10 +165,12 @@ where } /// BaseChannel lifts a Transport to a Channel by tracking in-flight requests. +#[pin_project] #[derive(Debug)] pub struct BaseChannel { config: Config, /// Writes responses to the wire and reads requests off the wire. + #[pin] transport: Fuse, /// Number of requests currently being responded to. in_flight_requests: FnvHashMap, @@ -176,10 +178,6 @@ pub struct BaseChannel { ghost: PhantomData<(Req, Resp)>, } -impl BaseChannel { - unsafe_unpinned!(in_flight_requests: FnvHashMap); -} - impl BaseChannel where T: Transport, ClientMessage>, @@ -204,19 +202,19 @@ where self.transport.get_ref() } - /// Returns the pinned inner transport. - pub fn transport<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> { - unsafe { self.map_unchecked_mut(|me| me.transport.get_mut()) } - } - fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) { // It's possible the request was already completed, so it's fine // if this is None. - if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) { - self.as_mut().in_flight_requests().compact(0.1); + if let Some(cancel_handle) = self + .as_mut() + .project() + .in_flight_requests + .remove(&request_id) + { + self.as_mut().project().in_flight_requests.compact(0.1); cancel_handle.abort(); - let remaining = self.as_mut().in_flight_requests().len(); + let remaining = self.as_mut().project().in_flight_requests.len(); trace!( "[{}] Request canceled. In-flight requests = {}", trace_context.trace_id, @@ -295,7 +293,7 @@ where fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { loop { - match ready!(self.as_mut().transport().poll_next(cx)?) { + match ready!(self.as_mut().project().transport.poll_next(cx)?) { Some(message) => match message { ClientMessage::Request(request) => { return Poll::Ready(Some(Ok(request))); @@ -321,28 +319,29 @@ where type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.transport().poll_ready(cx) + self.project().transport.poll_ready(cx) } fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { if self .as_mut() - .in_flight_requests() + .project() + .in_flight_requests .remove(&response.request_id) .is_some() { - self.as_mut().in_flight_requests().compact(0.1); + self.as_mut().project().in_flight_requests.compact(0.1); } - self.transport().start_send(response) + self.project().transport.start_send(response) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.transport().poll_flush(cx) + self.project().transport.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.transport().poll_close(cx) + self.project().transport.poll_close(cx) } } @@ -364,13 +363,14 @@ where } fn in_flight_requests(mut self: Pin<&mut Self>) -> usize { - self.as_mut().in_flight_requests().len() + self.as_mut().project().in_flight_requests.len() } fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration { let (abort_handle, abort_registration) = AbortHandle::new_pair(); assert!(self - .in_flight_requests() + .project() + .in_flight_requests .insert(request_id, abort_handle) .is_none()); abort_registration @@ -378,32 +378,24 @@ where } /// A running handler serving all requests coming over a channel. +#[pin_project] #[derive(Debug)] pub struct ClientHandler where C: Channel, { + #[pin] channel: C, /// Responses waiting to be written to the wire. + #[pin] pending_responses: Fuse)>>, /// Handed out to request handlers to fan in responses. + #[pin] responses_tx: mpsc::Sender<(context::Context, Response)>, /// Server server: S, } -impl ClientHandler -where - C: Channel, -{ - unsafe_pinned!(channel: C); - unsafe_pinned!(pending_responses: Fuse)>>); - unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response)>); - // For this to be safe, field f must be private, and code in this module must never - // construct PinMut. - unsafe_unpinned!(server: S); -} - impl ClientHandler where C: Channel, @@ -413,7 +405,7 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> PollIo> { - match ready!(self.as_mut().channel().poll_next(cx)?) { + match ready!(self.as_mut().project().channel.poll_next(cx)?) { Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))), None => Poll::Ready(None), } @@ -429,24 +421,24 @@ where trace!( "[{}] Staging response. In-flight requests = {}.", ctx.trace_id(), - self.as_mut().channel().in_flight_requests(), + self.as_mut().project().channel.in_flight_requests(), ); - self.as_mut().channel().start_send(response)?; + self.as_mut().project().channel.start_send(response)?; Poll::Ready(Some(Ok(()))) } Poll::Ready(None) => { // Shutdown can't be done before we finish pumping out remaining responses. - ready!(self.as_mut().channel().poll_flush(cx)?); + ready!(self.as_mut().project().channel.poll_flush(cx)?); Poll::Ready(None) } Poll::Pending => { // No more requests to process, so flush any requests buffered in the transport. - ready!(self.as_mut().channel().poll_flush(cx)?); + ready!(self.as_mut().project().channel.poll_flush(cx)?); // Being here means there are no staged requests and all written responses are // fully flushed. So, if the read half is closed and there are no in-flight // requests, then we can close the write half. - if read_half_closed && self.as_mut().channel().in_flight_requests() == 0 { + if read_half_closed && self.as_mut().project().channel.in_flight_requests() == 0 { Poll::Ready(None) } else { Poll::Pending @@ -460,11 +452,11 @@ where cx: &mut Context<'_>, ) -> PollIo<(context::Context, Response)> { // Ensure there's room to write a response. - while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? { - ready!(self.as_mut().channel().poll_flush(cx)?); + while let Poll::Pending = self.as_mut().project().channel.poll_ready(cx)? { + ready!(self.as_mut().project().channel.poll_flush(cx)?); } - match ready!(self.as_mut().pending_responses().poll_next(cx)) { + match ready!(self.as_mut().project().pending_responses.poll_next(cx)) { Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))), None => { // This branch likely won't happen, since the ClientHandler is holding a Sender. @@ -489,7 +481,7 @@ where let ctx = request.context; let request = request.message; - let response = self.as_mut().server().clone().serve(ctx, request); + let response = self.as_mut().project().server.clone().serve(ctx, request); let response = Resp { state: RespState::PollResp, request_id, @@ -497,9 +489,9 @@ where deadline, f: Timeout::new(response, timeout), response: None, - response_tx: self.as_mut().responses_tx().clone(), + response_tx: self.as_mut().project().responses_tx.clone(), }; - let abort_registration = self.as_mut().channel().start_request(request_id); + let abort_registration = self.as_mut().project().channel.start_request(request_id); RequestHandler { resp: Abortable::new(response, abort_registration), } @@ -507,15 +499,13 @@ where } /// A future fulfilling a single client request. +#[pin_project] #[derive(Debug)] pub struct RequestHandler { + #[pin] resp: Abortable>, } -impl RequestHandler { - unsafe_pinned!(resp: Abortable>); -} - impl Future for RequestHandler where F: Future, @@ -523,19 +513,22 @@ where type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - let _ = ready!(self.resp().poll(cx)); + let _ = ready!(self.project().resp.poll(cx)); Poll::Ready(()) } } +#[pin_project] #[derive(Debug)] struct Resp { state: RespState, request_id: u64, ctx: context::Context, deadline: SystemTime, + #[pin] f: Timeout, response: Option>, + #[pin] response_tx: mpsc::Sender<(context::Context, Response)>, } @@ -546,13 +539,6 @@ enum RespState { PollFlush, } -impl Resp { - unsafe_pinned!(f: Timeout); - unsafe_pinned!(response_tx: mpsc::Sender<(context::Context, Response)>); - unsafe_unpinned!(response: Option>); - unsafe_unpinned!(state: RespState); -} - impl Future for Resp where F: Future, @@ -561,10 +547,10 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { loop { - match self.as_mut().state() { + match self.as_mut().project().state { RespState::PollResp => { - let result = ready!(self.as_mut().f().poll(cx)); - *self.as_mut().response() = Some(Response { + let result = ready!(self.as_mut().project().f.poll(cx)); + *self.as_mut().project().response = Some(Response { request_id: self.request_id, message: match result { Ok(message) => Ok(message), @@ -588,21 +574,27 @@ where }, _non_exhaustive: (), }); - *self.as_mut().state() = RespState::PollReady; + *self.as_mut().project().state = RespState::PollReady; } RespState::PollReady => { - let ready = ready!(self.as_mut().response_tx().poll_ready(cx)); + let ready = ready!(self.as_mut().project().response_tx.poll_ready(cx)); if ready.is_err() { return Poll::Ready(()); } - let resp = (self.ctx, self.as_mut().response().take().unwrap()); - if self.as_mut().response_tx().start_send(resp).is_err() { + let resp = (self.ctx, self.as_mut().project().response.take().unwrap()); + if self + .as_mut() + .project() + .response_tx + .start_send(resp) + .is_err() + { return Poll::Ready(()); } - *self.as_mut().state() = RespState::PollFlush; + *self.as_mut().project().state = RespState::PollFlush; } RespState::PollFlush => { - let ready = ready!(self.as_mut().response_tx().poll_flush(cx)); + let ready = ready!(self.as_mut().project().response_tx.poll_flush(cx)); if ready.is_err() { return Poll::Ready(()); } @@ -672,19 +664,15 @@ where /// A future that drives the server by spawning channels and request handlers on the default /// executor. +#[pin_project] #[derive(Debug)] #[cfg(feature = "tokio1")] pub struct Running { + #[pin] incoming: St, server: Se, } -#[cfg(feature = "tokio1")] -impl Running { - unsafe_pinned!(incoming: St); - unsafe_unpinned!(server: Se); -} - #[cfg(feature = "tokio1")] impl Future for Running where @@ -700,10 +688,10 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { use log::info; - while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) { + while let Some(channel) = ready!(self.as_mut().project().incoming.poll_next(cx)) { tokio::spawn( channel - .respond_with(self.as_mut().server().clone()) + .respond_with(self.as_mut().project().server.clone()) .execute(), ); } diff --git a/rpc/src/server/testing.rs b/rpc/src/server/testing.rs index 17a5efe..26bc1df 100644 --- a/rpc/src/server/testing.rs +++ b/rpc/src/server/testing.rs @@ -4,26 +4,23 @@ use fnv::FnvHashSet; use futures::future::{AbortHandle, AbortRegistration}; use futures::{Sink, Stream}; use futures_test::task::noop_waker_ref; -use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use pin_project::pin_project; use std::collections::VecDeque; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::SystemTime; +#[pin_project] pub(crate) struct FakeChannel { + #[pin] pub stream: VecDeque, + #[pin] pub sink: VecDeque, pub config: Config, pub in_flight_requests: FnvHashSet, } -impl FakeChannel { - unsafe_pinned!(stream: VecDeque); - unsafe_pinned!(sink: VecDeque); - unsafe_unpinned!(in_flight_requests: FnvHashSet); -} - impl Stream for FakeChannel where In: Unpin, @@ -31,7 +28,7 @@ where type Item = In; fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.stream().poll_next(cx) + self.project().stream.poll_next(cx) } } @@ -39,22 +36,26 @@ impl Sink> for FakeChannel> { type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.sink().poll_ready(cx).map_err(|e| match e {}) + self.project().sink.poll_ready(cx).map_err(|e| match e {}) } fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { self.as_mut() - .in_flight_requests() + .project() + .in_flight_requests .remove(&response.request_id); - self.sink().start_send(response).map_err(|e| match e {}) + self.project() + .sink + .start_send(response) + .map_err(|e| match e {}) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.sink().poll_flush(cx).map_err(|e| match e {}) + self.project().sink.poll_flush(cx).map_err(|e| match e {}) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.sink().poll_close(cx).map_err(|e| match e {}) + self.project().sink.poll_close(cx).map_err(|e| match e {}) } } @@ -74,7 +75,7 @@ where } fn start_request(self: Pin<&mut Self>, id: u64) -> AbortRegistration { - self.in_flight_requests().insert(id); + self.project().in_flight_requests.insert(id); AbortHandle::new_pair().1 } } diff --git a/rpc/src/server/throttle.rs b/rpc/src/server/throttle.rs index 7c2ee79..7da7080 100644 --- a/rpc/src/server/throttle.rs +++ b/rpc/src/server/throttle.rs @@ -7,21 +7,20 @@ use futures::{ task::{Context, Poll}, }; use log::debug; -use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use pin_project::pin_project; use std::{io, pin::Pin}; /// A [`Channel`] that limits the number of concurrent /// requests by throttling. +#[pin_project] #[derive(Debug)] pub struct Throttler { max_in_flight_requests: usize, + #[pin] inner: C, } impl Throttler { - unsafe_unpinned!(max_in_flight_requests: usize); - unsafe_pinned!(inner: C); - /// Returns the inner channel. pub fn get_ref(&self) -> &C { &self.inner @@ -49,16 +48,17 @@ where type Item = ::Item; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - while self.as_mut().in_flight_requests() >= *self.as_mut().max_in_flight_requests() { - ready!(self.as_mut().inner().poll_ready(cx)?); + while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests + { + ready!(self.as_mut().project().inner.poll_ready(cx)?); - match ready!(self.as_mut().inner().poll_next(cx)?) { + match ready!(self.as_mut().project().inner.poll_next(cx)?) { Some(request) => { debug!( "[{}] Client has reached in-flight request limit ({}/{}).", request.context.trace_id(), self.as_mut().in_flight_requests(), - self.as_mut().max_in_flight_requests(), + self.as_mut().project().max_in_flight_requests, ); self.as_mut().start_send(Response { @@ -74,7 +74,7 @@ where None => return Poll::Ready(None), } } - self.inner().poll_next(cx) + self.project().inner.poll_next(cx) } } @@ -85,19 +85,19 @@ where type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.inner().poll_ready(cx) + self.project().inner.poll_ready(cx) } fn start_send(self: Pin<&mut Self>, item: Response<::Resp>) -> io::Result<()> { - self.inner().start_send(item) + self.project().inner.start_send(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.inner().poll_flush(cx) + self.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.inner().poll_close(cx) + self.project().inner.poll_close(cx) } } @@ -115,7 +115,7 @@ where type Resp = ::Resp; fn in_flight_requests(self: Pin<&mut Self>) -> usize { - self.inner().in_flight_requests() + self.project().inner.in_flight_requests() } fn config(&self) -> &Config { @@ -123,13 +123,15 @@ where } fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration { - self.inner().start_request(request_id) + self.project().inner.start_request(request_id) } } /// A stream of throttling channels. +#[pin_project] #[derive(Debug)] pub struct ThrottlerStream { + #[pin] inner: S, max_in_flight_requests: usize, } @@ -139,9 +141,6 @@ where S: Stream, ::Item: Channel, { - unsafe_pinned!(inner: S); - unsafe_unpinned!(max_in_flight_requests: usize); - pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self { Self { inner, @@ -158,10 +157,10 @@ where type Item = Throttler<::Item>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - match ready!(self.as_mut().inner().poll_next(cx)) { + match ready!(self.as_mut().project().inner.poll_next(cx)) { Some(channel) => Poll::Ready(Some(Throttler::new( channel, - *self.max_in_flight_requests(), + *self.project().max_in_flight_requests, ))), None => Poll::Ready(None), } diff --git a/rpc/src/transport/channel.rs b/rpc/src/transport/channel.rs index 76bd5b8..45cb0fc 100644 --- a/rpc/src/transport/channel.rs +++ b/rpc/src/transport/channel.rs @@ -8,7 +8,7 @@ use crate::PollIo; use futures::{channel::mpsc, task::Context, Poll, Sink, Stream}; -use pin_utils::unsafe_pinned; +use pin_project::pin_project; use std::io; use std::pin::Pin; @@ -28,22 +28,20 @@ pub fn unbounded() -> ( /// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender) /// and [`UnboundedReceiver`](mpsc::UnboundedReceiver). +#[pin_project] #[derive(Debug)] pub struct UnboundedChannel { + #[pin] rx: mpsc::UnboundedReceiver, + #[pin] tx: mpsc::UnboundedSender, } -impl UnboundedChannel { - unsafe_pinned!(rx: mpsc::UnboundedReceiver); - unsafe_pinned!(tx: mpsc::UnboundedSender); -} - impl Stream for UnboundedChannel { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo { - self.rx().poll_next(cx).map(|option| option.map(Ok)) + self.project().rx.poll_next(cx).map(|option| option.map(Ok)) } } @@ -51,25 +49,29 @@ impl Sink for UnboundedChannel { type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.tx() + self.project() + .tx .poll_ready(cx) .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) } fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { - self.tx() + self.project() + .tx .start_send(item) .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.tx() + self.project() + .tx .poll_flush(cx) .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.tx() + self.project() + .tx .poll_close(cx) .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) }