diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 0ff853d..3f313db 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -47,8 +47,7 @@ async fn main() -> io::Result<()> { // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. - let mut client = - service::WorldClient::new(client::Config::default(), transport.await?).spawn()?; + let client = service::WorldClient::new(client::Config::default(), transport.await?).spawn()?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index c52711e..0964b29 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -664,26 +664,7 @@ impl<'a> ServiceGenerator<'a> { #[derive(Clone, Debug)] /// The client stub that makes RPC calls to the server. ALl request methods return /// [Futures](std::future::Future). - #vis struct #client_ident>(C); - } - } - - fn impl_from_for_client(&self) -> TokenStream2 { - let &Self { - client_ident, - request_ident, - response_ident, - .. - } = self; - - quote! { - impl From for #client_ident - where for <'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident> - { - fn from(client: C) -> Self { - #client_ident(client) - } - } + #vis struct #client_ident(tarpc::client::Channel<#request_ident, #response_ident>); } } @@ -734,16 +715,14 @@ impl<'a> ServiceGenerator<'a> { } = self; quote! { - impl #client_ident - where for<'a> C: tarpc::Client<'a, #request_ident, Response = #response_ident> - { + impl #client_ident { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents(&mut self, ctx: tarpc::context::Context, #( #args ),*) + #vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*) -> impl std::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; - let resp = tarpc::Client::call(&mut self.0, ctx, request); + let resp = self.0.call(ctx, request); async move { match resp.await? { #response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg), @@ -769,7 +748,6 @@ impl<'a> ToTokens for ServiceGenerator<'a> { self.impl_debug_for_response_future(), self.impl_future_for_response_future(), self.struct_client(), - self.impl_from_for_client(), self.impl_client_new(), self.impl_client_rpc_methods(), ]) diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index f0faee7..4582b07 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -118,8 +118,7 @@ async fn main() -> anyhow::Result<()> { }); let transport = tcp::connect(addr, Bincode::default).await?; - let mut client = - WorldClient::new(client::Config::default(), add_compression(transport)).spawn()?; + let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn()?; println!( "{}", diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 6b43bc6..3d94c75 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -203,7 +203,7 @@ impl Publisher { async fn initialize_subscription( &mut self, subscriber_addr: SocketAddr, - mut subscriber: subscriber::SubscriberClient, + subscriber: subscriber::SubscriberClient, ) { // Populate the topics if let Ok(topics) = subscriber.topics(context::current()).await { @@ -305,7 +305,7 @@ async fn main() -> anyhow::Result<()> { ) .await?; - let mut publisher = publisher::PublisherClient::new( + let publisher = publisher::PublisherClient::new( client::Config::default(), tcp::connect(addrs.publisher, Json::default).await?, ) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index b85a0b8..9bd6e4b 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -43,7 +43,7 @@ async fn main() -> io::Result<()> { // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` // that takes a config and any Transport as input. - let mut client = WorldClient::new(client::Config::default(), client_transport).spawn()?; + let client = WorldClient::new(client::Config::default(), client_transport).spawn()?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/tarpc/examples/server_calling_server.rs b/tarpc/examples/server_calling_server.rs index eb066f0..f3f518a 100644 --- a/tarpc/examples/server_calling_server.rs +++ b/tarpc/examples/server_calling_server.rs @@ -46,7 +46,7 @@ struct DoubleServer { #[tarpc::server] impl DoubleService for DoubleServer { - async fn double(mut self, _: context::Context, x: i32) -> Result { + async fn double(self, _: context::Context, x: i32) -> Result { self.add_client .add(context::current(), x, x) .await @@ -82,7 +82,7 @@ async fn main() -> io::Result<()> { tokio::spawn(double_server); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - let mut double_client = + let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?; for i in 1..=5 { diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 4606505..8cc24ae 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -6,7 +6,6 @@ //! Provides a client that connects to a server and sends multiplexed requests. -use crate::context; use futures::prelude::*; use std::fmt; use std::io; @@ -16,94 +15,6 @@ pub mod channel; mod in_flight_requests; pub use channel::{new, Channel}; -/// Sends multiplexed requests to, and receives responses from, a server. -pub trait Client<'a, Req> { - /// The response type. - type Response; - - /// The future response. - type Future: Future> + 'a; - - /// Initiates a request, sending it to the dispatch task. - /// - /// Returns a [`Future`] that resolves to this client and the future response - /// once the request is successfully enqueued. - /// - /// [`Future`]: futures::Future - fn call(&'a mut self, ctx: context::Context, request: Req) -> Self::Future; - - /// Returns a Client that applies a post-processing function to the returned response. - fn map_response(self, f: F) -> MapResponse - where - F: FnMut(Self::Response) -> R, - Self: Sized, - { - MapResponse { inner: self, f } - } - - /// Returns a Client that applies a pre-processing function to the request. - fn with_request(self, f: F) -> WithRequest - where - F: FnMut(Req2) -> Req, - Self: Sized, - { - WithRequest { inner: self, f } - } -} - -/// A Client that applies a function to the returned response. -#[derive(Clone, Debug)] -pub struct MapResponse { - inner: C, - f: F, -} - -impl<'a, C, F, Req, Resp, Resp2> Client<'a, Req> for MapResponse -where - C: Client<'a, Req, Response = Resp>, - F: FnMut(Resp) -> Resp2 + 'a, -{ - type Response = Resp2; - type Future = futures::future::MapOk<>::Future, &'a mut F>; - - fn call(&'a mut self, ctx: context::Context, request: Req) -> Self::Future { - self.inner.call(ctx, request).map_ok(&mut self.f) - } -} - -/// A Client that applies a pre-processing function to the request. -#[derive(Clone, Debug)] -pub struct WithRequest { - inner: C, - f: F, -} - -impl<'a, C, F, Req, Req2, Resp> Client<'a, Req2> for WithRequest -where - C: Client<'a, Req, Response = Resp>, - F: FnMut(Req2) -> Req, -{ - type Response = Resp; - type Future = >::Future; - - fn call(&'a mut self, ctx: context::Context, request: Req2) -> Self::Future { - self.inner.call(ctx, (self.f)(request)) - } -} - -impl<'a, Req, Resp> Client<'a, Req> for Channel -where - Req: 'a, - Resp: 'a, -{ - type Response = Resp; - type Future = channel::Call<'a, Req, Resp>; - - fn call(&'a mut self, ctx: context::Context, request: Req) -> channel::Call<'a, Req, Resp> { - self.call(ctx, request) - } -} - /// Settings that control the behavior of the client. #[derive(Clone, Debug)] #[non_exhaustive] diff --git a/tarpc/src/client/channel.rs b/tarpc/src/client/channel.rs index 851cf2c..bd12135 100644 --- a/tarpc/src/client/channel.rs +++ b/tarpc/src/client/channel.rs @@ -8,13 +8,7 @@ use crate::{ client::in_flight_requests::InFlightRequests, context, trace::SpanId, ClientMessage, PollContext, PollIo, Request, Response, Transport, }; -use futures::{ - channel::{mpsc, oneshot}, - prelude::*, - ready, - stream::Fuse, - task::*, -}; +use futures::{prelude::*, ready, stream::Fuse, task::*}; use log::{info, trace}; use pin_project::{pin_project, pinned_drop}; use std::{ @@ -26,6 +20,7 @@ use std::{ Arc, }, }; +use tokio::sync::{mpsc, oneshot}; #[allow(dead_code)] #[allow(clippy::no_effect)] @@ -58,49 +53,14 @@ impl Clone for Channel { } } -/// A future returned by [`Channel::send`] that resolves to a server response. -#[pin_project] -#[derive(Debug)] -#[must_use = "futures do nothing unless polled"] -struct Send<'a, Req, Resp> { - #[pin] - fut: MapOkDispatchResponse, Resp>, -} - -type SendMapErrConnectionReset<'a, Req, Resp> = MapErrConnectionReset< - futures::sink::Send<'a, mpsc::Sender>, DispatchRequest>, ->; - -impl<'a, Req, Resp> Future for Send<'a, Req, Resp> { - type Output = io::Result>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.as_mut().project().fut.poll(cx) - } -} - -/// A future returned by [`Channel::call`] that resolves to a server response. -#[pin_project] -#[derive(Debug)] -#[must_use = "futures do nothing unless polled"] -pub struct Call<'a, Req, Resp> { - #[pin] - fut: AndThenIdent, DispatchResponse>, -} - -impl<'a, Req, Resp> Future for Call<'a, Req, Resp> { - type Output = io::Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let resp = ready!(self.as_mut().project().fut.poll(cx)); - Poll::Ready(resp) - } -} - impl Channel { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves when the request is sent (not when the response is received). - fn send(&mut self, mut ctx: context::Context, request: Req) -> Send { + fn send( + &self, + mut ctx: context::Context, + request: Req, + ) -> impl Future>> + '_ { // Convert the context to the call context. ctx.trace_context.parent_id = Some(ctx.trace_context.span_id); ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng()); @@ -109,31 +69,39 @@ impl Channel { let cancellation = self.cancellation.clone(); let request_id = u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); - Send { - fut: MapOkDispatchResponse::new( - MapErrConnectionReset::new(self.to_dispatch.send(DispatchRequest { + + // DispatchResponse impls Drop to cancel in-flight requests. It should be created before + // sending out the request; otherwise, the response future could be dropped after the + // request is sent out but before DispatchResponse is created, rendering the cancellation + // logic inactive. + let response = DispatchResponse { + response, + complete: false, + request_id, + cancellation, + ctx, + }; + async move { + self.to_dispatch + .send(DispatchRequest { ctx, request_id, request, response_completion, - })), - DispatchResponse { - response, - complete: false, - request_id, - cancellation, - ctx, - }, - ), + }) + .await + .map_err(|mpsc::error::SendError(_)| { + io::Error::from(io::ErrorKind::ConnectionReset) + })?; + Ok(response) } } /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. - pub fn call(&mut self, ctx: context::Context, request: Req) -> Call { - Call { - fut: AndThenIdent::new(self.send(ctx, request)), - } + pub async fn call(&self, ctx: context::Context, request: Req) -> io::Result { + let dispatch_response = self.send(ctx, request).await?; + dispatch_response.await } } @@ -157,7 +125,7 @@ impl Future for DispatchResponse { self.complete = true; Poll::Ready(match resp { Ok(resp) => Ok(resp.message?), - Err(oneshot::Canceled) => { + Err(oneshot::error::RecvError { .. }) => { // The oneshot is Canceled when the dispatch task ends. In that case, // there's nothing listening on the other side, so there's no point in // propagating cancellation. @@ -200,7 +168,7 @@ where { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); - let canceled_requests = canceled_requests.fuse(); + let canceled_requests = canceled_requests; NewClient { client: Channel { @@ -213,7 +181,7 @@ where canceled_requests, transport: transport.fuse(), in_flight_requests: InFlightRequests::default(), - pending_requests: pending_requests.fuse(), + pending_requests, }, } } @@ -228,10 +196,10 @@ pub struct RequestDispatch { transport: Fuse, /// Requests waiting to be written to the wire. #[pin] - pending_requests: Fuse>>, + pending_requests: mpsc::Receiver>, /// Requests that were dropped. #[pin] - canceled_requests: Fuse, + canceled_requests: CanceledRequests, /// Requests already written to the wire that haven't yet received responses. in_flight_requests: InFlightRequests, /// Configures limits to prevent unlimited resource usage. @@ -337,9 +305,9 @@ where } loop { - match ready!(self.as_mut().project().pending_requests.poll_next_unpin(cx)) { + match ready!(self.as_mut().project().pending_requests.poll_recv(cx)) { Some(request) => { - if request.response_completion.is_canceled() { + if request.response_completion.is_closed() { trace!( "[{}] Request canceled before being sent.", request.ctx.trace_id() @@ -496,14 +464,14 @@ fn cancellations() -> (RequestCancellation, CanceledRequests) { // bounded by the number of in-flight requests. Additionally, each request has a clone // of the sender, so the bounded channel would have the same behavior, // since it guarantees a slot. - let (tx, rx) = mpsc::unbounded(); + let (tx, rx) = mpsc::unbounded_channel(); (RequestCancellation(tx), CanceledRequests(rx)) } impl RequestCancellation { /// Cancels the request with ID `request_id`. fn cancel(&mut self, request_id: u64) { - let _ = self.0.unbounded_send(request_id); + let _ = self.0.send(request_id); } } @@ -511,184 +479,7 @@ impl Stream for CanceledRequests { type Item = u64; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.0.poll_next_unpin(cx) - } -} - -#[pin_project] -#[derive(Debug)] -#[must_use = "futures do nothing unless polled"] -struct MapErrConnectionReset { - #[pin] - future: Fut, - finished: Option<()>, -} - -impl MapErrConnectionReset { - fn new(future: Fut) -> MapErrConnectionReset { - MapErrConnectionReset { - future, - finished: Some(()), - } - } -} - -impl Future for MapErrConnectionReset -where - Fut: TryFuture, -{ - type Output = io::Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.as_mut().project().future.try_poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(result) => { - self.project().finished.take().expect( - "MapErrConnectionReset must not be polled after it returned `Poll::Ready`", - ); - Poll::Ready(result.map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset))) - } - } - } -} - -#[pin_project] -#[derive(Debug)] -#[must_use = "futures do nothing unless polled"] -struct MapOkDispatchResponse { - #[pin] - future: Fut, - response: Option>, -} - -impl MapOkDispatchResponse { - fn new(future: Fut, response: DispatchResponse) -> MapOkDispatchResponse { - MapOkDispatchResponse { - future, - response: Some(response), - } - } -} - -impl Future for MapOkDispatchResponse -where - Fut: TryFuture, -{ - type Output = Result, Fut::Error>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.as_mut().project().future.try_poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(result) => { - let response = self - .as_mut() - .project() - .response - .take() - .expect("MapOk must not be polled after it returned `Poll::Ready`"); - Poll::Ready(result.map(|_| response)) - } - } - } -} - -#[pin_project] -#[derive(Debug)] -#[must_use = "futures do nothing unless polled"] -struct AndThenIdent { - #[pin] - try_chain: TryChain, -} - -impl AndThenIdent -where - Fut1: TryFuture, - Fut2: TryFuture, -{ - /// Creates a new `Then`. - fn new(future: Fut1) -> AndThenIdent { - AndThenIdent { - try_chain: TryChain::new(future), - } - } -} - -impl Future for AndThenIdent -where - Fut1: TryFuture, - Fut2: TryFuture, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project().try_chain.poll(cx, |result| match result { - Ok(ok) => TryChainAction::Future(ok), - Err(err) => TryChainAction::Output(Err(err)), - }) - } -} - -#[pin_project(project = TryChainProj)] -#[must_use = "futures do nothing unless polled"] -#[derive(Debug)] -enum TryChain { - First(#[pin] Fut1), - Second(#[pin] Fut2), - Empty, -} - -enum TryChainAction -where - Fut2: TryFuture, -{ - Future(Fut2), - Output(Result), -} - -impl TryChain -where - Fut1: TryFuture, - Fut2: TryFuture, -{ - fn new(fut1: Fut1) -> TryChain { - TryChain::First(fut1) - } - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - f: F, - ) -> Poll> - where - F: FnOnce(Result) -> TryChainAction, - { - let mut f = Some(f); - - loop { - let output = match self.as_mut().project() { - TryChainProj::First(fut1) => { - // Poll the first future - match fut1.try_poll(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(output) => output, - } - } - TryChainProj::Second(fut2) => { - // Poll the second future - return fut2.try_poll(cx); - } - TryChainProj::Empty => { - panic!("future must not be polled after it returned `Poll::Ready`"); - } - }; - - self.set(TryChain::Empty); // Drop fut1 - let f = f.take().unwrap(); - match f(output) { - TryChainAction::Future(fut2) => self.set(TryChain::Second(fut2)), - TryChainAction::Output(output) => return Poll::Ready(output), - } - } + self.0.poll_recv(cx) } } @@ -704,12 +495,9 @@ mod tests { transport::{self, channel::UnboundedChannel}, ClientMessage, Response, }; - use futures::{ - channel::{mpsc, oneshot}, - prelude::*, - task::*, - }; + use futures::{prelude::*, task::*}; use std::{pin::Pin, sync::atomic::AtomicUsize, sync::Arc}; + use tokio::sync::{mpsc, oneshot}; #[tokio::test] async fn dispatch_response_cancels_on_drop() { @@ -723,7 +511,8 @@ mod tests { ctx: context::current(), }); // resp's drop() is run, which should send a cancel message. - assert_eq!(canceled_requests.0.try_next().unwrap(), Some(3)); + let cx = &mut Context::from_waker(&noop_waker_ref()); + assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(Some(3))); } #[tokio::test] @@ -824,13 +613,13 @@ mod tests { let _ = env_logger::try_init(); let (to_dispatch, pending_requests) = mpsc::channel(1); - let (cancel_tx, canceled_requests) = mpsc::unbounded(); + let (cancel_tx, canceled_requests) = mpsc::unbounded_channel(); let (client_channel, server_channel) = transport::channel::unbounded(); let dispatch = RequestDispatch:: { transport: client_channel.fuse(), - pending_requests: pending_requests.fuse(), - canceled_requests: CanceledRequests(canceled_requests).fuse(), + pending_requests: pending_requests, + canceled_requests: CanceledRequests(canceled_requests), in_flight_requests: InFlightRequests::default(), config: Config::default(), }; diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 63ae797..cdfa63d 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -4,13 +4,14 @@ use crate::{ PollIo, Response, ServerError, }; use fnv::FnvHashMap; -use futures::{channel::oneshot, ready}; +use futures::ready; use log::{debug, trace}; use std::{ collections::hash_map, io, task::{Context, Poll}, }; +use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; /// Requests already written to the wire that haven't yet received responses. diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index a3b161d..f08e471 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -299,7 +299,7 @@ pub mod server; pub mod transport; pub(crate) mod util; -pub use crate::{client::Client, transport::sealed::Transport}; +pub use crate::transport::sealed::Transport; use anyhow::Context as _; use futures::task::*; diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index f24e2cb..70485d7 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -8,7 +8,6 @@ use crate::{context, ClientMessage, PollIo, Request, Response, ServerError, Transport}; use futures::{ - channel::mpsc, future::{AbortRegistration, Abortable}, prelude::*, ready, @@ -19,6 +18,7 @@ use humantime::format_rfc3339; use log::{debug, info, trace}; use pin_project::pin_project; use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime}; +use tokio::sync::mpsc; mod filter; mod in_flight_requests; @@ -244,7 +244,6 @@ where Self: Sized, { let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer); - let responses = responses.fuse(); Requests { channel: self, @@ -387,7 +386,7 @@ where channel: C, /// Responses waiting to be written to the wire. #[pin] - pending_responses: Fuse)>>, + pending_responses: mpsc::Receiver<(context::Context, Response)>, /// Handed out to request handlers to fan in responses. #[pin] responses_tx: mpsc::Sender<(context::Context, Response)>, @@ -455,7 +454,7 @@ where request_id, message: Err(ServerError { kind: io::ErrorKind::TimedOut, - detail: Some(format!("Request did not complete before deadline.")), + detail: Some(String::from("Request did not complete before deadline.")), }), })?; return Poll::Ready(Some(Ok(()))); @@ -506,7 +505,7 @@ where ready!(self.as_mut().project().channel.poll_flush(cx)?); } - match ready!(self.as_mut().project().pending_responses.poll_next(cx)) { + match ready!(self.as_mut().project().pending_responses.poll_recv(cx)) { Some(response) => Poll::Ready(Some(Ok(response))), None => { // This branch likely won't happen, since the Requests stream is holding a Sender. @@ -556,7 +555,7 @@ impl InFlightRequest { let Self { abort_registration, request, - mut response_tx, + response_tx, } = self; Abortable::new( async move { diff --git a/tarpc/src/server/filter.rs b/tarpc/src/server/filter.rs index 902cede..2fc8d60 100644 --- a/tarpc/src/server/filter.rs +++ b/tarpc/src/server/filter.rs @@ -10,7 +10,7 @@ use crate::{ PollIo, }; use fnv::FnvHashMap; -use futures::{channel::mpsc, future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*}; +use futures::{future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*}; use log::{debug, info, trace}; use pin_project::pin_project; use std::sync::{Arc, Weak}; @@ -18,6 +18,7 @@ use std::{ collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin, time::SystemTime, }; +use tokio::sync::mpsc; /// A single-threaded filter that drops channels based on per-key limits. #[pin_project] @@ -55,7 +56,7 @@ struct Tracker { impl Drop for Tracker { fn drop(&mut self) { // Don't care if the listener is dropped. - let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap()); + let _ = self.dropped_keys.send(self.key.take().unwrap()); } } @@ -147,7 +148,7 @@ where { /// Sheds new channels to stay under configured limits. pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self { - let (dropped_keys_tx, dropped_keys) = mpsc::unbounded(); + let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel(); ChannelFilter { listener: listener.fuse(), channels_per_key, @@ -233,7 +234,7 @@ where } fn poll_closed_channels(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - match ready!(self.as_mut().project().dropped_keys.poll_next_unpin(cx)) { + match ready!(self.as_mut().project().dropped_keys.poll_recv(cx)) { Some(key) => { debug!("All channels dropped for key [{}]", key); self.as_mut().project().key_counts.remove(&key); @@ -278,7 +279,6 @@ where } } } - #[cfg(test)] fn ctx() -> Context<'static> { use futures::task::*; @@ -290,12 +290,12 @@ fn ctx() -> Context<'static> { fn tracker_drop() { use assert_matches::assert_matches; - let (tx, mut rx) = mpsc::unbounded(); + let (tx, mut rx) = mpsc::unbounded_channel(); Tracker { key: Some(1), dropped_keys: tx, }; - assert_matches!(rx.try_next(), Ok(Some(1))); + assert_matches!(rx.poll_recv(&mut ctx()), Poll::Ready(Some(1))); } #[test] @@ -303,8 +303,8 @@ fn tracked_channel_stream() { use assert_matches::assert_matches; use pin_utils::pin_mut; - let (chan_tx, chan) = mpsc::unbounded(); - let (dropped_keys, _) = mpsc::unbounded(); + let (chan_tx, chan) = futures::channel::mpsc::unbounded(); + let (dropped_keys, _) = mpsc::unbounded_channel(); let channel = TrackedChannel { inner: chan, tracker: Arc::new(Tracker { @@ -323,8 +323,8 @@ fn tracked_channel_sink() { use assert_matches::assert_matches; use pin_utils::pin_mut; - let (chan, mut chan_rx) = mpsc::unbounded(); - let (dropped_keys, _) = mpsc::unbounded(); + let (chan, mut chan_rx) = futures::channel::mpsc::unbounded(); + let (dropped_keys, _) = mpsc::unbounded_channel(); let channel = TrackedChannel { inner: chan, tracker: Arc::new(Tracker { @@ -348,7 +348,7 @@ fn channel_filter_increment_channels_for_key() { struct TestChannel { key: &'static str, } - let (_, listener) = mpsc::unbounded(); + let (_, listener) = futures::channel::mpsc::unbounded(); let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap(); @@ -369,7 +369,7 @@ fn channel_filter_handle_new_channel() { struct TestChannel { key: &'static str, } - let (_, listener) = mpsc::unbounded(); + let (_, listener) = futures::channel::mpsc::unbounded(); let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); let channel1 = filter @@ -401,7 +401,7 @@ fn channel_filter_poll_listener() { struct TestChannel { key: &'static str, } - let (new_channels, listener) = mpsc::unbounded(); + let (new_channels, listener) = futures::channel::mpsc::unbounded(); let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); @@ -437,7 +437,7 @@ fn channel_filter_poll_closed_channels() { struct TestChannel { key: &'static str, } - let (new_channels, listener) = mpsc::unbounded(); + let (new_channels, listener) = futures::channel::mpsc::unbounded(); let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); @@ -465,7 +465,7 @@ fn channel_filter_stream() { struct TestChannel { key: &'static str, } - let (new_channels, listener) = mpsc::unbounded(); + let (new_channels, listener) = futures::channel::mpsc::unbounded(); let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 3223732..59760a2 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -7,10 +7,11 @@ //! Transports backed by in-memory channels. use crate::PollIo; -use futures::{channel::mpsc, task::*, Sink, Stream}; +use futures::{task::*, Sink, Stream}; use pin_project::pin_project; use std::io; use std::pin::Pin; +use tokio::sync::mpsc; /// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's /// [`Sink`]. @@ -18,8 +19,8 @@ pub fn unbounded() -> ( UnboundedChannel, UnboundedChannel, ) { - let (tx1, rx2) = mpsc::unbounded(); - let (tx2, rx1) = mpsc::unbounded(); + let (tx1, rx2) = mpsc::unbounded_channel(); + let (tx2, rx1) = mpsc::unbounded_channel(); ( UnboundedChannel { tx: tx1, rx: rx1 }, UnboundedChannel { tx: tx2, rx: rx2 }, @@ -41,39 +42,36 @@ impl Stream for UnboundedChannel { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo { - self.project().rx.poll_next(cx).map(|option| option.map(Ok)) + self.project().rx.poll_recv(cx).map(|option| option.map(Ok)) } } impl Sink for UnboundedChannel { type Error = io::Error; - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .tx - .poll_ready(cx) - .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) + fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(if self.project().tx.is_closed() { + Err(io::Error::from(io::ErrorKind::NotConnected)) + } else { + Ok(()) + }) } fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { self.project() .tx - .start_send(item) + .send(item) .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .tx - .poll_flush(cx) - .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + // UnboundedSender requires no flushing. + Poll::Ready(Ok(())) } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .tx - .poll_close(cx) - .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) + fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + // UnboundedSender can't initiate closure. + Poll::Ready(Ok(())) } } @@ -108,7 +106,7 @@ mod tests { }), ); - let mut client = client::new(client::Config::default(), client_channel).spawn()?; + let client = client::new(client::Config::default(), client_channel).spawn()?; let response1 = client.call(context::current(), "123".into()).await?; let response2 = client.call(context::current(), "abc".into()).await?; diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 96e06a0..1bff8a1 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -45,7 +45,7 @@ async fn test_call() -> io::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; - let mut client = ColorProtocolClient::new(client::Config::default(), transport).spawn()?; + let client = ColorProtocolClient::new(client::Config::default(), transport).spawn()?; let color = client .get_opposite_color(context::current(), TestData::White) diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 8a00274..488e37c 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -50,7 +50,7 @@ async fn sequential() -> io::Result<()> { .execute(Server.serve()), ); - let mut client = ServiceClient::new(client::Config::default(), tx).spawn()?; + let client = ServiceClient::new(client::Config::default(), tx).spawn()?; assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); assert_matches!( @@ -89,7 +89,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> { // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. tokio::spawn(async move { - let mut client = LoopClient::new(client::Config::default(), tx) + let client = LoopClient::new(client::Config::default(), tx) .spawn() .unwrap(); @@ -130,7 +130,7 @@ async fn serde() -> io::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; - let mut client = ServiceClient::new(client::Config::default(), transport).spawn()?; + let client = ServiceClient::new(client::Config::default(), transport).spawn()?; assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); assert_matches!( @@ -154,14 +154,9 @@ async fn concurrent() -> io::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn()?; - let mut c = client.clone(); - let req1 = c.add(context::current(), 1, 2); - - let mut c = client.clone(); - let req2 = c.add(context::current(), 3, 4); - - let mut c = client.clone(); - let req3 = c.hey(context::current(), "Tim".to_string()); + let req1 = client.add(context::current(), 1, 2); + let req2 = client.add(context::current(), 3, 4); + let req3 = client.hey(context::current(), "Tim".to_string()); assert_matches!(req1.await, Ok(3)); assert_matches!(req2.await, Ok(7)); @@ -183,14 +178,9 @@ async fn concurrent_join() -> io::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn()?; - let mut c = client.clone(); - let req1 = c.add(context::current(), 1, 2); - - let mut c = client.clone(); - let req2 = c.add(context::current(), 3, 4); - - let mut c = client.clone(); - let req3 = c.hey(context::current(), "Tim".to_string()); + let req1 = client.add(context::current(), 1, 2); + let req2 = client.add(context::current(), 3, 4); + let req3 = client.hey(context::current(), "Tim".to_string()); let (resp1, resp2, resp3) = join!(req1, req2, req3); assert_matches!(resp1, Ok(3)); @@ -213,11 +203,8 @@ async fn concurrent_join_all() -> io::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn()?; - let mut c1 = client.clone(); - let mut c2 = client.clone(); - - let req1 = c1.add(context::current(), 1, 2); - let req2 = c2.add(context::current(), 3, 4); + let req1 = client.add(context::current(), 1, 2); + let req2 = client.add(context::current(), 3, 4); let responses = join_all(vec![req1, req2]).await; assert_matches!(responses[0], Ok(3));