From 2264ebecfc0f69966a1b284ec8c768abc8e55b42 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Wed, 19 Aug 2020 17:51:53 -0700 Subject: [PATCH] Remove serde_transport::tcp::connect_with. Instead, serde_transport::tcp::connect returns a future named Connect that has methods to directly access the framing config. This is consistent with how serde_transport::tcp::listen returns a future with methods to access the framing config. In addition to this consistency, it reduces the API surface and provides a simpler user transition from "zero config" to "some config". --- example-service/src/client.rs | 14 +++------ tarpc/src/serde_transport.rs | 58 +++++++++++++++++++++++++---------- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 1e9d112..a692500 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -7,9 +7,7 @@ use clap::{App, Arg}; use std::{io, net::SocketAddr}; use tarpc::{client, context}; -use tokio::net::TcpStream; use tokio_serde::formats::Json; -use tokio_util::codec::LengthDelimitedCodec; #[tokio::main] async fn main() -> io::Result<()> { @@ -45,17 +43,13 @@ async fn main() -> io::Result<()> { let name = flags.value_of("name").unwrap().into(); - let conn = TcpStream::connect(server_addr).await?; - let transport = tarpc::serde_transport::new( - LengthDelimitedCodec::builder() - .max_frame_length(4294967296) - .new_framed(conn), - Json::default(), - ); + let mut transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default); + transport.config_mut().max_frame_length(4294967296); // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. - let mut client = service::WorldClient::new(client::Config::default(), transport).spawn()?; + let mut client = + service::WorldClient::new(client::Config::default(), transport.await?).spawn()?; // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/tarpc/src/serde_transport.rs b/tarpc/src/serde_transport.rs index 8f053b1..940d62f 100644 --- a/tarpc/src/serde_transport.rs +++ b/tarpc/src/serde_transport.rs @@ -151,36 +151,62 @@ pub mod tcp { } } - /// Connects to `addr`, wrapping the connection in a TCP transport. - pub async fn connect_with( - addr: A, - codec: impl FnOnce() -> Codec, - config: LengthDelimitedCodec, - ) -> io::Result> + /// A connection Future that also exposes the length-delimited framing config. + #[pin_project] + pub struct Connect { + #[pin] + inner: T, + codec_fn: CodecFn, + config: length_delimited::Builder, + ghost: PhantomData<(fn(SinkItem), fn() -> Item)>, + } + + impl Future for Connect where - A: ToSocketAddrs, + T: Future>, Item: for<'de> Deserialize<'de>, SinkItem: Serialize, Codec: Serializer + Deserializer, + CodecFn: Fn() -> Codec, { - Ok(new( - Framed::new(TcpStream::connect(addr).await?, config), - codec(), - )) + type Output = io::Result>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let io = ready!(self.as_mut().project().inner.poll(cx))?; + Poll::Ready(Ok(new(self.config.new_framed(io), (self.codec_fn)()))) + } + } + + impl Connect { + /// Returns an immutable reference to the length-delimited codec's config. + pub fn config(&self) -> &length_delimited::Builder { + &self.config + } + + /// Returns a mutable reference to the length-delimited codec's config. + pub fn config_mut(&mut self) -> &mut length_delimited::Builder { + &mut self.config + } } /// Connects to `addr`, wrapping the connection in a TCP transport. - pub async fn connect( + pub fn connect( addr: A, - codec: impl FnOnce() -> Codec, - ) -> io::Result> + codec_fn: CodecFn, + ) -> Connect>, Item, SinkItem, CodecFn> where A: ToSocketAddrs, Item: for<'de> Deserialize<'de>, SinkItem: Serialize, Codec: Serializer + Deserializer, + CodecFn: Fn() -> Codec, { - connect_with(addr, codec, LengthDelimitedCodec::new()).await + Connect { + inner: TcpStream::connect(addr), + codec_fn, + config: LengthDelimitedCodec::builder(), + ghost: PhantomData, + } } /// Listens on `addr`, wrapping accepted connections in TCP transports. @@ -213,7 +239,7 @@ pub mod tcp { local_addr: SocketAddr, codec_fn: CodecFn, config: length_delimited::Builder, - ghost: PhantomData<(Item, SinkItem, Codec)>, + ghost: PhantomData<(fn() -> Item, fn(SinkItem), Codec)>, } impl Incoming {