diff --git a/example-service/Cargo.toml b/example-service/Cargo.toml index d38526b..0f8318c 100644 --- a/example-service/Cargo.toml +++ b/example-service/Cargo.toml @@ -19,6 +19,7 @@ serde = { version = "1.0" } tarpc = { version = "0.21", path = "../tarpc", features = ["full"] } tokio = { version = "0.2", features = ["full"] } tokio-serde = { version = "0.6", features = ["json"] } +tokio-util = { version = "0.3", features = ["codec"] } env_logger = "0.6" [lib] diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 65294ae..96bb52b 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -8,6 +8,8 @@ use clap::{App, Arg}; use std::{io, net::SocketAddr}; use tarpc::{client, context}; use tokio_serde::formats::Json; +use tokio_util::codec::LengthDelimitedCodec; +use tokio::net::TcpStream; #[tokio::main] async fn main() -> io::Result<()> { @@ -43,7 +45,12 @@ async fn main() -> io::Result<()> { let name = flags.value_of("name").unwrap().into(); - let transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default()).await?; + let conn = TcpStream::connect(server_addr).await?; + let transport = tarpc::serde_transport::new( + LengthDelimitedCodec::builder().max_frame_length(4294967296).new_framed(conn), + Json::default(), + ); + // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. diff --git a/example-service/src/server.rs b/example-service/src/server.rs index a6cdd1c..9048801 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -57,8 +57,9 @@ async fn main() -> io::Result<()> { // JSON transport is provided by the json_transport tarpc module. It makes it easy // to start up a serde-powered json serialization strategy over TCP. - tarpc::serde_transport::tcp::listen(&server_addr, Json::default) - .await? + let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?; + listener.config_mut().max_frame_length(4294967296); + listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) .map(server::BaseChannel::with_defaults) diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 201f4df..9571a8a 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -36,8 +36,8 @@ rand = "0.7" tokio = { version = "0.2", features = ["time"] } serde = { optional = true, version = "1.0", features = ["derive"] } static_assertions = "1.1.0" -tokio-util = { optional = true, version = "0.2" } tarpc-plugins = { path = "../plugins", version = "0.8" } +tokio-util = { optional = true, version = "0.3" } tokio-serde = { optional = true, version = "0.6" } [dev-dependencies] diff --git a/tarpc/src/serde_transport.rs b/tarpc/src/serde_transport.rs index e872740..f43ee56 100644 --- a/tarpc/src/serde_transport.rs +++ b/tarpc/src/serde_transport.rs @@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize}; use std::{error::Error, io, pin::Pin}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_serde::{Framed as SerdeFramed, *}; -use tokio_util::codec::{length_delimited::LengthDelimitedCodec, Framed}; +use tokio_util::codec::{length_delimited::{self, LengthDelimitedCodec}, Framed}; /// A transport that serializes to, and deserializes from, a byte stream. #[pin_project] @@ -90,6 +90,20 @@ fn convert>>( poll.map(|ready| ready.map_err(|e| io::Error::new(io::ErrorKind::Other, e))) } +/// Constructs a new transport from a framed transport and a serialization codec. +pub fn new(framed_io: Framed, codec: Codec) + -> Transport +where + S: AsyncWrite + AsyncRead, + Item: for<'de> Deserialize<'de>, + SinkItem: Serialize, + Codec: Serializer + Deserializer, +{ + Transport { + inner: SerdeFramed::new(framed_io, codec), + } +} + impl From<(S, Codec)> for Transport where S: AsyncWrite + AsyncRead, @@ -97,10 +111,8 @@ where SinkItem: Serialize, Codec: Serializer + Deserializer, { - fn from((inner, codec): (S, Codec)) -> Self { - Transport { - inner: SerdeFramed::new(Framed::new(inner, LengthDelimitedCodec::new()), codec), - } + fn from((io, codec): (S, Codec)) -> Self { + new(Framed::new(io, LengthDelimitedCodec::new()), codec) } } @@ -134,17 +146,19 @@ pub mod tcp { } } - /// Returns a new JSON transport that reads from and writes to `io`. - pub fn new( - io: TcpStream, + /// Connects to `addr`, wrapping the connection in a JSON transport. + pub async fn connect_with( + addr: A, codec: Codec, - ) -> Transport + config: LengthDelimitedCodec, + ) -> io::Result> where + A: ToSocketAddrs, Item: for<'de> Deserialize<'de>, SinkItem: Serialize, Codec: Serializer + Deserializer, { - Transport::from((io, codec)) + Ok(new(Framed::new(TcpStream::connect(addr).await?, config), codec)) } /// Connects to `addr`, wrapping the connection in a JSON transport. @@ -158,7 +172,7 @@ pub mod tcp { SinkItem: Serialize, Codec: Serializer + Deserializer, { - Ok(new(TcpStream::connect(addr).await?, codec)) + connect_with(addr, codec, LengthDelimitedCodec::new()).await } /// Listens on `addr`, wrapping accepted connections in JSON transports. @@ -178,6 +192,7 @@ pub mod tcp { listener, codec_fn, local_addr, + config: LengthDelimitedCodec::builder(), ghost: PhantomData, }) } @@ -189,6 +204,7 @@ pub mod tcp { listener: TcpListener, local_addr: SocketAddr, codec_fn: CodecFn, + config: length_delimited::Builder, ghost: PhantomData<(Item, SinkItem, Codec)>, } @@ -197,6 +213,16 @@ pub mod tcp { pub fn local_addr(&self) -> SocketAddr { self.local_addr } + + /// Returns an immutable reference to the length-delimited codec's config. + pub fn config(&self) -> &length_delimited::Builder { + &self.config + } + + /// Returns a mutable reference to the length-delimited codec's config. + pub fn config_mut(&mut self) -> &mut length_delimited::Builder { + &mut self.config + } } impl Stream for Incoming @@ -211,7 +237,7 @@ pub mod tcp { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let next = ready!(Pin::new(&mut self.as_mut().project().listener.incoming()).poll_next(cx)?); - Poll::Ready(next.map(|conn| Ok(new(conn, (self.codec_fn)())))) + Poll::Ready(next.map(|conn| Ok(new(self.config.new_framed(conn), (self.codec_fn)())))) } } }