From e75193c191835ce4e493d65213ab704d6f4e363d Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Sun, 7 Mar 2021 17:33:31 -0800 Subject: [PATCH] Client RPCs now take &self. This required the breaking change of removing the Client trait. The intent of the Client trait was to facilitate the decorator pattern by allowing users to create their own Clients that added behavior on top of the base client. Unfortunately, this trait had become a maintenance burden, consistently causing issues with lifetimes and the lack of generic associated types. Specifically, it meant that Client impls could not use async fns, which is no longer tenable today. --- example-service/src/client.rs | 3 +- plugins/src/lib.rs | 30 +-- tarpc/examples/compression.rs | 3 +- tarpc/examples/pubsub.rs | 4 +- tarpc/examples/readme.rs | 2 +- tarpc/examples/server_calling_server.rs | 4 +- tarpc/src/client.rs | 89 ------- tarpc/src/client/channel.rs | 307 ++++-------------------- tarpc/src/client/in_flight_requests.rs | 3 +- tarpc/src/lib.rs | 2 +- tarpc/src/server.rs | 11 +- tarpc/src/server/filter.rs | 32 +-- tarpc/src/transport/channel.rs | 40 ++- tarpc/tests/dataservice.rs | 2 +- tarpc/tests/service_functional.rs | 35 +-- 15 files changed, 114 insertions(+), 453 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 0ff853d..3f313db 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -47,8 +47,7 @@ async fn main() -> io::Result<()> { // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. - let mut client = - service::WorldClient::new(client::Config::default(), transport.await?).spawn()?; + let client = service::WorldClient::new(client::Config::default(), transport.await?).spawn()?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index c52711e..0964b29 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -664,26 +664,7 @@ impl<'a> ServiceGenerator<'a> { #[derive(Clone, Debug)] /// The client stub that makes RPC calls to the server. ALl request methods return /// [Futures](std::future::Future). - #vis struct #client_ident>(C); - } - } - - fn impl_from_for_client(&self) -> TokenStream2 { - let &Self { - client_ident, - request_ident, - response_ident, - .. - } = self; - - quote! { - impl From for #client_ident - where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident> - { - fn from(client: C) -> Self { - #client_ident(client) - } - } + #vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>); } } @@ -734,16 +715,14 @@ impl<'a> ServiceGenerator<'a> { } = self; quote! { - impl #client_ident - where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident> - { + impl #client_ident { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents(&mut self, ctx: tarpc::context::Context, #( #args ),*) + #vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*) -> impl std::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; - let resp = tarpc::Client::call(&mut self.0, ctx, request); + let resp = self.0.call(ctx, request); async move { match resp.await? { #response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg), @@ -769,7 +748,6 @@ impl<'a> ToTokens for ServiceGenerator<'a> { self.impl_debug_for_response_future(), self.impl_future_for_response_future(), self.struct_client(), - self.impl_from_for_client(), self.impl_client_new(), self.impl_client_rpc_methods(), ]) diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index f0faee7..4582b07 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -118,8 +118,7 @@ async fn main() -> anyhow::Result<()> { }); let transport = tcp::connect(addr, Bincode::default).await?; - let mut client = - WorldClient::new(client::Config::default(), add_compression(transport)).spawn()?; + let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn()?; println!( "{}", diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 6b43bc6..3d94c75 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -203,7 +203,7 @@ impl Publisher { async fn initialize_subscription( &mut self, subscriber_addr: SocketAddr, - mut subscriber: subscriber::SubscriberClient, + subscriber: subscriber::SubscriberClient, ) { // Populate the topics if let Ok(topics) = subscriber.topics(context::current()).await { @@ -305,7 +305,7 @@ async fn main() -> anyhow::Result<()> { ) .await?; - let mut publisher = publisher::PublisherClient::new( + let publisher = publisher::PublisherClient::new( client::Config::default(), tcp::connect(addrs.publisher, Json::default).await?, ) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index b85a0b8..9bd6e4b 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -43,7 +43,7 @@ async fn main() -> io::Result<()> { // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` // that takes a config and any Transport as input. - let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?; + let client = WorldClient::new(client::Config::default(), client_transport).spawn()?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/tarpc/examples/server_calling_server.rs b/tarpc/examples/server_calling_server.rs index eb066f0..f3f518a 100644 --- a/tarpc/examples/server_calling_server.rs +++ b/tarpc/examples/server_calling_server.rs @@ -46,7 +46,7 @@ struct DoubleServer { #[tarpc::server] impl DoubleService for DoubleServer { - async fn double(mut self, _: context::Context, x: i32) -> Result { + async fn double(self, _: context::Context, x: i32) -> Result { self.add_client .add(context::current(), x, x) .await @@ -82,7 +82,7 @@ async fn main() -> io::Result<()> { tokio::spawn(double_server); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - let mut double_client = + let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?; for i in 1..=5 { diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 4606505..8cc24ae 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -6,7 +6,6 @@ //! Provides a client that connects to a server and sends multiplexed requests. -use crate::context; use futures::prelude::*; use std::fmt; use std::io; @@ -16,94 +15,6 @@ pub mod channel; mod in_flight_requests; pub use channel::{new, Channel}; -/// Sends multiplexed requests to, and receives responses from, a server. -pub trait Client<'a, Req> { - /// The response type. - type Response; - - /// The future response. - type Future: Future> + 'a; - - /// Initiates a request, sending it to the dispatch task. - /// - /// Returns a [`Future`] that resolves to this client and the future response - /// once the request is successfully enqueued. - /// - /// [`Future`]: futures::Future - fn call(&'a mut self, ctx: context::Context, request: Req) -> Self::Future; - - /// Returns a Client that applies a post-processing function to the returned response. - fn map_response(self, f: F) -> MapResponse - where - F: FnMut(Self::Response) -> R, - Self: Sized, - { - MapResponse { inner: self, f } - } - - /// Returns a Client that applies a pre-processing function to the request. - fn with_request(self, f: F) -> WithRequest - where - F: FnMut(Req2) -> Req, - Self: Sized, - { - WithRequest { inner: self, f } - } -} - -/// A Client that applies a function to the returned response. -#[derive(Clone, Debug)] -pub struct MapResponse { - inner: C, - f: F, -} - -impl<'a, C, F, Req, Resp, Resp2> Client<'a, Req> for MapResponse -where - C: Client<'a, Req, Response = Resp>, - F: FnMut(Resp) -> Resp2 + 'a, -{ - type Response = Resp2; - type Future = futures::future::MapOk<>::Future, &'a mut F>; - - fn call(&'a mut self, ctx: context::Context, request: Req) -> Self::Future { - self.inner.call(ctx, request).map_ok(&mut self.f) - } -} - -/// A Client that applies a pre-processing function to the request. -#[derive(Clone, Debug)] -pub struct WithRequest { - inner: C, - f: F, -} - -impl<'a, C, F, Req, Req2, Resp> Client<'a, Req2> for WithRequest -where - C: Client<'a, Req, Response = Resp>, - F: FnMut(Req2) -> Req, -{ - type Response = Resp; - type Future = >::Future; - - fn call(&'a mut self, ctx: context::Context, request: Req2) -> Self::Future { - self.inner.call(ctx, (self.f)(request)) - } -} - -impl<'a, Req, Resp> Client<'a, Req> for Channel -where - Req: 'a, - Resp: 'a, -{ - type Response = Resp; - type Future = channel::Call<'a, Req, Resp>; - - fn call(&'a mut self, ctx: context::Context, request: Req) -> channel::Call<'a, Req, Resp> { - self.call(ctx, request) - } -} - /// Settings that control the behavior of the client. #[derive(Clone, Debug)] #[non_exhaustive] diff --git a/tarpc/src/client/channel.rs b/tarpc/src/client/channel.rs index 851cf2c..bd12135 100644 --- a/tarpc/src/client/channel.rs +++ b/tarpc/src/client/channel.rs @@ -8,13 +8,7 @@ use crate::{ client::in_flight_requests::InFlightRequests, context, trace::SpanId, ClientMessage, PollContext, PollIo, Request, Response, Transport, }; -use futures::{ - channel::{mpsc, oneshot}, - prelude::*, - ready, - stream::Fuse, - task::*, -}; +use futures::{prelude::*, ready, stream::Fuse, task::*}; use log::{info, trace}; use pin_project::{pin_project, pinned_drop}; use std::{ @@ -26,6 +20,7 @@ use std::{ Arc, }, }; +use tokio::sync::{mpsc, oneshot}; #[allow(dead_code)] #[allow(clippy::no_effect)] @@ -58,49 +53,14 @@ impl Clone for Channel { } } -/// A future returned by [`Channel::send`] that resolves to a server response. -#[pin_project] -#[derive(Debug)] -#[must_use = "futures do nothing unless polled"] -struct Send<'a, Req, Resp> { - #[pin] - fut: MapOkDispatchResponse, Resp>, -} - -type SendMapErrConnectionReset<'a, Req, Resp> = MapErrConnectionReset< - futures::sink::Send<'a, mpsc::Sender>, DispatchRequest>, ->; - -impl<'a, Req, Resp> Future for Send<'a, Req, Resp> { - type Output = io::Result>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.as_mut().project().fut.poll(cx) - } -} - -/// A future returned by [`Channel::call`] that resolves to a server response. -#[pin_project] -#[derive(Debug)] -#[must_use = "futures do nothing unless polled"] -pub struct Call<'a, Req, Resp> { - #[pin] - fut: AndThenIdent, DispatchResponse>, -} - -impl<'a, Req, Resp> Future for Call<'a, Req, Resp> { - type Output = io::Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let resp = ready!(self.as_mut().project().fut.poll(cx)); - Poll::Ready(resp) - } -} - impl Channel { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves when the request is sent (not when the response is received). - fn send(&mut self, mut ctx: context::Context, request: Req) -> Send { + fn send( + &self, + mut ctx: context::Context, + request: Req, + ) -> impl Future>> + '_ { // Convert the context to the call context. ctx.trace_context.parent_id = Some(ctx.trace_context.span_id); ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng()); @@ -109,31 +69,39 @@ impl Channel { let cancellation = self.cancellation.clone(); let request_id = u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); - Send { - fut: MapOkDispatchResponse::new( - MapErrConnectionReset::new(self.to_dispatch.send(DispatchRequest { + + // DispatchResponse impls Drop to cancel in-flight requests. It should be created before + // sending out the request; otherwise, the response future could be dropped after the + // request is sent out but before DispatchResponse is created, rendering the cancellation + // logic inactive. + let response = DispatchResponse { + response, + complete: false, + request_id, + cancellation, + ctx, + }; + async move { + self.to_dispatch + .send(DispatchRequest { ctx, request_id, request, response_completion, - })), - DispatchResponse { - response, - complete: false, - request_id, - cancellation, - ctx, - }, - ), + }) + .await + .map_err(|mpsc::error::SendError(_)| { + io::Error::from(io::ErrorKind::ConnectionReset) + })?; + Ok(response) } } /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. - pub fn call(&mut self, ctx: context::Context, request: Req) -> Call { - Call { - fut: AndThenIdent::new(self.send(ctx, request)), - } + pub async fn call(&self, ctx: context::Context, request: Req) -> io::Result { + let dispatch_response = self.send(ctx, request).await?; + dispatch_response.await } } @@ -157,7 +125,7 @@ impl Future for DispatchResponse { self.complete = true; Poll::Ready(match resp { Ok(resp) => Ok(resp.message?), - Err(oneshot::Canceled) => { + Err(oneshot::error::RecvError { .. }) => { // The oneshot is Canceled when the dispatch task ends. In that case, // there's nothing listening on the other side, so there's no point in // propagating cancellation. @@ -200,7 +168,7 @@ where { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); - let canceled_requests = canceled_requests.fuse(); + let canceled_requests = canceled_requests; NewClient { client: Channel { @@ -213,7 +181,7 @@ where canceled_requests, transport: transport.fuse(), in_flight_requests: InFlightRequests::default(), - pending_requests: pending_requests.fuse(), + pending_requests, }, } } @@ -228,10 +196,10 @@ pub struct RequestDispatch { transport: Fuse, /// Requests waiting to be written to the wire. #[pin] - pending_requests: Fuse>>, + pending_requests: mpsc::Receiver>, /// Requests that were dropped. #[pin] - canceled_requests: Fuse, + canceled_requests: CanceledRequests, /// Requests already written to the wire that haven't yet received responses. in_flight_requests: InFlightRequests, /// Configures limits to prevent unlimited resource usage. @@ -337,9 +305,9 @@ where } loop { - match ready!(self.as_mut().project().pending_requests.poll_next_unpin(cx)) { + match ready!(self.as_mut().project().pending_requests.poll_recv(cx)) { Some(request) => { - if request.response_completion.is_canceled() { + if request.response_completion.is_closed() { trace!( "[{}] Request canceled before being sent.", request.ctx.trace_id() @@ -496,14 +464,14 @@ fn cancellations() -> (RequestCancellation, CanceledRequests) { // bounded by the number of in-flight requests. Additionally, each request has a clone // of the sender, so the bounded channel would have the same behavior, // since it guarantees a slot. - let (tx, rx) = mpsc::unbounded(); + let (tx, rx) = mpsc::unbounded_channel(); (RequestCancellation(tx), CanceledRequests(rx)) } impl RequestCancellation { /// Cancels the request with ID `request_id`. fn cancel(&mut self, request_id: u64) { - let _ = self.0.unbounded_send(request_id); + let _ = self.0.send(request_id); } } @@ -511,184 +479,7 @@ impl Stream for CanceledRequests { type Item = u64; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.0.poll_next_unpin(cx) - } -} - -#[pin_project] -#[derive(Debug)] -#[must_use = "futures do nothing unless polled"] -struct MapErrConnectionReset { - #[pin] - future: Fut, - finished: Option<()>, -} - -impl MapErrConnectionReset { - fn new(future: Fut) -> MapErrConnectionReset { - MapErrConnectionReset { - future, - finished: Some(()), - } - } -} - -impl Future for MapErrConnectionReset -where - Fut: TryFuture, -{ - type Output = io::Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.as_mut().project().future.try_poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(result) => { - self.project().finished.take().expect( - "MapErrConnectionReset must not be polled after it returned `Poll::Ready`", - ); - Poll::Ready(result.map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset))) - } - } - } -} - -#[pin_project] -#[derive(Debug)] -#[must_use = "futures do nothing unless polled"] -struct MapOkDispatchResponse { - #[pin] - future: Fut, - response: Option>, -} - -impl MapOkDispatchResponse { - fn new(future: Fut, response: DispatchResponse) -> MapOkDispatchResponse { - MapOkDispatchResponse { - future, - response: Some(response), - } - } -} - -impl Future for MapOkDispatchResponse -where - Fut: TryFuture, -{ - type Output = Result, Fut::Error>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.as_mut().project().future.try_poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(result) => { - let response = self - .as_mut() - .project() - .response - .take() - .expect("MapOk must not be polled after it returned `Poll::Ready`"); - Poll::Ready(result.map(|_| response)) - } - } - } -} - -#[pin_project] -#[derive(Debug)] -#[must_use = "futures do nothing unless polled"] -struct AndThenIdent { - #[pin] - try_chain: TryChain, -} - -impl AndThenIdent -where - Fut1: TryFuture, - Fut2: TryFuture, -{ - /// Creates a new `Then`. - fn new(future: Fut1) -> AndThenIdent { - AndThenIdent { - try_chain: TryChain::new(future), - } - } -} - -impl Future for AndThenIdent -where - Fut1: TryFuture, - Fut2: TryFuture, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project().try_chain.poll(cx, |result| match result { - Ok(ok) => TryChainAction::Future(ok), - Err(err) => TryChainAction::Output(Err(err)), - }) - } -} - -#[pin_project(project = TryChainProj)] -#[must_use = "futures do nothing unless polled"] -#[derive(Debug)] -enum TryChain { - First(#[pin] Fut1), - Second(#[pin] Fut2), - Empty, -} - -enum TryChainAction -where - Fut2: TryFuture, -{ - Future(Fut2), - Output(Result), -} - -impl TryChain -where - Fut1: TryFuture, - Fut2: TryFuture, -{ - fn new(fut1: Fut1) -> TryChain { - TryChain::First(fut1) - } - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - f: F, - ) -> Poll> - where - F: FnOnce(Result) -> TryChainAction, - { - let mut f = Some(f); - - loop { - let output = match self.as_mut().project() { - TryChainProj::First(fut1) => { - // Poll the first future - match fut1.try_poll(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(output) => output, - } - } - TryChainProj::Second(fut2) => { - // Poll the second future - return fut2.try_poll(cx); - } - TryChainProj::Empty => { - panic!("future must not be polled after it returned `Poll::Ready`"); - } - }; - - self.set(TryChain::Empty); // Drop fut1 - let f = f.take().unwrap(); - match f(output) { - TryChainAction::Future(fut2) => self.set(TryChain::Second(fut2)), - TryChainAction::Output(output) => return Poll::Ready(output), - } - } + self.0.poll_recv(cx) } } @@ -704,12 +495,9 @@ mod tests { transport::{self, channel::UnboundedChannel}, ClientMessage, Response, }; - use futures::{ - channel::{mpsc, oneshot}, - prelude::*, - task::*, - }; + use futures::{prelude::*, task::*}; use std::{pin::Pin, sync::atomic::AtomicUsize, sync::Arc}; + use tokio::sync::{mpsc, oneshot}; #[tokio::test] async fn dispatch_response_cancels_on_drop() { @@ -723,7 +511,8 @@ mod tests { ctx: context::current(), }); // resp's drop() is run, which should send a cancel message. - assert_eq!(canceled_requests.0.try_next().unwrap(), Some(3)); + let cx = &mut Context::from_waker(&noop_waker_ref()); + assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(Some(3))); } #[tokio::test] @@ -824,13 +613,13 @@ mod tests { let _ = env_logger::try_init(); let (to_dispatch, pending_requests) = mpsc::channel(1); - let (cancel_tx, canceled_requests) = mpsc::unbounded(); + let (cancel_tx, canceled_requests) = mpsc::unbounded_channel(); let (client_channel, server_channel) = transport::channel::unbounded(); let dispatch = RequestDispatch:: { transport: client_channel.fuse(), - pending_requests: pending_requests.fuse(), - canceled_requests: CanceledRequests(canceled_requests).fuse(), + pending_requests: pending_requests, + canceled_requests: CanceledRequests(canceled_requests), in_flight_requests: InFlightRequests::default(), config: Config::default(), }; diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 63ae797..cdfa63d 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -4,13 +4,14 @@ use crate::{ PollIo, Response, ServerError, }; use fnv::FnvHashMap; -use futures::{channel::oneshot, ready}; +use futures::ready; use log::{debug, trace}; use std::{ collections::hash_map, io, task::{Context, Poll}, }; +use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; /// Requests already written to the wire that haven't yet received responses. diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index a3b161d..f08e471 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -299,7 +299,7 @@ pub mod server; pub mod transport; pub(crate) mod util; -pub use crate::{client::Client, transport::sealed::Transport}; +pub use crate::transport::sealed::Transport; use anyhow::Context as _; use futures::task::*; diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index f24e2cb..70485d7 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -8,7 +8,6 @@ use crate::{context, ClientMessage, PollIo, Request, Response, ServerError, Transport}; use futures::{ - channel::mpsc, future::{AbortRegistration, Abortable}, prelude::*, ready, @@ -19,6 +18,7 @@ use humantime::format_rfc3339; use log::{debug, info, trace}; use pin_project::pin_project; use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime}; +use tokio::sync::mpsc; mod filter; mod in_flight_requests; @@ -244,7 +244,6 @@ where Self: Sized, { let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer); - let responses = responses.fuse(); Requests { channel: self, @@ -387,7 +386,7 @@ where channel: C, /// Responses waiting to be written to the wire. #[pin] - pending_responses: Fuse)>>, + pending_responses: mpsc::Receiver<(context::Context, Response)>, /// Handed out to request handlers to fan in responses. #[pin] responses_tx: mpsc::Sender<(context::Context, Response)>, @@ -455,7 +454,7 @@ where request_id, message: Err(ServerError { kind: io::ErrorKind::TimedOut, - detail: Some(format!("Request did not complete before deadline.")), + detail: Some(String::from("Request did not complete before deadline.")), }), })?; return Poll::Ready(Some(Ok(()))); @@ -506,7 +505,7 @@ where ready!(self.as_mut().project().channel.poll_flush(cx)?); } - match ready!(self.as_mut().project().pending_responses.poll_next(cx)) { + match ready!(self.as_mut().project().pending_responses.poll_recv(cx)) { Some(response) => Poll::Ready(Some(Ok(response))), None => { // This branch likely won't happen, since the Requests stream is holding a Sender. @@ -556,7 +555,7 @@ impl InFlightRequest { let Self { abort_registration, request, - mut response_tx, + response_tx, } = self; Abortable::new( async move { diff --git a/tarpc/src/server/filter.rs b/tarpc/src/server/filter.rs index 902cede..2fc8d60 100644 --- a/tarpc/src/server/filter.rs +++ b/tarpc/src/server/filter.rs @@ -10,7 +10,7 @@ use crate::{ PollIo, }; use fnv::FnvHashMap; -use futures::{channel::mpsc, future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*}; +use futures::{future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*}; use log::{debug, info, trace}; use pin_project::pin_project; use std::sync::{Arc, Weak}; @@ -18,6 +18,7 @@ use std::{ collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin, time::SystemTime, }; +use tokio::sync::mpsc; /// A single-threaded filter that drops channels based on per-key limits. #[pin_project] @@ -55,7 +56,7 @@ struct Tracker { impl Drop for Tracker { fn drop(&mut self) { // Don't care if the listener is dropped. - let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap()); + let _ = self.dropped_keys.send(self.key.take().unwrap()); } } @@ -147,7 +148,7 @@ where { /// Sheds new channels to stay under configured limits. pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self { - let (dropped_keys_tx, dropped_keys) = mpsc::unbounded(); + let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel(); ChannelFilter { listener: listener.fuse(), channels_per_key, @@ -233,7 +234,7 @@ where } fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - match ready!(self.as_mut().project().dropped_keys.poll_next_unpin(cx)) { + match ready!(self.as_mut().project().dropped_keys.poll_recv(cx)) { Some(key) => { debug!("All channels dropped for key [{}]", key); self.as_mut().project().key_counts.remove(&key); @@ -278,7 +279,6 @@ where } } } - #[cfg(test)] fn ctx() -> Context<'static> { use futures::task::*; @@ -290,12 +290,12 @@ fn ctx() -> Context<'static> { fn tracker_drop() { use assert_matches::assert_matches; - let (tx, mut rx) = mpsc::unbounded(); + let (tx, mut rx) = mpsc::unbounded_channel(); Tracker { key: Some(1), dropped_keys: tx, }; - assert_matches!(rx.try_next(), Ok(Some(1))); + assert_matches!(rx.poll_recv(&mut ctx()), Poll::Ready(Some(1))); } #[test] @@ -303,8 +303,8 @@ fn tracked_channel_stream() { use assert_matches::assert_matches; use pin_utils::pin_mut; - let (chan_tx, chan) = mpsc::unbounded(); - let (dropped_keys, _) = mpsc::unbounded(); + let (chan_tx, chan) = futures::channel::mpsc::unbounded(); + let (dropped_keys, _) = mpsc::unbounded_channel(); let channel = TrackedChannel { inner: chan, tracker: Arc::new(Tracker { @@ -323,8 +323,8 @@ fn tracked_channel_sink() { use assert_matches::assert_matches; use pin_utils::pin_mut; - let (chan, mut chan_rx) = mpsc::unbounded(); - let (dropped_keys, _) = mpsc::unbounded(); + let (chan, mut chan_rx) = futures::channel::mpsc::unbounded(); + let (dropped_keys, _) = mpsc::unbounded_channel(); let channel = TrackedChannel { inner: chan, tracker: Arc::new(Tracker { @@ -348,7 +348,7 @@ fn channel_filter_increment_channels_for_key() { struct TestChannel { key: &'static str, } - let (_, listener) = mpsc::unbounded(); + let (_, listener) = futures::channel::mpsc::unbounded(); let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap(); @@ -369,7 +369,7 @@ fn channel_filter_handle_new_channel() { struct TestChannel { key: &'static str, } - let (_, listener) = mpsc::unbounded(); + let (_, listener) = futures::channel::mpsc::unbounded(); let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); let channel1 = filter @@ -401,7 +401,7 @@ fn channel_filter_poll_listener() { struct TestChannel { key: &'static str, } - let (new_channels, listener) = mpsc::unbounded(); + let (new_channels, listener) = futures::channel::mpsc::unbounded(); let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); @@ -437,7 +437,7 @@ fn channel_filter_poll_closed_channels() { struct TestChannel { key: &'static str, } - let (new_channels, listener) = mpsc::unbounded(); + let (new_channels, listener) = futures::channel::mpsc::unbounded(); let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); @@ -465,7 +465,7 @@ fn channel_filter_stream() { struct TestChannel { key: &'static str, } - let (new_channels, listener) = mpsc::unbounded(); + let (new_channels, listener) = futures::channel::mpsc::unbounded(); let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 3223732..59760a2 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -7,10 +7,11 @@ //! Transports backed by in-memory channels. use crate::PollIo; -use futures::{channel::mpsc, task::*, Sink, Stream}; +use futures::{task::*, Sink, Stream}; use pin_project::pin_project; use std::io; use std::pin::Pin; +use tokio::sync::mpsc; /// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's /// [`Sink`]. @@ -18,8 +19,8 @@ pub fn unbounded() -> ( UnboundedChannel, UnboundedChannel, ) { - let (tx1, rx2) = mpsc::unbounded(); - let (tx2, rx1) = mpsc::unbounded(); + let (tx1, rx2) = mpsc::unbounded_channel(); + let (tx2, rx1) = mpsc::unbounded_channel(); ( UnboundedChannel { tx: tx1, rx: rx1 }, UnboundedChannel { tx: tx2, rx: rx2 }, @@ -41,39 +42,36 @@ impl Stream for UnboundedChannel { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo { - self.project().rx.poll_next(cx).map(|option| option.map(Ok)) + self.project().rx.poll_recv(cx).map(|option| option.map(Ok)) } } impl Sink for UnboundedChannel { type Error = io::Error; - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .tx - .poll_ready(cx) - .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) + fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(if self.project().tx.is_closed() { + Err(io::Error::from(io::ErrorKind::NotConnected)) + } else { + Ok(()) + }) } fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { self.project() .tx - .start_send(item) + .send(item) .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .tx - .poll_flush(cx) - .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + // UnboundedSender requires no flushing. + Poll::Ready(Ok(())) } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .tx - .poll_close(cx) - .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) + fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + // UnboundedSender can't initiate closure. + Poll::Ready(Ok(())) } } @@ -108,7 +106,7 @@ mod tests { }), ); - let mut client = client::new(client::Config::default(), client_channel).spawn()?; + let client = client::new(client::Config::default(), client_channel).spawn()?; let response1 = client.call(context::current(), "123".into()).await?; let response2 = client.call(context::current(), "abc".into()).await?; diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 96e06a0..1bff8a1 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -45,7 +45,7 @@ async fn test_call() -> io::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; - let mut client = ColorProtocolClient::new(client::Config::default(), transport).spawn()?; + let client = ColorProtocolClient::new(client::Config::default(), transport).spawn()?; let color = client .get_opposite_color(context::current(), TestData::White) diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 8a00274..488e37c 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -50,7 +50,7 @@ async fn sequential() -> io::Result<()> { .execute(Server.serve()), ); - let mut client = ServiceClient::new(client::Config::default(), tx).spawn()?; + let client = ServiceClient::new(client::Config::default(), tx).spawn()?; assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); assert_matches!( @@ -89,7 +89,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> { // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. tokio::spawn(async move { - let mut client = LoopClient::new(client::Config::default(), tx) + let client = LoopClient::new(client::Config::default(), tx) .spawn() .unwrap(); @@ -130,7 +130,7 @@ async fn serde() -> io::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; - let mut client = ServiceClient::new(client::Config::default(), transport).spawn()?; + let client = ServiceClient::new(client::Config::default(), transport).spawn()?; assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); assert_matches!( @@ -154,14 +154,9 @@ async fn concurrent() -> io::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn()?; - let mut c = client.clone(); - let req1 = c.add(context::current(), 1, 2); - - let mut c = client.clone(); - let req2 = c.add(context::current(), 3, 4); - - let mut c = client.clone(); - let req3 = c.hey(context::current(), "Tim".to_string()); + let req1 = client.add(context::current(), 1, 2); + let req2 = client.add(context::current(), 3, 4); + let req3 = client.hey(context::current(), "Tim".to_string()); assert_matches!(req1.await, Ok(3)); assert_matches!(req2.await, Ok(7)); @@ -183,14 +178,9 @@ async fn concurrent_join() -> io::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn()?; - let mut c = client.clone(); - let req1 = c.add(context::current(), 1, 2); - - let mut c = client.clone(); - let req2 = c.add(context::current(), 3, 4); - - let mut c = client.clone(); - let req3 = c.hey(context::current(), "Tim".to_string()); + let req1 = client.add(context::current(), 1, 2); + let req2 = client.add(context::current(), 3, 4); + let req3 = client.hey(context::current(), "Tim".to_string()); let (resp1, resp2, resp3) = join!(req1, req2, req3); assert_matches!(resp1, Ok(3)); @@ -213,11 +203,8 @@ async fn concurrent_join_all() -> io::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn()?; - let mut c1 = client.clone(); - let mut c2 = client.clone(); - - let req1 = c1.add(context::current(), 1, 2); - let req2 = c2.add(context::current(), 3, 4); + let req1 = client.add(context::current(), 1, 2); + let req2 = client.add(context::current(), 3, 4); let responses = join_all(vec![req1, req2]).await; assert_matches!(responses[0], Ok(3));